From 6c9f8f61560f0eaec574984f1d27b89a316ff96c Mon Sep 17 00:00:00 2001 From: JmPotato Date: Mon, 30 May 2022 12:28:26 +0800 Subject: [PATCH 01/82] apiutil, audit, autoscaling, cache: testify the tests (#5061) ref tikv/pd#4813 Testify the pkg/apiutil, pkg/audit, pkg/autoscaling, pkg/cache tests. Signed-off-by: JmPotato Co-authored-by: ShuNing --- go.mod | 1 + pkg/apiutil/apiutil_test.go | 32 ++- pkg/audit/audit_test.go | 50 +++-- pkg/autoscaling/calculation_test.go | 94 ++++----- pkg/autoscaling/prometheus_test.go | 59 +++--- pkg/cache/cache_test.go | 292 ++++++++++++++-------------- 6 files changed, 244 insertions(+), 284 deletions(-) diff --git a/go.mod b/go.mod index 875eb6efe7c..5eeb2656499 100644 --- a/go.mod +++ b/go.mod @@ -38,6 +38,7 @@ require ( github.com/sasha-s/go-deadlock v0.2.0 github.com/spf13/cobra v1.0.0 github.com/spf13/pflag v1.0.5 + github.com/stretchr/testify v1.7.0 github.com/swaggo/http-swagger v0.0.0-20200308142732-58ac5e232fba github.com/swaggo/swag v1.6.6-0.20200529100950-7c765ddd0476 github.com/syndtr/goleveldb v1.0.1-0.20190318030020-c3a204f8e965 diff --git a/pkg/apiutil/apiutil_test.go b/pkg/apiutil/apiutil_test.go index ab418f3f5cf..94cf96c3f26 100644 --- a/pkg/apiutil/apiutil_test.go +++ b/pkg/apiutil/apiutil_test.go @@ -20,19 +20,11 @@ import ( "net/http/httptest" "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/unrolled/render" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testUtilSuite{}) - -type testUtilSuite struct{} - -func (s *testUtilSuite) TestJsonRespondErrorOk(c *C) { +func TestJsonRespondErrorOk(t *testing.T) { rd := render.New(render.Options{ IndentJSON: true, }) @@ -41,15 +33,15 @@ func (s *testUtilSuite) TestJsonRespondErrorOk(c *C) { var input map[string]string output := map[string]string{"zone": "cn", "host": "local"} err := ReadJSONRespondError(rd, response, body, &input) - c.Assert(err, IsNil) - c.Assert(input["zone"], Equals, output["zone"]) - c.Assert(input["host"], Equals, output["host"]) + require.NoError(t, err) + require.Equal(t, output["zone"], input["zone"]) + require.Equal(t, output["host"], input["host"]) result := response.Result() defer result.Body.Close() - c.Assert(result.StatusCode, Equals, 200) + require.Equal(t, 200, result.StatusCode) } -func (s *testUtilSuite) TestJsonRespondErrorBadInput(c *C) { +func TestJsonRespondErrorBadInput(t *testing.T) { rd := render.New(render.Options{ IndentJSON: true, }) @@ -57,20 +49,18 @@ func (s *testUtilSuite) TestJsonRespondErrorBadInput(c *C) { body := io.NopCloser(bytes.NewBufferString("{\"zone\":\"cn\", \"host\":\"local\"}")) var input []string err := ReadJSONRespondError(rd, response, body, &input) - c.Assert(err, NotNil) - c.Assert(err.Error(), Equals, "json: cannot unmarshal object into Go value of type []string") + require.EqualError(t, err, "json: cannot unmarshal object into Go value of type []string") result := response.Result() defer result.Body.Close() - c.Assert(result.StatusCode, Equals, 400) + require.Equal(t, 400, result.StatusCode) { body := io.NopCloser(bytes.NewBufferString("{\"zone\":\"cn\",")) var input []string err := ReadJSONRespondError(rd, response, body, &input) - c.Assert(err, NotNil) - c.Assert(err.Error(), Equals, "unexpected end of JSON input") + require.EqualError(t, err, "unexpected end of JSON input") result := response.Result() defer result.Body.Close() - c.Assert(result.StatusCode, Equals, 400) + require.Equal(t, 400, result.StatusCode) } } diff --git a/pkg/audit/audit_test.go b/pkg/audit/audit_test.go index b5951231f3d..66df5298b8b 100644 --- a/pkg/audit/audit_test.go +++ b/pkg/audit/audit_test.go @@ -24,32 +24,22 @@ import ( "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/log" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/requestutil" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testAuditSuite{}) - -type testAuditSuite struct { -} - -func (s *testAuditSuite) TestLabelMatcher(c *C) { +func TestLabelMatcher(t *testing.T) { matcher := &LabelMatcher{"testSuccess"} labels1 := &BackendLabels{Labels: []string{"testFail", "testSuccess"}} - c.Assert(matcher.Match(labels1), Equals, true) - + require.True(t, matcher.Match(labels1)) labels2 := &BackendLabels{Labels: []string{"testFail"}} - c.Assert(matcher.Match(labels2), Equals, false) + require.False(t, matcher.Match(labels2)) } -func (s *testAuditSuite) TestPrometheusHistogramBackend(c *C) { +func TestPrometheusHistogramBackend(t *testing.T) { serviceAuditHistogramTest := prometheus.NewHistogramVec( prometheus.HistogramOpts{ Namespace: "pd", @@ -70,44 +60,48 @@ func (s *testAuditSuite) TestPrometheusHistogramBackend(c *C) { info.ServiceLabel = "test" info.Component = "user1" req = req.WithContext(requestutil.WithRequestInfo(req.Context(), info)) - c.Assert(backend.ProcessHTTPRequest(req), Equals, false) + require.False(t, backend.ProcessHTTPRequest(req)) endTime := time.Now().Unix() + 20 req = req.WithContext(requestutil.WithEndTime(req.Context(), endTime)) - c.Assert(backend.ProcessHTTPRequest(req), Equals, true) - c.Assert(backend.ProcessHTTPRequest(req), Equals, true) + require.True(t, backend.ProcessHTTPRequest(req)) + require.True(t, backend.ProcessHTTPRequest(req)) info.Component = "user2" req = req.WithContext(requestutil.WithRequestInfo(req.Context(), info)) - c.Assert(backend.ProcessHTTPRequest(req), Equals, true) + require.True(t, backend.ProcessHTTPRequest(req)) // For test, sleep time needs longer than the push interval time.Sleep(1 * time.Second) req, _ = http.NewRequest("GET", ts.URL, nil) resp, err := http.DefaultClient.Do(req) - c.Assert(err, IsNil) + require.NoError(t, err) defer resp.Body.Close() content, _ := io.ReadAll(resp.Body) output := string(content) - c.Assert(strings.Contains(output, "pd_service_audit_handling_seconds_test_count{component=\"user1\",method=\"HTTP\",service=\"test\"} 2"), Equals, true) - c.Assert(strings.Contains(output, "pd_service_audit_handling_seconds_test_count{component=\"user2\",method=\"HTTP\",service=\"test\"} 1"), Equals, true) + require.Contains(t, output, "pd_service_audit_handling_seconds_test_count{component=\"user1\",method=\"HTTP\",service=\"test\"} 2") + require.Contains(t, output, "pd_service_audit_handling_seconds_test_count{component=\"user2\",method=\"HTTP\",service=\"test\"} 1") } -func (s *testAuditSuite) TestLocalLogBackendUsingFile(c *C) { +func TestLocalLogBackendUsingFile(t *testing.T) { backend := NewLocalLogBackend(true) fname := initLog() defer os.Remove(fname) req, _ := http.NewRequest("GET", "http://127.0.0.1:2379/test?test=test", strings.NewReader("testBody")) - c.Assert(backend.ProcessHTTPRequest(req), Equals, false) + require.False(t, backend.ProcessHTTPRequest(req)) info := requestutil.GetRequestInfo(req) req = req.WithContext(requestutil.WithRequestInfo(req.Context(), info)) - c.Assert(backend.ProcessHTTPRequest(req), Equals, true) + require.True(t, backend.ProcessHTTPRequest(req)) b, _ := os.ReadFile(fname) output := strings.SplitN(string(b), "]", 4) - c.Assert(output[3], Equals, fmt.Sprintf(" [\"Audit Log\"] [service-info=\"{ServiceLabel:, Method:HTTP/1.1/GET:/test, Component:anonymous, IP:, "+ - "StartTime:%s, URLParam:{\\\"test\\\":[\\\"test\\\"]}, BodyParam:testBody}\"]\n", - time.Unix(info.StartTimeStamp, 0).String())) + require.Equal( + t, + fmt.Sprintf(" [\"Audit Log\"] [service-info=\"{ServiceLabel:, Method:HTTP/1.1/GET:/test, Component:anonymous, IP:, "+ + "StartTime:%s, URLParam:{\\\"test\\\":[\\\"test\\\"]}, BodyParam:testBody}\"]\n", + time.Unix(info.StartTimeStamp, 0).String()), + output[3], + ) } func BenchmarkLocalLogAuditUsingTerminal(b *testing.B) { diff --git a/pkg/autoscaling/calculation_test.go b/pkg/autoscaling/calculation_test.go index 958b44926ef..8334c96ecc5 100644 --- a/pkg/autoscaling/calculation_test.go +++ b/pkg/autoscaling/calculation_test.go @@ -19,37 +19,21 @@ import ( "encoding/json" "fmt" "math" + "reflect" "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/mock/mockcluster" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/core" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&calculationTestSuite{}) - -type calculationTestSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *calculationTestSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) -} - -func (s *calculationTestSuite) TearDownTest(c *C) { - s.cancel() -} - -func (s *calculationTestSuite) TestGetScaledTiKVGroups(c *C) { +func TestGetScaledTiKVGroups(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() // case1 indicates the tikv cluster with not any group existed - case1 := mockcluster.NewCluster(s.ctx, config.NewTestOptions()) + case1 := mockcluster.NewCluster(ctx, config.NewTestOptions()) case1.AddLabelsStore(1, 1, map[string]string{}) case1.AddLabelsStore(2, 1, map[string]string{ "foo": "bar", @@ -59,7 +43,7 @@ func (s *calculationTestSuite) TestGetScaledTiKVGroups(c *C) { }) // case2 indicates the tikv cluster with 1 auto-scaling group existed - case2 := mockcluster.NewCluster(s.ctx, config.NewTestOptions()) + case2 := mockcluster.NewCluster(ctx, config.NewTestOptions()) case2.AddLabelsStore(1, 1, map[string]string{}) case2.AddLabelsStore(2, 1, map[string]string{ groupLabelKey: fmt.Sprintf("%s-%s-0", autoScalingGroupLabelKeyPrefix, TiKV.String()), @@ -71,7 +55,7 @@ func (s *calculationTestSuite) TestGetScaledTiKVGroups(c *C) { }) // case3 indicates the tikv cluster with other group existed - case3 := mockcluster.NewCluster(s.ctx, config.NewTestOptions()) + case3 := mockcluster.NewCluster(ctx, config.NewTestOptions()) case3.AddLabelsStore(1, 1, map[string]string{}) case3.AddLabelsStore(2, 1, map[string]string{ groupLabelKey: "foo", @@ -80,12 +64,12 @@ func (s *calculationTestSuite) TestGetScaledTiKVGroups(c *C) { groupLabelKey: "foo", }) - testcases := []struct { + testCases := []struct { name string informer core.StoreSetInformer healthyInstances []instance expectedPlan []*Plan - errChecker Checker + noError bool }{ { name: "no scaled tikv group", @@ -105,7 +89,7 @@ func (s *calculationTestSuite) TestGetScaledTiKVGroups(c *C) { }, }, expectedPlan: nil, - errChecker: IsNil, + noError: true, }, { name: "exist 1 scaled tikv group", @@ -135,7 +119,7 @@ func (s *calculationTestSuite) TestGetScaledTiKVGroups(c *C) { }, }, }, - errChecker: IsNil, + noError: true, }, { name: "exist 1 tikv scaled group with inconsistency healthy instances", @@ -155,7 +139,7 @@ func (s *calculationTestSuite) TestGetScaledTiKVGroups(c *C) { }, }, expectedPlan: nil, - errChecker: NotNil, + noError: false, }, { name: "exist 1 tikv scaled group with less healthy instances", @@ -181,7 +165,7 @@ func (s *calculationTestSuite) TestGetScaledTiKVGroups(c *C) { }, }, }, - errChecker: IsNil, + noError: true, }, { name: "existed other tikv group", @@ -201,18 +185,18 @@ func (s *calculationTestSuite) TestGetScaledTiKVGroups(c *C) { }, }, expectedPlan: nil, - errChecker: IsNil, + noError: true, }, } - for _, testcase := range testcases { - c.Log(testcase.name) - plans, err := getScaledTiKVGroups(testcase.informer, testcase.healthyInstances) - if testcase.expectedPlan == nil { - c.Assert(plans, HasLen, 0) - c.Assert(err, testcase.errChecker) + for _, testCase := range testCases { + t.Log(testCase.name) + plans, err := getScaledTiKVGroups(testCase.informer, testCase.healthyInstances) + if testCase.expectedPlan == nil { + require.Len(t, plans, 0) + require.Equal(t, testCase.noError, err == nil) } else { - c.Assert(plans, DeepEquals, testcase.expectedPlan) + require.True(t, reflect.DeepEqual(testCase.expectedPlan, plans)) } } } @@ -228,7 +212,7 @@ func (q *mockQuerier) Query(options *QueryOptions) (QueryResult, error) { return result, nil } -func (s *calculationTestSuite) TestGetTotalCPUUseTime(c *C) { +func TestGetTotalCPUUseTime(t *testing.T) { querier := &mockQuerier{} instances := []instance{ { @@ -246,10 +230,10 @@ func (s *calculationTestSuite) TestGetTotalCPUUseTime(c *C) { } totalCPUUseTime, _ := getTotalCPUUseTime(querier, TiDB, instances, time.Now(), 0) expected := mockResultValue * float64(len(instances)) - c.Assert(math.Abs(expected-totalCPUUseTime) < 1e-6, IsTrue) + require.True(t, math.Abs(expected-totalCPUUseTime) < 1e-6) } -func (s *calculationTestSuite) TestGetTotalCPUQuota(c *C) { +func TestGetTotalCPUQuota(t *testing.T) { querier := &mockQuerier{} instances := []instance{ { @@ -267,10 +251,10 @@ func (s *calculationTestSuite) TestGetTotalCPUQuota(c *C) { } totalCPUQuota, _ := getTotalCPUQuota(querier, TiDB, instances, time.Now()) expected := uint64(mockResultValue * float64(len(instances)*milliCores)) - c.Assert(totalCPUQuota, Equals, expected) + require.Equal(t, expected, totalCPUQuota) } -func (s *calculationTestSuite) TestScaleOutGroupLabel(c *C) { +func TestScaleOutGroupLabel(t *testing.T) { var jsonStr = []byte(` { "rules":[ @@ -304,14 +288,14 @@ func (s *calculationTestSuite) TestScaleOutGroupLabel(c *C) { }`) strategy := &Strategy{} err := json.Unmarshal(jsonStr, strategy) - c.Assert(err, IsNil) + require.NoError(t, err) plan := findBestGroupToScaleOut(strategy, nil, TiKV) - c.Assert(plan.Labels["specialUse"], Equals, "hotRegion") + require.Equal(t, "hotRegion", plan.Labels["specialUse"]) plan = findBestGroupToScaleOut(strategy, nil, TiDB) - c.Assert(plan.Labels["specialUse"], Equals, "") + require.Equal(t, "", plan.Labels["specialUse"]) } -func (s *calculationTestSuite) TestStrategyChangeCount(c *C) { +func TestStrategyChangeCount(t *testing.T) { var count uint64 = 2 strategy := &Strategy{ Rules: []*Rule{ @@ -336,7 +320,9 @@ func (s *calculationTestSuite) TestStrategyChangeCount(c *C) { } // tikv cluster with 1 auto-scaling group existed - cluster := mockcluster.NewCluster(s.ctx, config.NewTestOptions()) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster := mockcluster.NewCluster(ctx, config.NewTestOptions()) cluster.AddLabelsStore(1, 1, map[string]string{}) cluster.AddLabelsStore(2, 1, map[string]string{ groupLabelKey: fmt.Sprintf("%s-%s-0", autoScalingGroupLabelKeyPrefix, TiKV.String()), @@ -357,21 +343,21 @@ func (s *calculationTestSuite) TestStrategyChangeCount(c *C) { // exist two scaled TiKVs and plan does not change due to the limit of resource count groups, err := getScaledTiKVGroups(cluster, instances) - c.Assert(err, IsNil) + require.NoError(t, err) plans := calculateScaleOutPlan(strategy, TiKV, scaleOutQuota, groups) - c.Assert(plans[0].Count, Equals, uint64(2)) + require.Equal(t, uint64(2), plans[0].Count) // change the resource count to 3 and plan increates one more tikv groups, err = getScaledTiKVGroups(cluster, instances) - c.Assert(err, IsNil) + require.NoError(t, err) *strategy.Resources[0].Count = 3 plans = calculateScaleOutPlan(strategy, TiKV, scaleOutQuota, groups) - c.Assert(plans[0].Count, Equals, uint64(3)) + require.Equal(t, uint64(3), plans[0].Count) // change the resource count to 1 and plan decreases to 1 tikv due to the limit of resource count groups, err = getScaledTiKVGroups(cluster, instances) - c.Assert(err, IsNil) + require.NoError(t, err) *strategy.Resources[0].Count = 1 plans = calculateScaleOutPlan(strategy, TiKV, scaleOutQuota, groups) - c.Assert(plans[0].Count, Equals, uint64(1)) + require.Equal(t, uint64(1), plans[0].Count) } diff --git a/pkg/autoscaling/prometheus_test.go b/pkg/autoscaling/prometheus_test.go index 284d780e297..2c541446d2b 100644 --- a/pkg/autoscaling/prometheus_test.go +++ b/pkg/autoscaling/prometheus_test.go @@ -24,10 +24,11 @@ import ( "net/http" "net/url" "strings" + "testing" "time" - . "github.com/pingcap/check" promClient "github.com/prometheus/client_golang/api" + "github.com/stretchr/testify/require" ) const ( @@ -41,8 +42,6 @@ const ( instanceCount = 3 ) -var _ = Suite(&testPrometheusQuerierSuite{}) - var podNameTemplate = map[ComponentType]string{ TiDB: mockTiDBInstanceNamePattern, TiKV: mockTiKVInstanceNamePattern, @@ -76,8 +75,6 @@ var podAddresses = map[ComponentType][]string{ TiKV: generateAddresses(TiKV), } -type testPrometheusQuerierSuite struct{} - // For building mock data only type response struct { Status string `json:"status"` @@ -183,7 +180,7 @@ func (c *normalClient) Do(_ context.Context, req *http.Request) (response *http. return } -func (s *testPrometheusQuerierSuite) TestRetrieveCPUMetrics(c *C) { +func TestRetrieveCPUMetrics(t *testing.T) { client := &normalClient{ mockData: make(map[string]*response), } @@ -194,15 +191,15 @@ func (s *testPrometheusQuerierSuite) TestRetrieveCPUMetrics(c *C) { for _, metric := range metrics { options := NewQueryOptions(component, metric, addresses[:len(addresses)-1], time.Now(), mockDuration) result, err := querier.Query(options) - c.Assert(err, IsNil) + require.NoError(t, err) for i := 0; i < len(addresses)-1; i++ { value, ok := result[addresses[i]] - c.Assert(ok, IsTrue) - c.Assert(math.Abs(value-mockResultValue) < 1e-6, IsTrue) + require.True(t, ok) + require.True(t, math.Abs(value-mockResultValue) < 1e-6) } _, ok := result[addresses[len(addresses)-1]] - c.Assert(ok, IsFalse) + require.False(t, ok) } } } @@ -226,13 +223,13 @@ func (c *emptyResponseClient) Do(_ context.Context, req *http.Request) (r *http. return } -func (s *testPrometheusQuerierSuite) TestEmptyResponse(c *C) { +func TestEmptyResponse(t *testing.T) { client := &emptyResponseClient{} querier := NewPrometheusQuerier(client) options := NewQueryOptions(TiDB, CPUUsage, podAddresses[TiDB], time.Now(), mockDuration) result, err := querier.Query(options) - c.Assert(result, IsNil) - c.Assert(err, NotNil) + require.Nil(t, result) + require.Error(t, err) } type errorHTTPStatusClient struct{} @@ -252,13 +249,13 @@ func (c *errorHTTPStatusClient) Do(_ context.Context, req *http.Request) (r *htt return } -func (s *testPrometheusQuerierSuite) TestErrorHTTPStatus(c *C) { +func TestErrorHTTPStatus(t *testing.T) { client := &errorHTTPStatusClient{} querier := NewPrometheusQuerier(client) options := NewQueryOptions(TiDB, CPUUsage, podAddresses[TiDB], time.Now(), mockDuration) result, err := querier.Query(options) - c.Assert(result, IsNil) - c.Assert(err, NotNil) + require.Nil(t, result) + require.Error(t, err) } type errorPrometheusStatusClient struct{} @@ -276,17 +273,17 @@ func (c *errorPrometheusStatusClient) Do(_ context.Context, req *http.Request) ( return } -func (s *testPrometheusQuerierSuite) TestErrorPrometheusStatus(c *C) { +func TestErrorPrometheusStatus(t *testing.T) { client := &errorPrometheusStatusClient{} querier := NewPrometheusQuerier(client) options := NewQueryOptions(TiDB, CPUUsage, podAddresses[TiDB], time.Now(), mockDuration) result, err := querier.Query(options) - c.Assert(result, IsNil) - c.Assert(err, NotNil) + require.Nil(t, result) + require.Error(t, err) } -func (s *testPrometheusQuerierSuite) TestGetInstanceNameFromAddress(c *C) { - testcases := []struct { +func TestGetInstanceNameFromAddress(t *testing.T) { + testCases := []struct { address string expectedInstanceName string }{ @@ -311,18 +308,18 @@ func (s *testPrometheusQuerierSuite) TestGetInstanceNameFromAddress(c *C) { expectedInstanceName: "", }, } - for _, testcase := range testcases { - instanceName, err := getInstanceNameFromAddress(testcase.address) - if testcase.expectedInstanceName == "" { - c.Assert(err, NotNil) + for _, testCase := range testCases { + instanceName, err := getInstanceNameFromAddress(testCase.address) + if testCase.expectedInstanceName == "" { + require.Error(t, err) } else { - c.Assert(instanceName, Equals, testcase.expectedInstanceName) + require.Equal(t, testCase.expectedInstanceName, instanceName) } } } -func (s *testPrometheusQuerierSuite) TestGetDurationExpression(c *C) { - testcases := []struct { +func TestGetDurationExpression(t *testing.T) { + testCases := []struct { duration time.Duration expectedExpression string }{ @@ -344,8 +341,8 @@ func (s *testPrometheusQuerierSuite) TestGetDurationExpression(c *C) { }, } - for _, testcase := range testcases { - expression := getDurationExpression(testcase.duration) - c.Assert(expression, Equals, testcase.expectedExpression) + for _, testCase := range testCases { + expression := getDurationExpression(testCase.duration) + require.Equal(t, testCase.expectedExpression, expression) } } diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go index df508e39cd1..bd633fef525 100644 --- a/pkg/cache/cache_test.go +++ b/pkg/cache/cache_test.go @@ -16,53 +16,45 @@ package cache import ( "context" + "reflect" "sort" "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -func TestCore(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testRegionCacheSuite{}) - -type testRegionCacheSuite struct { -} - -func (s *testRegionCacheSuite) TestExpireRegionCache(c *C) { +func TestExpireRegionCache(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() cache := NewIDTTL(ctx, time.Second, 2*time.Second) // Test Pop cache.PutWithTTL(9, "9", 5*time.Second) cache.PutWithTTL(10, "10", 5*time.Second) - c.Assert(cache.Len(), Equals, 2) + require.Equal(t, 2, cache.Len()) k, v, success := cache.pop() - c.Assert(success, IsTrue) - c.Assert(cache.Len(), Equals, 1) + require.True(t, success) + require.Equal(t, 1, cache.Len()) k2, v2, success := cache.pop() - c.Assert(success, IsTrue) + require.True(t, success) // we can't ensure the order which the key/value pop from cache, so we save into a map kvMap := map[uint64]string{ 9: "9", 10: "10", } expV, ok := kvMap[k.(uint64)] - c.Assert(ok, IsTrue) - c.Assert(expV, Equals, v.(string)) + require.True(t, ok) + require.Equal(t, expV, v.(string)) expV, ok = kvMap[k2.(uint64)] - c.Assert(ok, IsTrue) - c.Assert(expV, Equals, v2.(string)) + require.True(t, ok) + require.Equal(t, expV, v2.(string)) cache.PutWithTTL(11, "11", 1*time.Second) time.Sleep(5 * time.Second) k, v, success = cache.pop() - c.Assert(success, IsFalse) - c.Assert(k, IsNil) - c.Assert(v, IsNil) + require.False(t, success) + require.Nil(t, k) + require.Nil(t, v) // Test Get cache.PutWithTTL(1, 1, 1*time.Second) @@ -70,50 +62,50 @@ func (s *testRegionCacheSuite) TestExpireRegionCache(c *C) { cache.PutWithTTL(3, 3.0, 5*time.Second) value, ok := cache.Get(1) - c.Assert(ok, IsTrue) - c.Assert(value, Equals, 1) + require.True(t, ok) + require.Equal(t, 1, value) value, ok = cache.Get(2) - c.Assert(ok, IsTrue) - c.Assert(value, Equals, "v2") + require.True(t, ok) + require.Equal(t, "v2", value) value, ok = cache.Get(3) - c.Assert(ok, IsTrue) - c.Assert(value, Equals, 3.0) + require.True(t, ok) + require.Equal(t, 3.0, value) - c.Assert(cache.Len(), Equals, 3) + require.Equal(t, 3, cache.Len()) - c.Assert(sortIDs(cache.GetAllID()), DeepEquals, []uint64{1, 2, 3}) + require.True(t, reflect.DeepEqual(sortIDs(cache.GetAllID()), []uint64{1, 2, 3})) time.Sleep(2 * time.Second) value, ok = cache.Get(1) - c.Assert(ok, IsFalse) - c.Assert(value, IsNil) + require.False(t, ok) + require.Nil(t, value) value, ok = cache.Get(2) - c.Assert(ok, IsTrue) - c.Assert(value, Equals, "v2") + require.True(t, ok) + require.Equal(t, "v2", value) value, ok = cache.Get(3) - c.Assert(ok, IsTrue) - c.Assert(value, Equals, 3.0) + require.True(t, ok) + require.Equal(t, 3.0, value) - c.Assert(cache.Len(), Equals, 2) - c.Assert(sortIDs(cache.GetAllID()), DeepEquals, []uint64{2, 3}) + require.Equal(t, 2, cache.Len()) + require.True(t, reflect.DeepEqual(sortIDs(cache.GetAllID()), []uint64{2, 3})) cache.Remove(2) value, ok = cache.Get(2) - c.Assert(ok, IsFalse) - c.Assert(value, IsNil) + require.False(t, ok) + require.Nil(t, value) value, ok = cache.Get(3) - c.Assert(ok, IsTrue) - c.Assert(value, Equals, 3.0) + require.True(t, ok) + require.Equal(t, 3.0, value) - c.Assert(cache.Len(), Equals, 1) - c.Assert(sortIDs(cache.GetAllID()), DeepEquals, []uint64{3}) + require.Equal(t, 1, cache.Len()) + require.True(t, reflect.DeepEqual(sortIDs(cache.GetAllID()), []uint64{3})) } func sortIDs(ids []uint64) []uint64 { @@ -122,7 +114,7 @@ func sortIDs(ids []uint64) []uint64 { return ids } -func (s *testRegionCacheSuite) TestLRUCache(c *C) { +func TestLRUCache(t *testing.T) { cache := newLRU(3) cache.Put(1, "1") @@ -130,173 +122,173 @@ func (s *testRegionCacheSuite) TestLRUCache(c *C) { cache.Put(3, "3") val, ok := cache.Get(3) - c.Assert(ok, IsTrue) - c.Assert(val, DeepEquals, "3") + require.True(t, ok) + require.True(t, reflect.DeepEqual(val, "3")) val, ok = cache.Get(2) - c.Assert(ok, IsTrue) - c.Assert(val, DeepEquals, "2") + require.True(t, ok) + require.True(t, reflect.DeepEqual(val, "2")) val, ok = cache.Get(1) - c.Assert(ok, IsTrue) - c.Assert(val, DeepEquals, "1") + require.True(t, ok) + require.True(t, reflect.DeepEqual(val, "1")) - c.Assert(cache.Len(), Equals, 3) + require.Equal(t, 3, cache.Len()) cache.Put(4, "4") - c.Assert(cache.Len(), Equals, 3) + require.Equal(t, 3, cache.Len()) val, ok = cache.Get(3) - c.Assert(ok, IsFalse) - c.Assert(val, IsNil) + require.False(t, ok) + require.Nil(t, val) val, ok = cache.Get(1) - c.Assert(ok, IsTrue) - c.Assert(val, DeepEquals, "1") + require.True(t, ok) + require.True(t, reflect.DeepEqual(val, "1")) val, ok = cache.Get(2) - c.Assert(ok, IsTrue) - c.Assert(val, DeepEquals, "2") + require.True(t, ok) + require.True(t, reflect.DeepEqual(val, "2")) val, ok = cache.Get(4) - c.Assert(ok, IsTrue) - c.Assert(val, DeepEquals, "4") + require.True(t, ok) + require.True(t, reflect.DeepEqual(val, "4")) - c.Assert(cache.Len(), Equals, 3) + require.Equal(t, 3, cache.Len()) val, ok = cache.Peek(1) - c.Assert(ok, IsTrue) - c.Assert(val, DeepEquals, "1") + require.True(t, ok) + require.True(t, reflect.DeepEqual(val, "1")) elems := cache.Elems() - c.Assert(elems, HasLen, 3) - c.Assert(elems[0].Value, DeepEquals, "4") - c.Assert(elems[1].Value, DeepEquals, "2") - c.Assert(elems[2].Value, DeepEquals, "1") + require.Len(t, elems, 3) + require.True(t, reflect.DeepEqual(elems[0].Value, "4")) + require.True(t, reflect.DeepEqual(elems[1].Value, "2")) + require.True(t, reflect.DeepEqual(elems[2].Value, "1")) cache.Remove(1) cache.Remove(2) cache.Remove(4) - c.Assert(cache.Len(), Equals, 0) + require.Equal(t, 0, cache.Len()) val, ok = cache.Get(1) - c.Assert(ok, IsFalse) - c.Assert(val, IsNil) + require.False(t, ok) + require.Nil(t, val) val, ok = cache.Get(2) - c.Assert(ok, IsFalse) - c.Assert(val, IsNil) + require.False(t, ok) + require.Nil(t, val) val, ok = cache.Get(3) - c.Assert(ok, IsFalse) - c.Assert(val, IsNil) + require.False(t, ok) + require.Nil(t, val) val, ok = cache.Get(4) - c.Assert(ok, IsFalse) - c.Assert(val, IsNil) + require.False(t, ok) + require.Nil(t, val) } -func (s *testRegionCacheSuite) TestFifoCache(c *C) { +func TestFifoCache(t *testing.T) { cache := NewFIFO(3) cache.Put(1, "1") cache.Put(2, "2") cache.Put(3, "3") - c.Assert(cache.Len(), Equals, 3) + require.Equal(t, 3, cache.Len()) cache.Put(4, "4") - c.Assert(cache.Len(), Equals, 3) + require.Equal(t, 3, cache.Len()) elems := cache.Elems() - c.Assert(elems, HasLen, 3) - c.Assert(elems[0].Value, DeepEquals, "2") - c.Assert(elems[1].Value, DeepEquals, "3") - c.Assert(elems[2].Value, DeepEquals, "4") + require.Len(t, elems, 3) + require.True(t, reflect.DeepEqual(elems[0].Value, "2")) + require.True(t, reflect.DeepEqual(elems[1].Value, "3")) + require.True(t, reflect.DeepEqual(elems[2].Value, "4")) elems = cache.FromElems(3) - c.Assert(elems, HasLen, 1) - c.Assert(elems[0].Value, DeepEquals, "4") + require.Len(t, elems, 1) + require.True(t, reflect.DeepEqual(elems[0].Value, "4")) cache.Remove() cache.Remove() cache.Remove() - c.Assert(cache.Len(), Equals, 0) + require.Equal(t, 0, cache.Len()) } -func (s *testRegionCacheSuite) TestTwoQueueCache(c *C) { +func TestTwoQueueCache(t *testing.T) { cache := newTwoQueue(3) cache.Put(1, "1") cache.Put(2, "2") cache.Put(3, "3") val, ok := cache.Get(3) - c.Assert(ok, IsTrue) - c.Assert(val, DeepEquals, "3") + require.True(t, ok) + require.True(t, reflect.DeepEqual(val, "3")) val, ok = cache.Get(2) - c.Assert(ok, IsTrue) - c.Assert(val, DeepEquals, "2") + require.True(t, ok) + require.True(t, reflect.DeepEqual(val, "2")) val, ok = cache.Get(1) - c.Assert(ok, IsTrue) - c.Assert(val, DeepEquals, "1") + require.True(t, ok) + require.True(t, reflect.DeepEqual(val, "1")) - c.Assert(cache.Len(), Equals, 3) + require.Equal(t, 3, cache.Len()) cache.Put(4, "4") - c.Assert(cache.Len(), Equals, 3) + require.Equal(t, 3, cache.Len()) val, ok = cache.Get(3) - c.Assert(ok, IsFalse) - c.Assert(val, IsNil) + require.False(t, ok) + require.Nil(t, val) val, ok = cache.Get(1) - c.Assert(ok, IsTrue) - c.Assert(val, DeepEquals, "1") + require.True(t, ok) + require.True(t, reflect.DeepEqual(val, "1")) val, ok = cache.Get(2) - c.Assert(ok, IsTrue) - c.Assert(val, DeepEquals, "2") + require.True(t, ok) + require.True(t, reflect.DeepEqual(val, "2")) val, ok = cache.Get(4) - c.Assert(ok, IsTrue) - c.Assert(val, DeepEquals, "4") + require.True(t, ok) + require.True(t, reflect.DeepEqual(val, "4")) - c.Assert(cache.Len(), Equals, 3) + require.Equal(t, 3, cache.Len()) val, ok = cache.Peek(1) - c.Assert(ok, IsTrue) - c.Assert(val, DeepEquals, "1") + require.True(t, ok) + require.True(t, reflect.DeepEqual(val, "1")) elems := cache.Elems() - c.Assert(elems, HasLen, 3) - c.Assert(elems[0].Value, DeepEquals, "4") - c.Assert(elems[1].Value, DeepEquals, "2") - c.Assert(elems[2].Value, DeepEquals, "1") + require.Len(t, elems, 3) + require.True(t, reflect.DeepEqual(elems[0].Value, "4")) + require.True(t, reflect.DeepEqual(elems[1].Value, "2")) + require.True(t, reflect.DeepEqual(elems[2].Value, "1")) cache.Remove(1) cache.Remove(2) cache.Remove(4) - c.Assert(cache.Len(), Equals, 0) + require.Equal(t, 0, cache.Len()) val, ok = cache.Get(1) - c.Assert(ok, IsFalse) - c.Assert(val, IsNil) + require.False(t, ok) + require.Nil(t, val) val, ok = cache.Get(2) - c.Assert(ok, IsFalse) - c.Assert(val, IsNil) + require.False(t, ok) + require.Nil(t, val) val, ok = cache.Get(3) - c.Assert(ok, IsFalse) - c.Assert(val, IsNil) + require.False(t, ok) + require.Nil(t, val) val, ok = cache.Get(4) - c.Assert(ok, IsFalse) - c.Assert(val, IsNil) + require.False(t, ok) + require.Nil(t, val) } var _ PriorityQueueItem = PriorityQueueItemTest(0) @@ -307,54 +299,54 @@ func (pq PriorityQueueItemTest) ID() uint64 { return uint64(pq) } -func (s *testRegionCacheSuite) TestPriorityQueue(c *C) { +func TestPriorityQueue(t *testing.T) { testData := []PriorityQueueItemTest{0, 1, 2, 3, 4, 5} pq := NewPriorityQueue(0) - c.Assert(pq.Put(1, testData[1]), IsFalse) + require.False(t, pq.Put(1, testData[1])) // it will have priority-value pair as 1-1 2-2 3-3 pq = NewPriorityQueue(3) - c.Assert(pq.Put(1, testData[1]), IsTrue) - c.Assert(pq.Put(2, testData[2]), IsTrue) - c.Assert(pq.Put(3, testData[4]), IsTrue) - c.Assert(pq.Put(5, testData[4]), IsTrue) - c.Assert(pq.Put(5, testData[5]), IsFalse) - c.Assert(pq.Put(3, testData[3]), IsTrue) - c.Assert(pq.Put(3, testData[3]), IsTrue) - c.Assert(pq.Get(4), IsNil) - c.Assert(pq.Len(), Equals, 3) + require.True(t, pq.Put(1, testData[1])) + require.True(t, pq.Put(2, testData[2])) + require.True(t, pq.Put(3, testData[4])) + require.True(t, pq.Put(5, testData[4])) + require.False(t, pq.Put(5, testData[5])) + require.True(t, pq.Put(3, testData[3])) + require.True(t, pq.Put(3, testData[3])) + require.Nil(t, pq.Get(4)) + require.Equal(t, 3, pq.Len()) // case1 test getAll, the highest element should be the first entries := pq.Elems() - c.Assert(entries, HasLen, 3) - c.Assert(entries[0].Priority, Equals, 1) - c.Assert(entries[0].Value, Equals, testData[1]) - c.Assert(entries[1].Priority, Equals, 2) - c.Assert(entries[1].Value, Equals, testData[2]) - c.Assert(entries[2].Priority, Equals, 3) - c.Assert(entries[2].Value, Equals, testData[3]) + require.Len(t, entries, 3) + require.Equal(t, 1, entries[0].Priority) + require.Equal(t, testData[1], entries[0].Value) + require.Equal(t, 2, entries[1].Priority) + require.Equal(t, testData[2], entries[1].Value) + require.Equal(t, 3, entries[2].Priority) + require.Equal(t, testData[3], entries[2].Value) // case2 test remove the high element, and the second element should be the first pq.Remove(uint64(1)) - c.Assert(pq.Get(1), IsNil) - c.Assert(pq.Len(), Equals, 2) + require.Nil(t, pq.Get(1)) + require.Equal(t, 2, pq.Len()) entry := pq.Peek() - c.Assert(entry.Priority, Equals, 2) - c.Assert(entry.Value, Equals, testData[2]) + require.Equal(t, 2, entry.Priority) + require.Equal(t, testData[2], entry.Value) // case3 update 3's priority to highest pq.Put(-1, testData[3]) entry = pq.Peek() - c.Assert(entry.Priority, Equals, -1) - c.Assert(entry.Value, Equals, testData[3]) + require.Equal(t, -1, entry.Priority) + require.Equal(t, testData[3], entry.Value) pq.Remove(entry.Value.ID()) - c.Assert(pq.Peek().Value, Equals, testData[2]) - c.Assert(pq.Len(), Equals, 1) + require.Equal(t, testData[2], pq.Peek().Value) + require.Equal(t, 1, pq.Len()) // case4 remove all element pq.Remove(uint64(2)) - c.Assert(pq.Len(), Equals, 0) - c.Assert(pq.items, HasLen, 0) - c.Assert(pq.Peek(), IsNil) - c.Assert(pq.Tail(), IsNil) + require.Equal(t, 0, pq.Len()) + require.Len(t, pq.items, 0) + require.Nil(t, pq.Peek()) + require.Nil(t, pq.Tail()) } From 109719ff0875608f1a09551aa939b8b2cb2d1251 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Mon, 30 May 2022 14:38:27 +0800 Subject: [PATCH 02/82] codec, encryption, errs: testify the tests (#5066) ref tikv/pd#4813 Testify the pkg/codec, pkg/encryption, pkg/errs tests. Signed-off-by: JmPotato --- pkg/codec/codec_test.go | 28 +++---- pkg/encryption/config_test.go | 29 ++++--- pkg/encryption/crypter_test.go | 104 ++++++++++++-------------- pkg/encryption/master_key_test.go | 81 ++++++++++---------- pkg/encryption/region_crypter_test.go | 89 +++++++++++----------- pkg/errs/errs_test.go | 38 ++++------ 6 files changed, 168 insertions(+), 201 deletions(-) diff --git a/pkg/codec/codec_test.go b/pkg/codec/codec_test.go index 1c5aa0700d5..cd73c1da0cc 100644 --- a/pkg/codec/codec_test.go +++ b/pkg/codec/codec_test.go @@ -17,39 +17,31 @@ package codec import ( "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -func TestTable(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testCodecSuite{}) - -type testCodecSuite struct{} - -func (s *testCodecSuite) TestDecodeBytes(c *C) { +func TestDecodeBytes(t *testing.T) { key := "abcdefghijklmnopqrstuvwxyz" for i := 0; i < len(key); i++ { _, k, err := DecodeBytes(EncodeBytes([]byte(key[:i]))) - c.Assert(err, IsNil) - c.Assert(string(k), Equals, key[:i]) + require.NoError(t, err) + require.Equal(t, key[:i], string(k)) } } -func (s *testCodecSuite) TestTableID(c *C) { +func TestTableID(t *testing.T) { key := EncodeBytes([]byte("t\x80\x00\x00\x00\x00\x00\x00\xff")) - c.Assert(key.TableID(), Equals, int64(0xff)) + require.Equal(t, int64(0xff), key.TableID()) key = EncodeBytes([]byte("t\x80\x00\x00\x00\x00\x00\x00\xff_i\x01\x02")) - c.Assert(key.TableID(), Equals, int64(0xff)) + require.Equal(t, int64(0xff), key.TableID()) key = []byte("t\x80\x00\x00\x00\x00\x00\x00\xff") - c.Assert(key.TableID(), Equals, int64(0)) + require.Equal(t, int64(0), key.TableID()) key = EncodeBytes([]byte("T\x00\x00\x00\x00\x00\x00\x00\xff")) - c.Assert(key.TableID(), Equals, int64(0)) + require.Equal(t, int64(0), key.TableID()) key = EncodeBytes([]byte("t\x80\x00\x00\x00\x00\x00\xff")) - c.Assert(key.TableID(), Equals, int64(0)) + require.Equal(t, int64(0), key.TableID()) } diff --git a/pkg/encryption/config_test.go b/pkg/encryption/config_test.go index 49a04f0cd46..04e9d417686 100644 --- a/pkg/encryption/config_test.go +++ b/pkg/encryption/config_test.go @@ -15,37 +15,34 @@ package encryption import ( + "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/typeutil" ) -type testConfigSuite struct{} - -var _ = Suite(&testConfigSuite{}) - -func (s *testConfigSuite) TestAdjustDefaultValue(c *C) { +func TestAdjustDefaultValue(t *testing.T) { config := &Config{} err := config.Adjust() - c.Assert(err, IsNil) - c.Assert(config.DataEncryptionMethod, Equals, methodPlaintext) + require.NoError(t, err) + require.Equal(t, methodPlaintext, config.DataEncryptionMethod) defaultRotationPeriod, _ := time.ParseDuration(defaultDataKeyRotationPeriod) - c.Assert(config.DataKeyRotationPeriod.Duration, Equals, defaultRotationPeriod) - c.Assert(config.MasterKey.Type, Equals, masterKeyTypePlaintext) + require.Equal(t, defaultRotationPeriod, config.DataKeyRotationPeriod.Duration) + require.Equal(t, masterKeyTypePlaintext, config.MasterKey.Type) } -func (s *testConfigSuite) TestAdjustInvalidDataEncryptionMethod(c *C) { +func TestAdjustInvalidDataEncryptionMethod(t *testing.T) { config := &Config{DataEncryptionMethod: "unknown"} - c.Assert(config.Adjust(), NotNil) + require.NotNil(t, config.Adjust()) } -func (s *testConfigSuite) TestAdjustNegativeRotationDuration(c *C) { +func TestAdjustNegativeRotationDuration(t *testing.T) { config := &Config{DataKeyRotationPeriod: typeutil.NewDuration(time.Duration(int64(-1)))} - c.Assert(config.Adjust(), NotNil) + require.NotNil(t, config.Adjust()) } -func (s *testConfigSuite) TestAdjustInvalidMasterKeyType(c *C) { +func TestAdjustInvalidMasterKeyType(t *testing.T) { config := &Config{MasterKey: MasterKeyConfig{Type: "unknown"}} - c.Assert(config.Adjust(), NotNil) + require.NotNil(t, config.Adjust()) } diff --git a/pkg/encryption/crypter_test.go b/pkg/encryption/crypter_test.go index 6f9b8c1a38e..716d15ecdcb 100644 --- a/pkg/encryption/crypter_test.go +++ b/pkg/encryption/crypter_test.go @@ -19,90 +19,82 @@ import ( "encoding/hex" "testing" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/encryptionpb" + "github.com/stretchr/testify/require" ) -func Test(t *testing.T) { - TestingT(t) +func TestEncryptionMethodSupported(t *testing.T) { + require.NotNil(t, CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_PLAINTEXT)) + require.NotNil(t, CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_UNKNOWN)) + require.Nil(t, CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_AES128_CTR)) + require.Nil(t, CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_AES192_CTR)) + require.Nil(t, CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_AES256_CTR)) } -type testCrypterSuite struct{} - -var _ = Suite(&testCrypterSuite{}) - -func (s *testCrypterSuite) TestEncryptionMethodSupported(c *C) { - c.Assert(CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_PLAINTEXT), Not(IsNil)) - c.Assert(CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_UNKNOWN), Not(IsNil)) - c.Assert(CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_AES128_CTR), IsNil) - c.Assert(CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_AES192_CTR), IsNil) - c.Assert(CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_AES256_CTR), IsNil) -} - -func (s *testCrypterSuite) TestKeyLength(c *C) { +func TestKeyLength(t *testing.T) { _, err := KeyLength(encryptionpb.EncryptionMethod_PLAINTEXT) - c.Assert(err, Not(IsNil)) + require.NotNil(t, err) _, err = KeyLength(encryptionpb.EncryptionMethod_UNKNOWN) - c.Assert(err, Not(IsNil)) + require.NotNil(t, err) length, err := KeyLength(encryptionpb.EncryptionMethod_AES128_CTR) - c.Assert(err, IsNil) - c.Assert(length, Equals, 16) + require.NoError(t, err) + require.Equal(t, 16, length) length, err = KeyLength(encryptionpb.EncryptionMethod_AES192_CTR) - c.Assert(err, IsNil) - c.Assert(length, Equals, 24) + require.NoError(t, err) + require.Equal(t, 24, length) length, err = KeyLength(encryptionpb.EncryptionMethod_AES256_CTR) - c.Assert(err, IsNil) - c.Assert(length, Equals, 32) + require.NoError(t, err) + require.Equal(t, 32, length) } -func (s *testCrypterSuite) TestNewIv(c *C) { +func TestNewIv(t *testing.T) { ivCtr, err := NewIvCTR() - c.Assert(err, IsNil) - c.Assert([]byte(ivCtr), HasLen, ivLengthCTR) + require.NoError(t, err) + require.Len(t, []byte(ivCtr), ivLengthCTR) ivGcm, err := NewIvGCM() - c.Assert(err, IsNil) - c.Assert([]byte(ivGcm), HasLen, ivLengthGCM) + require.NoError(t, err) + require.Len(t, []byte(ivGcm), ivLengthGCM) } -func testNewDataKey(c *C, method encryptionpb.EncryptionMethod) { - _, key, err := NewDataKey(method, uint64(123)) - c.Assert(err, IsNil) - length, err := KeyLength(method) - c.Assert(err, IsNil) - c.Assert(key.Key, HasLen, length) - c.Assert(key.Method, Equals, method) - c.Assert(key.WasExposed, IsFalse) - c.Assert(key.CreationTime, Equals, uint64(123)) +func TestNewDataKey(t *testing.T) { + for _, method := range []encryptionpb.EncryptionMethod{ + encryptionpb.EncryptionMethod_AES128_CTR, + encryptionpb.EncryptionMethod_AES192_CTR, + encryptionpb.EncryptionMethod_AES256_CTR, + } { + _, key, err := NewDataKey(method, uint64(123)) + require.NoError(t, err) + length, err := KeyLength(method) + require.NoError(t, err) + require.Len(t, key.Key, length) + require.Equal(t, method, key.Method) + require.False(t, key.WasExposed) + require.Equal(t, uint64(123), key.CreationTime) + } } -func (s *testCrypterSuite) TestNewDataKey(c *C) { - testNewDataKey(c, encryptionpb.EncryptionMethod_AES128_CTR) - testNewDataKey(c, encryptionpb.EncryptionMethod_AES192_CTR) - testNewDataKey(c, encryptionpb.EncryptionMethod_AES256_CTR) -} - -func (s *testCrypterSuite) TestAesGcmCrypter(c *C) { +func TestAesGcmCrypter(t *testing.T) { key, err := hex.DecodeString("ed568fbd8c8018ed2d042a4e5d38d6341486922d401d2022fb81e47c900d3f07") - c.Assert(err, IsNil) + require.NoError(t, err) plaintext, err := hex.DecodeString( "5c873a18af5e7c7c368cb2635e5a15c7f87282085f4b991e84b78c5967e946d4") - c.Assert(err, IsNil) + require.NoError(t, err) // encrypt ivBytes, err := hex.DecodeString("ba432b70336c40c39ba14c1b") - c.Assert(err, IsNil) + require.NoError(t, err) iv := IvGCM(ivBytes) ciphertext, err := aesGcmEncryptImpl(key, plaintext, iv) - c.Assert(err, IsNil) - c.Assert([]byte(iv), HasLen, ivLengthGCM) - c.Assert( - hex.EncodeToString(ciphertext), - Equals, + require.NoError(t, err) + require.Len(t, []byte(iv), ivLengthGCM) + require.Equal( + t, "bbb9b49546350880cf55d4e4eaccc831c506a4aeae7f6cda9c821d4cb8cfc269dcdaecb09592ef25d7a33b40d3f02208", + hex.EncodeToString(ciphertext), ) // decrypt plaintext2, err := AesGcmDecrypt(key, ciphertext, iv) - c.Assert(err, IsNil) - c.Assert(bytes.Equal(plaintext2, plaintext), IsTrue) + require.NoError(t, err) + require.True(t, bytes.Equal(plaintext2, plaintext)) // Modify ciphertext to test authentication failure. We modify the beginning of the ciphertext, // which is the real ciphertext part, not the tag. fakeCiphertext := make([]byte, len(ciphertext)) @@ -110,5 +102,5 @@ func (s *testCrypterSuite) TestAesGcmCrypter(c *C) { // ignore overflow fakeCiphertext[0] = ciphertext[0] + 1 _, err = AesGcmDecrypt(key, fakeCiphertext, iv) - c.Assert(err, Not(IsNil)) + require.NotNil(t, err) } diff --git a/pkg/encryption/master_key_test.go b/pkg/encryption/master_key_test.go index 917e4083bf5..0fc1d376ca7 100644 --- a/pkg/encryption/master_key_test.go +++ b/pkg/encryption/master_key_test.go @@ -17,69 +17,66 @@ package encryption import ( "encoding/hex" "os" + "testing" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/encryptionpb" + "github.com/stretchr/testify/require" ) -type testMasterKeySuite struct{} - -var _ = Suite(&testMasterKeySuite{}) - -func (s *testMasterKeySuite) TestPlaintextMasterKey(c *C) { +func TestPlaintextMasterKey(t *testing.T) { config := &encryptionpb.MasterKey{ Backend: &encryptionpb.MasterKey_Plaintext{ Plaintext: &encryptionpb.MasterKeyPlaintext{}, }, } masterKey, err := NewMasterKey(config, nil) - c.Assert(err, IsNil) - c.Assert(masterKey, Not(IsNil)) - c.Assert(masterKey.key, HasLen, 0) + require.NoError(t, err) + require.NotNil(t, masterKey) + require.Len(t, masterKey.key, 0) plaintext := "this is a plaintext" ciphertext, iv, err := masterKey.Encrypt([]byte(plaintext)) - c.Assert(err, IsNil) - c.Assert(iv, HasLen, 0) - c.Assert(string(ciphertext), Equals, plaintext) + require.NoError(t, err) + require.Len(t, iv, 0) + require.Equal(t, plaintext, string(ciphertext)) plaintext2, err := masterKey.Decrypt(ciphertext, iv) - c.Assert(err, IsNil) - c.Assert(string(plaintext2), Equals, plaintext) + require.NoError(t, err) + require.Equal(t, plaintext, string(plaintext2)) - c.Assert(masterKey.IsPlaintext(), IsTrue) + require.True(t, masterKey.IsPlaintext()) } -func (s *testMasterKeySuite) TestEncrypt(c *C) { +func TestEncrypt(t *testing.T) { keyHex := "2f07ec61e5a50284f47f2b402a962ec672e500b26cb3aa568bb1531300c74806" key, err := hex.DecodeString(keyHex) - c.Assert(err, IsNil) + require.NoError(t, err) masterKey := &MasterKey{key: key} plaintext := "this-is-a-plaintext" ciphertext, iv, err := masterKey.Encrypt([]byte(plaintext)) - c.Assert(err, IsNil) - c.Assert(iv, HasLen, ivLengthGCM) + require.NoError(t, err) + require.Len(t, iv, ivLengthGCM) plaintext2, err := AesGcmDecrypt(key, ciphertext, iv) - c.Assert(err, IsNil) - c.Assert(string(plaintext2), Equals, plaintext) + require.NoError(t, err) + require.Equal(t, plaintext, string(plaintext2)) } -func (s *testMasterKeySuite) TestDecrypt(c *C) { +func TestDecrypt(t *testing.T) { keyHex := "2f07ec61e5a50284f47f2b402a962ec672e500b26cb3aa568bb1531300c74806" key, err := hex.DecodeString(keyHex) - c.Assert(err, IsNil) + require.NoError(t, err) plaintext := "this-is-a-plaintext" iv, err := hex.DecodeString("ba432b70336c40c39ba14c1b") - c.Assert(err, IsNil) + require.NoError(t, err) ciphertext, err := aesGcmEncryptImpl(key, []byte(plaintext), iv) - c.Assert(err, IsNil) + require.NoError(t, err) masterKey := &MasterKey{key: key} plaintext2, err := masterKey.Decrypt(ciphertext, iv) - c.Assert(err, IsNil) - c.Assert(string(plaintext2), Equals, plaintext) + require.NoError(t, err) + require.Equal(t, plaintext, string(plaintext2)) } -func (s *testMasterKeySuite) TestNewFileMasterKeyMissingPath(c *C) { +func TestNewFileMasterKeyMissingPath(t *testing.T) { config := &encryptionpb.MasterKey{ Backend: &encryptionpb.MasterKey_File{ File: &encryptionpb.MasterKeyFile{ @@ -88,12 +85,12 @@ func (s *testMasterKeySuite) TestNewFileMasterKeyMissingPath(c *C) { }, } _, err := NewMasterKey(config, nil) - c.Assert(err, Not(IsNil)) + require.Error(t, err) } -func (s *testMasterKeySuite) TestNewFileMasterKeyMissingFile(c *C) { +func TestNewFileMasterKeyMissingFile(t *testing.T) { dir, err := os.MkdirTemp("", "test_key_files") - c.Assert(err, IsNil) + require.NoError(t, err) path := dir + "/key" config := &encryptionpb.MasterKey{ Backend: &encryptionpb.MasterKey_File{ @@ -103,12 +100,12 @@ func (s *testMasterKeySuite) TestNewFileMasterKeyMissingFile(c *C) { }, } _, err = NewMasterKey(config, nil) - c.Assert(err, Not(IsNil)) + require.Error(t, err) } -func (s *testMasterKeySuite) TestNewFileMasterKeyNotHexString(c *C) { +func TestNewFileMasterKeyNotHexString(t *testing.T) { dir, err := os.MkdirTemp("", "test_key_files") - c.Assert(err, IsNil) + require.NoError(t, err) path := dir + "/key" os.WriteFile(path, []byte("not-a-hex-string"), 0600) config := &encryptionpb.MasterKey{ @@ -119,12 +116,12 @@ func (s *testMasterKeySuite) TestNewFileMasterKeyNotHexString(c *C) { }, } _, err = NewMasterKey(config, nil) - c.Assert(err, Not(IsNil)) + require.Error(t, err) } -func (s *testMasterKeySuite) TestNewFileMasterKeyLengthMismatch(c *C) { +func TestNewFileMasterKeyLengthMismatch(t *testing.T) { dir, err := os.MkdirTemp("", "test_key_files") - c.Assert(err, IsNil) + require.NoError(t, err) path := dir + "/key" os.WriteFile(path, []byte("2f07ec61e5a50284f47f2b402a962ec6"), 0600) config := &encryptionpb.MasterKey{ @@ -135,13 +132,13 @@ func (s *testMasterKeySuite) TestNewFileMasterKeyLengthMismatch(c *C) { }, } _, err = NewMasterKey(config, nil) - c.Assert(err, Not(IsNil)) + require.Error(t, err) } -func (s *testMasterKeySuite) TestNewFileMasterKey(c *C) { +func TestNewFileMasterKey(t *testing.T) { key := "2f07ec61e5a50284f47f2b402a962ec672e500b26cb3aa568bb1531300c74806" dir, err := os.MkdirTemp("", "test_key_files") - c.Assert(err, IsNil) + require.NoError(t, err) path := dir + "/key" os.WriteFile(path, []byte(key), 0600) config := &encryptionpb.MasterKey{ @@ -152,6 +149,6 @@ func (s *testMasterKeySuite) TestNewFileMasterKey(c *C) { }, } masterKey, err := NewMasterKey(config, nil) - c.Assert(err, IsNil) - c.Assert(hex.EncodeToString(masterKey.key), Equals, key) + require.NoError(t, err) + require.Equal(t, key, hex.EncodeToString(masterKey.key)) } diff --git a/pkg/encryption/region_crypter_test.go b/pkg/encryption/region_crypter_test.go index ac7bd34a7b1..06398ebc7ff 100644 --- a/pkg/encryption/region_crypter_test.go +++ b/pkg/encryption/region_crypter_test.go @@ -17,17 +17,14 @@ package encryption import ( "crypto/aes" "crypto/cipher" + "testing" - . "github.com/pingcap/check" "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/encryptionpb" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/require" ) -type testRegionCrypterSuite struct{} - -var _ = Suite(&testRegionCrypterSuite{}) - type testKeyManager struct { Keys *encryptionpb.KeyDictionary EncryptionEnabled bool @@ -72,16 +69,16 @@ func (m *testKeyManager) GetKey(keyID uint64) (*encryptionpb.DataKey, error) { return key, nil } -func (s *testRegionCrypterSuite) TestNilRegion(c *C) { +func TestNilRegion(t *testing.T) { m := newTestKeyManager() region, err := EncryptRegion(nil, m) - c.Assert(err, NotNil) - c.Assert(region, IsNil) + require.Error(t, err) + require.Nil(t, region) err = DecryptRegion(nil, m) - c.Assert(err, NotNil) + require.Error(t, err) } -func (s *testRegionCrypterSuite) TestEncryptRegionWithoutKeyManager(c *C) { +func TestEncryptRegionWithoutKeyManager(t *testing.T) { region := &metapb.Region{ Id: 10, StartKey: []byte("abc"), @@ -89,14 +86,14 @@ func (s *testRegionCrypterSuite) TestEncryptRegionWithoutKeyManager(c *C) { EncryptionMeta: nil, } region, err := EncryptRegion(region, nil) - c.Assert(err, IsNil) + require.NoError(t, err) // check the region isn't changed - c.Assert(string(region.StartKey), Equals, "abc") - c.Assert(string(region.EndKey), Equals, "xyz") - c.Assert(region.EncryptionMeta, IsNil) + require.Equal(t, "abc", string(region.StartKey)) + require.Equal(t, "xyz", string(region.EndKey)) + require.Nil(t, region.EncryptionMeta) } -func (s *testRegionCrypterSuite) TestEncryptRegionWhileEncryptionDisabled(c *C) { +func TestEncryptRegionWhileEncryptionDisabled(t *testing.T) { region := &metapb.Region{ Id: 10, StartKey: []byte("abc"), @@ -106,14 +103,14 @@ func (s *testRegionCrypterSuite) TestEncryptRegionWhileEncryptionDisabled(c *C) m := newTestKeyManager() m.EncryptionEnabled = false region, err := EncryptRegion(region, m) - c.Assert(err, IsNil) + require.NoError(t, err) // check the region isn't changed - c.Assert(string(region.StartKey), Equals, "abc") - c.Assert(string(region.EndKey), Equals, "xyz") - c.Assert(region.EncryptionMeta, IsNil) + require.Equal(t, "abc", string(region.StartKey)) + require.Equal(t, "xyz", string(region.EndKey)) + require.Nil(t, region.EncryptionMeta) } -func (s *testRegionCrypterSuite) TestEncryptRegion(c *C) { +func TestEncryptRegion(t *testing.T) { startKey := []byte("abc") endKey := []byte("xyz") region := &metapb.Region{ @@ -126,27 +123,27 @@ func (s *testRegionCrypterSuite) TestEncryptRegion(c *C) { copy(region.EndKey, endKey) m := newTestKeyManager() outRegion, err := EncryptRegion(region, m) - c.Assert(err, IsNil) - c.Assert(outRegion, Not(Equals), region) + require.NoError(t, err) + require.NotEqual(t, region, outRegion) // check region is encrypted - c.Assert(outRegion.EncryptionMeta, Not(IsNil)) - c.Assert(outRegion.EncryptionMeta.KeyId, Equals, uint64(2)) - c.Assert(outRegion.EncryptionMeta.Iv, HasLen, ivLengthCTR) + require.NotNil(t, outRegion.EncryptionMeta) + require.Equal(t, uint64(2), outRegion.EncryptionMeta.KeyId) + require.Len(t, outRegion.EncryptionMeta.Iv, ivLengthCTR) // Check encrypted content _, currentKey, err := m.GetCurrentKey() - c.Assert(err, IsNil) + require.NoError(t, err) block, err := aes.NewCipher(currentKey.Key) - c.Assert(err, IsNil) + require.NoError(t, err) stream := cipher.NewCTR(block, outRegion.EncryptionMeta.Iv) ciphertextStartKey := make([]byte, len(startKey)) stream.XORKeyStream(ciphertextStartKey, startKey) - c.Assert(string(outRegion.StartKey), Equals, string(ciphertextStartKey)) + require.Equal(t, string(ciphertextStartKey), string(outRegion.StartKey)) ciphertextEndKey := make([]byte, len(endKey)) stream.XORKeyStream(ciphertextEndKey, endKey) - c.Assert(string(outRegion.EndKey), Equals, string(ciphertextEndKey)) + require.Equal(t, string(ciphertextEndKey), string(outRegion.EndKey)) } -func (s *testRegionCrypterSuite) TestDecryptRegionNotEncrypted(c *C) { +func TestDecryptRegionNotEncrypted(t *testing.T) { region := &metapb.Region{ Id: 10, StartKey: []byte("abc"), @@ -155,14 +152,14 @@ func (s *testRegionCrypterSuite) TestDecryptRegionNotEncrypted(c *C) { } m := newTestKeyManager() err := DecryptRegion(region, m) - c.Assert(err, IsNil) + require.NoError(t, err) // check the region isn't changed - c.Assert(string(region.StartKey), Equals, "abc") - c.Assert(string(region.EndKey), Equals, "xyz") - c.Assert(region.EncryptionMeta, IsNil) + require.Equal(t, "abc", string(region.StartKey)) + require.Equal(t, "xyz", string(region.EndKey)) + require.Nil(t, region.EncryptionMeta) } -func (s *testRegionCrypterSuite) TestDecryptRegionWithoutKeyManager(c *C) { +func TestDecryptRegionWithoutKeyManager(t *testing.T) { region := &metapb.Region{ Id: 10, StartKey: []byte("abc"), @@ -173,14 +170,14 @@ func (s *testRegionCrypterSuite) TestDecryptRegionWithoutKeyManager(c *C) { }, } err := DecryptRegion(region, nil) - c.Assert(err, Not(IsNil)) + require.Error(t, err) } -func (s *testRegionCrypterSuite) TestDecryptRegionWhileKeyMissing(c *C) { +func TestDecryptRegionWhileKeyMissing(t *testing.T) { keyID := uint64(3) m := newTestKeyManager() _, err := m.GetKey(3) - c.Assert(err, Not(IsNil)) + require.Error(t, err) region := &metapb.Region{ Id: 10, @@ -192,10 +189,10 @@ func (s *testRegionCrypterSuite) TestDecryptRegionWhileKeyMissing(c *C) { }, } err = DecryptRegion(region, m) - c.Assert(err, Not(IsNil)) + require.Error(t, err) } -func (s *testRegionCrypterSuite) TestDecryptRegion(c *C) { +func TestDecryptRegion(t *testing.T) { keyID := uint64(1) startKey := []byte("abc") endKey := []byte("xyz") @@ -214,19 +211,19 @@ func (s *testRegionCrypterSuite) TestDecryptRegion(c *C) { copy(region.EncryptionMeta.Iv, iv) m := newTestKeyManager() err := DecryptRegion(region, m) - c.Assert(err, IsNil) + require.NoError(t, err) // check region is decrypted - c.Assert(region.EncryptionMeta, IsNil) + require.Nil(t, region.EncryptionMeta) // Check decrypted content key, err := m.GetKey(keyID) - c.Assert(err, IsNil) + require.NoError(t, err) block, err := aes.NewCipher(key.Key) - c.Assert(err, IsNil) + require.NoError(t, err) stream := cipher.NewCTR(block, iv) plaintextStartKey := make([]byte, len(startKey)) stream.XORKeyStream(plaintextStartKey, startKey) - c.Assert(string(region.StartKey), Equals, string(plaintextStartKey)) + require.Equal(t, string(plaintextStartKey), string(region.StartKey)) plaintextEndKey := make([]byte, len(endKey)) stream.XORKeyStream(plaintextEndKey, endKey) - c.Assert(string(region.EndKey), Equals, string(plaintextEndKey)) + require.Equal(t, string(plaintextEndKey), string(region.EndKey)) } diff --git a/pkg/errs/errs_test.go b/pkg/errs/errs_test.go index a2ffd299fb5..65ebb6460d0 100644 --- a/pkg/errs/errs_test.go +++ b/pkg/errs/errs_test.go @@ -20,9 +20,9 @@ import ( "strings" "testing" - . "github.com/pingcap/check" "github.com/pingcap/errors" "github.com/pingcap/log" + "github.com/stretchr/testify/require" "go.uber.org/zap" ) @@ -71,56 +71,48 @@ func newZapTestLogger(cfg *log.Config, opts ...zap.Option) verifyLogger { } } -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testErrorSuite{}) - -type testErrorSuite struct{} - -func (s *testErrorSuite) TestError(c *C) { +func TestError(t *testing.T) { conf := &log.Config{Level: "debug", File: log.FileLogConfig{}, DisableTimestamp: true} lg := newZapTestLogger(conf) log.ReplaceGlobals(lg.Logger, nil) rfc := `[error="[PD:member:ErrEtcdLeaderNotFound]etcd leader not found` log.Error("test", zap.Error(ErrEtcdLeaderNotFound.FastGenByArgs())) - c.Assert(strings.Contains(lg.Message(), rfc), IsTrue) + require.Contains(t, lg.Message(), rfc) err := errors.New("test error") log.Error("test", ZapError(ErrEtcdLeaderNotFound, err)) rfc = `[error="[PD:member:ErrEtcdLeaderNotFound]test error` - c.Assert(strings.Contains(lg.Message(), rfc), IsTrue) + require.Contains(t, lg.Message(), rfc) } -func (s *testErrorSuite) TestErrorEqual(c *C) { +func TestErrorEqual(t *testing.T) { err1 := ErrSchedulerNotFound.FastGenByArgs() err2 := ErrSchedulerNotFound.FastGenByArgs() - c.Assert(errors.ErrorEqual(err1, err2), IsTrue) + require.True(t, errors.ErrorEqual(err1, err2)) err := errors.New("test") err1 = ErrSchedulerNotFound.Wrap(err).FastGenWithCause() err2 = ErrSchedulerNotFound.Wrap(err).FastGenWithCause() - c.Assert(errors.ErrorEqual(err1, err2), IsTrue) + require.True(t, errors.ErrorEqual(err1, err2)) err1 = ErrSchedulerNotFound.FastGenByArgs() err2 = ErrSchedulerNotFound.Wrap(err).FastGenWithCause() - c.Assert(errors.ErrorEqual(err1, err2), IsFalse) + require.False(t, errors.ErrorEqual(err1, err2)) err3 := errors.New("test") err4 := errors.New("test") err1 = ErrSchedulerNotFound.Wrap(err3).FastGenWithCause() err2 = ErrSchedulerNotFound.Wrap(err4).FastGenWithCause() - c.Assert(errors.ErrorEqual(err1, err2), IsTrue) + require.True(t, errors.ErrorEqual(err1, err2)) err3 = errors.New("test1") err4 = errors.New("test") err1 = ErrSchedulerNotFound.Wrap(err3).FastGenWithCause() err2 = ErrSchedulerNotFound.Wrap(err4).FastGenWithCause() - c.Assert(errors.ErrorEqual(err1, err2), IsFalse) + require.False(t, errors.ErrorEqual(err1, err2)) } -func (s *testErrorSuite) TestZapError(c *C) { +func TestZapError(t *testing.T) { err := errors.New("test") log.Info("test", ZapError(err)) err1 := ErrSchedulerNotFound @@ -128,7 +120,7 @@ func (s *testErrorSuite) TestZapError(c *C) { log.Info("test", ZapError(err1, err)) } -func (s *testErrorSuite) TestErrorWithStack(c *C) { +func TestErrorWithStack(t *testing.T) { conf := &log.Config{Level: "debug", File: log.FileLogConfig{}, DisableTimestamp: true} lg := newZapTestLogger(conf) log.ReplaceGlobals(lg.Logger, nil) @@ -141,8 +133,8 @@ func (s *testErrorSuite) TestErrorWithStack(c *C) { // This test is based on line number and the first log is in line 141, the second is in line 142. // So they have the same length stack. Move this test to another place need to change the corresponding length. idx1 := strings.Index(m1, "[stack=") - c.Assert(idx1, GreaterEqual, -1) + require.GreaterOrEqual(t, idx1, -1) idx2 := strings.Index(m2, "[stack=") - c.Assert(idx2, GreaterEqual, -1) - c.Assert(len(m1[idx1:]), Equals, len(m2[idx2:])) + require.GreaterOrEqual(t, idx2, -1) + require.Equal(t, len(m1[idx1:]), len(m2[idx2:])) } From 294a016d0e95ef6af1d1e6c68517cc6c158dd76f Mon Sep 17 00:00:00 2001 From: JmPotato Date: Mon, 30 May 2022 20:38:27 +0800 Subject: [PATCH 03/82] assertutil, etcdutil, grpcutil, keyutil, logutil, metricutil: testify the tests (#5067) ref tikv/pd#4813 Testify the pkg/assertutil, pkg/etcdutil, pkg/grpcutil, pkg/keyutil, pkg/logutil, pkg/metricutil tests. Signed-off-by: JmPotato --- pkg/assertutil/assertutil_test.go | 16 ++---- pkg/etcdutil/etcdutil_test.go | 96 ++++++++++++++----------------- pkg/grpcutil/grpcutil_test.go | 30 ++++------ pkg/keyutil/util_test.go | 15 +---- pkg/logutil/log_test.go | 45 ++++++--------- pkg/metricutil/metricutil_test.go | 17 ++---- 6 files changed, 85 insertions(+), 134 deletions(-) diff --git a/pkg/assertutil/assertutil_test.go b/pkg/assertutil/assertutil_test.go index 754af4509e2..6cdfd591937 100644 --- a/pkg/assertutil/assertutil_test.go +++ b/pkg/assertutil/assertutil_test.go @@ -18,23 +18,15 @@ import ( "errors" "testing" - "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -func Test(t *testing.T) { - check.TestingT(t) -} - -var _ = check.Suite(&testAssertUtilSuite{}) - -type testAssertUtilSuite struct{} - -func (s *testAssertUtilSuite) TestNilFail(c *check.C) { +func TestNilFail(t *testing.T) { var failErr error checker := NewChecker(func() { failErr = errors.New("called assert func not exist") }) - c.Assert(checker.IsNil, check.IsNil) + require.Nil(t, checker.IsNil) checker.AssertNil(nil) - c.Assert(failErr, check.NotNil) + require.NotNil(t, failErr) } diff --git a/pkg/etcdutil/etcdutil_test.go b/pkg/etcdutil/etcdutil_test.go index c3dc327949d..7bc73f12cbe 100644 --- a/pkg/etcdutil/etcdutil_test.go +++ b/pkg/etcdutil/etcdutil_test.go @@ -21,43 +21,35 @@ import ( "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "go.etcd.io/etcd/clientv3" "go.etcd.io/etcd/embed" "go.etcd.io/etcd/pkg/types" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testEtcdutilSuite{}) - -type testEtcdutilSuite struct{} - -func (s *testEtcdutilSuite) TestMemberHelpers(c *C) { +func TestMemberHelpers(t *testing.T) { cfg1 := NewTestSingleConfig() etcd1, err := embed.StartEtcd(cfg1) defer func() { etcd1.Close() CleanConfig(cfg1) }() - c.Assert(err, IsNil) + require.NoError(t, err) ep1 := cfg1.LCUrls[0].String() client1, err := clientv3.New(clientv3.Config{ Endpoints: []string{ep1}, }) - c.Assert(err, IsNil) + require.NoError(t, err) <-etcd1.Server.ReadyNotify() // Test ListEtcdMembers listResp1, err := ListEtcdMembers(client1) - c.Assert(err, IsNil) - c.Assert(listResp1.Members, HasLen, 1) + require.NoError(t, err) + require.Len(t, listResp1.Members, 1) // types.ID is an alias of uint64. - c.Assert(listResp1.Members[0].ID, Equals, uint64(etcd1.Server.ID())) + require.Equal(t, uint64(etcd1.Server.ID()), listResp1.Members[0].ID) // Test AddEtcdMember // Make a new etcd config. @@ -69,67 +61,67 @@ func (s *testEtcdutilSuite) TestMemberHelpers(c *C) { // Add it to the cluster above. peerURL := cfg2.LPUrls[0].String() addResp, err := AddEtcdMember(client1, []string{peerURL}) - c.Assert(err, IsNil) + require.NoError(t, err) etcd2, err := embed.StartEtcd(cfg2) defer func() { etcd2.Close() CleanConfig(cfg2) }() - c.Assert(err, IsNil) - c.Assert(addResp.Member.ID, Equals, uint64(etcd2.Server.ID())) + require.NoError(t, err) + require.Equal(t, uint64(etcd2.Server.ID()), addResp.Member.ID) ep2 := cfg2.LCUrls[0].String() client2, err := clientv3.New(clientv3.Config{ Endpoints: []string{ep2}, }) - c.Assert(err, IsNil) + require.NoError(t, err) <-etcd2.Server.ReadyNotify() - c.Assert(err, IsNil) + require.NoError(t, err) listResp2, err := ListEtcdMembers(client2) - c.Assert(err, IsNil) - c.Assert(listResp2.Members, HasLen, 2) + require.NoError(t, err) + require.Len(t, listResp2.Members, 2) for _, m := range listResp2.Members { switch m.ID { case uint64(etcd1.Server.ID()): case uint64(etcd2.Server.ID()): default: - c.Fatalf("unknown member: %v", m) + t.Fatalf("unknown member: %v", m) } } // Test CheckClusterID urlsMap, err := types.NewURLsMap(cfg2.InitialCluster) - c.Assert(err, IsNil) + require.NoError(t, err) err = CheckClusterID(etcd1.Server.Cluster().ID(), urlsMap, &tls.Config{MinVersion: tls.VersionTLS12}) - c.Assert(err, IsNil) + require.NoError(t, err) // Test RemoveEtcdMember _, err = RemoveEtcdMember(client1, uint64(etcd2.Server.ID())) - c.Assert(err, IsNil) + require.NoError(t, err) listResp3, err := ListEtcdMembers(client1) - c.Assert(err, IsNil) - c.Assert(listResp3.Members, HasLen, 1) - c.Assert(listResp3.Members[0].ID, Equals, uint64(etcd1.Server.ID())) + require.NoError(t, err) + require.Len(t, listResp3.Members, 1) + require.Equal(t, uint64(etcd1.Server.ID()), listResp3.Members[0].ID) } -func (s *testEtcdutilSuite) TestEtcdKVGet(c *C) { +func TestEtcdKVGet(t *testing.T) { cfg := NewTestSingleConfig() etcd, err := embed.StartEtcd(cfg) defer func() { etcd.Close() CleanConfig(cfg) }() - c.Assert(err, IsNil) + require.NoError(t, err) ep := cfg.LCUrls[0].String() client, err := clientv3.New(clientv3.Config{ Endpoints: []string{ep}, }) - c.Assert(err, IsNil) + require.NoError(t, err) <-etcd.Server.ReadyNotify() @@ -139,69 +131,69 @@ func (s *testEtcdutilSuite) TestEtcdKVGet(c *C) { kv := clientv3.NewKV(client) for i := range keys { _, err = kv.Put(context.TODO(), keys[i], vals[i]) - c.Assert(err, IsNil) + require.NoError(t, err) } // Test simple point get resp, err := EtcdKVGet(client, "test/key1") - c.Assert(err, IsNil) - c.Assert(string(resp.Kvs[0].Value), Equals, "val1") + require.NoError(t, err) + require.Equal(t, "val1", string(resp.Kvs[0].Value)) // Test range get withRange := clientv3.WithRange("test/zzzz") withLimit := clientv3.WithLimit(3) resp, err = EtcdKVGet(client, "test/", withRange, withLimit, clientv3.WithSort(clientv3.SortByKey, clientv3.SortAscend)) - c.Assert(err, IsNil) - c.Assert(resp.Kvs, HasLen, 3) + require.NoError(t, err) + require.Len(t, resp.Kvs, 3) for i := range resp.Kvs { - c.Assert(string(resp.Kvs[i].Key), Equals, keys[i]) - c.Assert(string(resp.Kvs[i].Value), Equals, vals[i]) + require.Equal(t, keys[i], string(resp.Kvs[i].Key)) + require.Equal(t, vals[i], string(resp.Kvs[i].Value)) } lastKey := string(resp.Kvs[len(resp.Kvs)-1].Key) next := clientv3.GetPrefixRangeEnd(lastKey) resp, err = EtcdKVGet(client, next, withRange, withLimit, clientv3.WithSort(clientv3.SortByKey, clientv3.SortAscend)) - c.Assert(err, IsNil) - c.Assert(resp.Kvs, HasLen, 2) + require.NoError(t, err) + require.Len(t, resp.Kvs, 2) } -func (s *testEtcdutilSuite) TestEtcdKVPutWithTTL(c *C) { +func TestEtcdKVPutWithTTL(t *testing.T) { cfg := NewTestSingleConfig() etcd, err := embed.StartEtcd(cfg) defer func() { etcd.Close() CleanConfig(cfg) }() - c.Assert(err, IsNil) + require.NoError(t, err) ep := cfg.LCUrls[0].String() client, err := clientv3.New(clientv3.Config{ Endpoints: []string{ep}, }) - c.Assert(err, IsNil) + require.NoError(t, err) <-etcd.Server.ReadyNotify() _, err = EtcdKVPutWithTTL(context.TODO(), client, "test/ttl1", "val1", 2) - c.Assert(err, IsNil) + require.NoError(t, err) _, err = EtcdKVPutWithTTL(context.TODO(), client, "test/ttl2", "val2", 4) - c.Assert(err, IsNil) + require.NoError(t, err) time.Sleep(3 * time.Second) // test/ttl1 is outdated resp, err := EtcdKVGet(client, "test/ttl1") - c.Assert(err, IsNil) - c.Assert(resp.Count, Equals, int64(0)) + require.NoError(t, err) + require.Equal(t, int64(0), resp.Count) // but test/ttl2 is not resp, err = EtcdKVGet(client, "test/ttl2") - c.Assert(err, IsNil) - c.Assert(string(resp.Kvs[0].Value), Equals, "val2") + require.NoError(t, err) + require.Equal(t, "val2", string(resp.Kvs[0].Value)) time.Sleep(2 * time.Second) // test/ttl2 is also outdated resp, err = EtcdKVGet(client, "test/ttl2") - c.Assert(err, IsNil) - c.Assert(resp.Count, Equals, int64(0)) + require.NoError(t, err) + require.Equal(t, int64(0), resp.Count) } diff --git a/pkg/grpcutil/grpcutil_test.go b/pkg/grpcutil/grpcutil_test.go index 4d30bf6ed1d..d1b9d3a8830 100644 --- a/pkg/grpcutil/grpcutil_test.go +++ b/pkg/grpcutil/grpcutil_test.go @@ -4,31 +4,23 @@ import ( "os" "testing" - . "github.com/pingcap/check" "github.com/pingcap/errors" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/errs" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&gRPCUtilSuite{}) - -type gRPCUtilSuite struct{} - -func loadTLSContent(c *C, caPath, certPath, keyPath string) (caData, certData, keyData []byte) { +func loadTLSContent(t *testing.T, caPath, certPath, keyPath string) (caData, certData, keyData []byte) { var err error caData, err = os.ReadFile(caPath) - c.Assert(err, IsNil) + require.NoError(t, err) certData, err = os.ReadFile(certPath) - c.Assert(err, IsNil) + require.NoError(t, err) keyData, err = os.ReadFile(keyPath) - c.Assert(err, IsNil) + require.NoError(t, err) return } -func (s *gRPCUtilSuite) TestToTLSConfig(c *C) { +func TestToTLSConfig(t *testing.T) { tlsConfig := TLSConfig{ KeyPath: "../../tests/client/cert/pd-server-key.pem", CertPath: "../../tests/client/cert/pd-server.pem", @@ -36,24 +28,24 @@ func (s *gRPCUtilSuite) TestToTLSConfig(c *C) { } // test without bytes _, err := tlsConfig.ToTLSConfig() - c.Assert(err, IsNil) + require.NoError(t, err) // test with bytes - caData, certData, keyData := loadTLSContent(c, tlsConfig.CAPath, tlsConfig.CertPath, tlsConfig.KeyPath) + caData, certData, keyData := loadTLSContent(t, tlsConfig.CAPath, tlsConfig.CertPath, tlsConfig.KeyPath) tlsConfig.SSLCABytes = caData tlsConfig.SSLCertBytes = certData tlsConfig.SSLKEYBytes = keyData _, err = tlsConfig.ToTLSConfig() - c.Assert(err, IsNil) + require.NoError(t, err) // test wrong cert bytes tlsConfig.SSLCertBytes = []byte("invalid cert") _, err = tlsConfig.ToTLSConfig() - c.Assert(errors.ErrorEqual(err, errs.ErrCryptoX509KeyPair), IsTrue) + require.True(t, errors.ErrorEqual(err, errs.ErrCryptoX509KeyPair)) // test wrong ca bytes tlsConfig.SSLCertBytes = certData tlsConfig.SSLCABytes = []byte("invalid ca") _, err = tlsConfig.ToTLSConfig() - c.Assert(errors.ErrorEqual(err, errs.ErrCryptoAppendCertsFromPEM), IsTrue) + require.True(t, errors.ErrorEqual(err, errs.ErrCryptoAppendCertsFromPEM)) } diff --git a/pkg/keyutil/util_test.go b/pkg/keyutil/util_test.go index 59708177f5b..6603c61b131 100644 --- a/pkg/keyutil/util_test.go +++ b/pkg/keyutil/util_test.go @@ -17,21 +17,12 @@ package keyutil import ( "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testKeyUtilSuite{}) - -type testKeyUtilSuite struct { -} - -func (s *testKeyUtilSuite) TestKeyUtil(c *C) { +func TestKeyUtil(t *testing.T) { startKey := []byte("a") endKey := []byte("b") key := BuildKeyRangeKey(startKey, endKey) - c.Assert(key, Equals, "61-62") + require.Equal(t, "61-62", key) } diff --git a/pkg/logutil/log_test.go b/pkg/logutil/log_test.go index 33b0320b6f8..42a9126ea33 100644 --- a/pkg/logutil/log_test.go +++ b/pkg/logutil/log_test.go @@ -16,32 +16,25 @@ package logutil import ( "fmt" + "reflect" "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "go.uber.org/zap/zapcore" ) -func Test(t *testing.T) { - TestingT(t) +func TestStringToZapLogLevel(t *testing.T) { + require.Equal(t, zapcore.FatalLevel, StringToZapLogLevel("fatal")) + require.Equal(t, zapcore.ErrorLevel, StringToZapLogLevel("ERROR")) + require.Equal(t, zapcore.WarnLevel, StringToZapLogLevel("warn")) + require.Equal(t, zapcore.WarnLevel, StringToZapLogLevel("warning")) + require.Equal(t, zapcore.DebugLevel, StringToZapLogLevel("debug")) + require.Equal(t, zapcore.InfoLevel, StringToZapLogLevel("info")) + require.Equal(t, zapcore.InfoLevel, StringToZapLogLevel("whatever")) } -var _ = Suite(&testLogSuite{}) - -type testLogSuite struct{} - -func (s *testLogSuite) TestStringToZapLogLevel(c *C) { - c.Assert(StringToZapLogLevel("fatal"), Equals, zapcore.FatalLevel) - c.Assert(StringToZapLogLevel("ERROR"), Equals, zapcore.ErrorLevel) - c.Assert(StringToZapLogLevel("warn"), Equals, zapcore.WarnLevel) - c.Assert(StringToZapLogLevel("warning"), Equals, zapcore.WarnLevel) - c.Assert(StringToZapLogLevel("debug"), Equals, zapcore.DebugLevel) - c.Assert(StringToZapLogLevel("info"), Equals, zapcore.InfoLevel) - c.Assert(StringToZapLogLevel("whatever"), Equals, zapcore.InfoLevel) -} - -func (s *testLogSuite) TestRedactLog(c *C) { - testcases := []struct { +func TestRedactLog(t *testing.T) { + testCases := []struct { name string arg interface{} enableRedactLog bool @@ -73,16 +66,16 @@ func (s *testLogSuite) TestRedactLog(c *C) { }, } - for _, testcase := range testcases { - c.Log(testcase.name) - SetRedactLog(testcase.enableRedactLog) - switch r := testcase.arg.(type) { + for _, testCase := range testCases { + t.Log(testCase.name) + SetRedactLog(testCase.enableRedactLog) + switch r := testCase.arg.(type) { case []byte: - c.Assert(RedactBytes(r), DeepEquals, testcase.expect) + require.True(t, reflect.DeepEqual(testCase.expect, RedactBytes(r))) case string: - c.Assert(RedactString(r), DeepEquals, testcase.expect) + require.True(t, reflect.DeepEqual(testCase.expect, RedactString(r))) case fmt.Stringer: - c.Assert(RedactStringer(r), DeepEquals, testcase.expect) + require.True(t, reflect.DeepEqual(testCase.expect, RedactStringer(r))) default: panic("unmatched case") } diff --git a/pkg/metricutil/metricutil_test.go b/pkg/metricutil/metricutil_test.go index 1fbfacf58cf..512732c7f7e 100644 --- a/pkg/metricutil/metricutil_test.go +++ b/pkg/metricutil/metricutil_test.go @@ -18,20 +18,11 @@ import ( "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/typeutil" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testMetricsSuite{}) - -type testMetricsSuite struct { -} - -func (s *testMetricsSuite) TestCamelCaseToSnakeCase(c *C) { +func TestCamelCaseToSnakeCase(t *testing.T) { inputs := []struct { name string newName string @@ -59,11 +50,11 @@ func (s *testMetricsSuite) TestCamelCaseToSnakeCase(c *C) { } for _, input := range inputs { - c.Assert(camelCaseToSnakeCase(input.name), Equals, input.newName) + require.Equal(t, input.newName, camelCaseToSnakeCase(input.name)) } } -func (s *testMetricsSuite) TestCoverage(c *C) { +func TestCoverage(t *testing.T) { cfgs := []*MetricConfig{ { PushJob: "j1", From bd79cc2b95a24cb9ba9dacb7d9baf915e651c44d Mon Sep 17 00:00:00 2001 From: JmPotato Date: Tue, 31 May 2022 16:06:27 +0800 Subject: [PATCH 04/82] *: use require.New to reduce code (#5076) ref tikv/pd#4813 Use `require.New` to reduce code. Signed-off-by: JmPotato --- pkg/apiutil/apiutil_test.go | 18 +- pkg/assertutil/assertutil_test.go | 5 +- pkg/audit/audit_test.go | 28 +-- pkg/autoscaling/calculation_test.go | 33 ++-- pkg/autoscaling/prometheus_test.go | 32 +-- pkg/cache/cache_test.go | 275 +++++++++++++------------- pkg/codec/codec_test.go | 16 +- pkg/encryption/config_test.go | 18 +- pkg/encryption/crypter_test.go | 70 +++---- pkg/encryption/master_key_test.go | 66 ++++--- pkg/encryption/region_crypter_test.go | 74 +++---- pkg/errs/errs_test.go | 23 ++- pkg/etcdutil/etcdutil_test.go | 81 ++++---- pkg/grpcutil/grpcutil_test.go | 19 +- pkg/keyutil/util_test.go | 3 +- pkg/logutil/log_test.go | 22 ++- pkg/metricutil/metricutil_test.go | 3 +- 17 files changed, 422 insertions(+), 364 deletions(-) diff --git a/pkg/apiutil/apiutil_test.go b/pkg/apiutil/apiutil_test.go index 94cf96c3f26..8a79edc9784 100644 --- a/pkg/apiutil/apiutil_test.go +++ b/pkg/apiutil/apiutil_test.go @@ -25,6 +25,7 @@ import ( ) func TestJsonRespondErrorOk(t *testing.T) { + re := require.New(t) rd := render.New(render.Options{ IndentJSON: true, }) @@ -33,15 +34,16 @@ func TestJsonRespondErrorOk(t *testing.T) { var input map[string]string output := map[string]string{"zone": "cn", "host": "local"} err := ReadJSONRespondError(rd, response, body, &input) - require.NoError(t, err) - require.Equal(t, output["zone"], input["zone"]) - require.Equal(t, output["host"], input["host"]) + re.NoError(err) + re.Equal(output["zone"], input["zone"]) + re.Equal(output["host"], input["host"]) result := response.Result() defer result.Body.Close() - require.Equal(t, 200, result.StatusCode) + re.Equal(200, result.StatusCode) } func TestJsonRespondErrorBadInput(t *testing.T) { + re := require.New(t) rd := render.New(render.Options{ IndentJSON: true, }) @@ -49,18 +51,18 @@ func TestJsonRespondErrorBadInput(t *testing.T) { body := io.NopCloser(bytes.NewBufferString("{\"zone\":\"cn\", \"host\":\"local\"}")) var input []string err := ReadJSONRespondError(rd, response, body, &input) - require.EqualError(t, err, "json: cannot unmarshal object into Go value of type []string") + re.EqualError(err, "json: cannot unmarshal object into Go value of type []string") result := response.Result() defer result.Body.Close() - require.Equal(t, 400, result.StatusCode) + re.Equal(400, result.StatusCode) { body := io.NopCloser(bytes.NewBufferString("{\"zone\":\"cn\",")) var input []string err := ReadJSONRespondError(rd, response, body, &input) - require.EqualError(t, err, "unexpected end of JSON input") + re.EqualError(err, "unexpected end of JSON input") result := response.Result() defer result.Body.Close() - require.Equal(t, 400, result.StatusCode) + re.Equal(400, result.StatusCode) } } diff --git a/pkg/assertutil/assertutil_test.go b/pkg/assertutil/assertutil_test.go index 6cdfd591937..324e403f7b6 100644 --- a/pkg/assertutil/assertutil_test.go +++ b/pkg/assertutil/assertutil_test.go @@ -22,11 +22,12 @@ import ( ) func TestNilFail(t *testing.T) { + re := require.New(t) var failErr error checker := NewChecker(func() { failErr = errors.New("called assert func not exist") }) - require.Nil(t, checker.IsNil) + re.Nil(checker.IsNil) checker.AssertNil(nil) - require.NotNil(t, failErr) + re.NotNil(failErr) } diff --git a/pkg/audit/audit_test.go b/pkg/audit/audit_test.go index 66df5298b8b..2b33b62ca55 100644 --- a/pkg/audit/audit_test.go +++ b/pkg/audit/audit_test.go @@ -32,14 +32,16 @@ import ( ) func TestLabelMatcher(t *testing.T) { + re := require.New(t) matcher := &LabelMatcher{"testSuccess"} labels1 := &BackendLabels{Labels: []string{"testFail", "testSuccess"}} - require.True(t, matcher.Match(labels1)) + re.True(matcher.Match(labels1)) labels2 := &BackendLabels{Labels: []string{"testFail"}} - require.False(t, matcher.Match(labels2)) + re.False(matcher.Match(labels2)) } func TestPrometheusHistogramBackend(t *testing.T) { + re := require.New(t) serviceAuditHistogramTest := prometheus.NewHistogramVec( prometheus.HistogramOpts{ Namespace: "pd", @@ -60,43 +62,43 @@ func TestPrometheusHistogramBackend(t *testing.T) { info.ServiceLabel = "test" info.Component = "user1" req = req.WithContext(requestutil.WithRequestInfo(req.Context(), info)) - require.False(t, backend.ProcessHTTPRequest(req)) + re.False(backend.ProcessHTTPRequest(req)) endTime := time.Now().Unix() + 20 req = req.WithContext(requestutil.WithEndTime(req.Context(), endTime)) - require.True(t, backend.ProcessHTTPRequest(req)) - require.True(t, backend.ProcessHTTPRequest(req)) + re.True(backend.ProcessHTTPRequest(req)) + re.True(backend.ProcessHTTPRequest(req)) info.Component = "user2" req = req.WithContext(requestutil.WithRequestInfo(req.Context(), info)) - require.True(t, backend.ProcessHTTPRequest(req)) + re.True(backend.ProcessHTTPRequest(req)) // For test, sleep time needs longer than the push interval time.Sleep(1 * time.Second) req, _ = http.NewRequest("GET", ts.URL, nil) resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + re.NoError(err) defer resp.Body.Close() content, _ := io.ReadAll(resp.Body) output := string(content) - require.Contains(t, output, "pd_service_audit_handling_seconds_test_count{component=\"user1\",method=\"HTTP\",service=\"test\"} 2") - require.Contains(t, output, "pd_service_audit_handling_seconds_test_count{component=\"user2\",method=\"HTTP\",service=\"test\"} 1") + re.Contains(output, "pd_service_audit_handling_seconds_test_count{component=\"user1\",method=\"HTTP\",service=\"test\"} 2") + re.Contains(output, "pd_service_audit_handling_seconds_test_count{component=\"user2\",method=\"HTTP\",service=\"test\"} 1") } func TestLocalLogBackendUsingFile(t *testing.T) { + re := require.New(t) backend := NewLocalLogBackend(true) fname := initLog() defer os.Remove(fname) req, _ := http.NewRequest("GET", "http://127.0.0.1:2379/test?test=test", strings.NewReader("testBody")) - require.False(t, backend.ProcessHTTPRequest(req)) + re.False(backend.ProcessHTTPRequest(req)) info := requestutil.GetRequestInfo(req) req = req.WithContext(requestutil.WithRequestInfo(req.Context(), info)) - require.True(t, backend.ProcessHTTPRequest(req)) + re.True(backend.ProcessHTTPRequest(req)) b, _ := os.ReadFile(fname) output := strings.SplitN(string(b), "]", 4) - require.Equal( - t, + re.Equal( fmt.Sprintf(" [\"Audit Log\"] [service-info=\"{ServiceLabel:, Method:HTTP/1.1/GET:/test, Component:anonymous, IP:, "+ "StartTime:%s, URLParam:{\\\"test\\\":[\\\"test\\\"]}, BodyParam:testBody}\"]\n", time.Unix(info.StartTimeStamp, 0).String()), diff --git a/pkg/autoscaling/calculation_test.go b/pkg/autoscaling/calculation_test.go index 8334c96ecc5..f5ac3313ba4 100644 --- a/pkg/autoscaling/calculation_test.go +++ b/pkg/autoscaling/calculation_test.go @@ -30,6 +30,7 @@ import ( ) func TestGetScaledTiKVGroups(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() // case1 indicates the tikv cluster with not any group existed @@ -193,10 +194,10 @@ func TestGetScaledTiKVGroups(t *testing.T) { t.Log(testCase.name) plans, err := getScaledTiKVGroups(testCase.informer, testCase.healthyInstances) if testCase.expectedPlan == nil { - require.Len(t, plans, 0) - require.Equal(t, testCase.noError, err == nil) + re.Len(plans, 0) + re.Equal(testCase.noError, err == nil) } else { - require.True(t, reflect.DeepEqual(testCase.expectedPlan, plans)) + re.True(reflect.DeepEqual(testCase.expectedPlan, plans)) } } } @@ -213,6 +214,7 @@ func (q *mockQuerier) Query(options *QueryOptions) (QueryResult, error) { } func TestGetTotalCPUUseTime(t *testing.T) { + re := require.New(t) querier := &mockQuerier{} instances := []instance{ { @@ -230,10 +232,11 @@ func TestGetTotalCPUUseTime(t *testing.T) { } totalCPUUseTime, _ := getTotalCPUUseTime(querier, TiDB, instances, time.Now(), 0) expected := mockResultValue * float64(len(instances)) - require.True(t, math.Abs(expected-totalCPUUseTime) < 1e-6) + re.True(math.Abs(expected-totalCPUUseTime) < 1e-6) } func TestGetTotalCPUQuota(t *testing.T) { + re := require.New(t) querier := &mockQuerier{} instances := []instance{ { @@ -251,10 +254,11 @@ func TestGetTotalCPUQuota(t *testing.T) { } totalCPUQuota, _ := getTotalCPUQuota(querier, TiDB, instances, time.Now()) expected := uint64(mockResultValue * float64(len(instances)*milliCores)) - require.Equal(t, expected, totalCPUQuota) + re.Equal(expected, totalCPUQuota) } func TestScaleOutGroupLabel(t *testing.T) { + re := require.New(t) var jsonStr = []byte(` { "rules":[ @@ -288,14 +292,15 @@ func TestScaleOutGroupLabel(t *testing.T) { }`) strategy := &Strategy{} err := json.Unmarshal(jsonStr, strategy) - require.NoError(t, err) + re.NoError(err) plan := findBestGroupToScaleOut(strategy, nil, TiKV) - require.Equal(t, "hotRegion", plan.Labels["specialUse"]) + re.Equal("hotRegion", plan.Labels["specialUse"]) plan = findBestGroupToScaleOut(strategy, nil, TiDB) - require.Equal(t, "", plan.Labels["specialUse"]) + re.Equal("", plan.Labels["specialUse"]) } func TestStrategyChangeCount(t *testing.T) { + re := require.New(t) var count uint64 = 2 strategy := &Strategy{ Rules: []*Rule{ @@ -343,21 +348,21 @@ func TestStrategyChangeCount(t *testing.T) { // exist two scaled TiKVs and plan does not change due to the limit of resource count groups, err := getScaledTiKVGroups(cluster, instances) - require.NoError(t, err) + re.NoError(err) plans := calculateScaleOutPlan(strategy, TiKV, scaleOutQuota, groups) - require.Equal(t, uint64(2), plans[0].Count) + re.Equal(uint64(2), plans[0].Count) // change the resource count to 3 and plan increates one more tikv groups, err = getScaledTiKVGroups(cluster, instances) - require.NoError(t, err) + re.NoError(err) *strategy.Resources[0].Count = 3 plans = calculateScaleOutPlan(strategy, TiKV, scaleOutQuota, groups) - require.Equal(t, uint64(3), plans[0].Count) + re.Equal(uint64(3), plans[0].Count) // change the resource count to 1 and plan decreases to 1 tikv due to the limit of resource count groups, err = getScaledTiKVGroups(cluster, instances) - require.NoError(t, err) + re.NoError(err) *strategy.Resources[0].Count = 1 plans = calculateScaleOutPlan(strategy, TiKV, scaleOutQuota, groups) - require.Equal(t, uint64(1), plans[0].Count) + re.Equal(uint64(1), plans[0].Count) } diff --git a/pkg/autoscaling/prometheus_test.go b/pkg/autoscaling/prometheus_test.go index 2c541446d2b..2906645b180 100644 --- a/pkg/autoscaling/prometheus_test.go +++ b/pkg/autoscaling/prometheus_test.go @@ -181,6 +181,7 @@ func (c *normalClient) Do(_ context.Context, req *http.Request) (response *http. } func TestRetrieveCPUMetrics(t *testing.T) { + re := require.New(t) client := &normalClient{ mockData: make(map[string]*response), } @@ -191,15 +192,15 @@ func TestRetrieveCPUMetrics(t *testing.T) { for _, metric := range metrics { options := NewQueryOptions(component, metric, addresses[:len(addresses)-1], time.Now(), mockDuration) result, err := querier.Query(options) - require.NoError(t, err) + re.NoError(err) for i := 0; i < len(addresses)-1; i++ { value, ok := result[addresses[i]] - require.True(t, ok) - require.True(t, math.Abs(value-mockResultValue) < 1e-6) + re.True(ok) + re.True(math.Abs(value-mockResultValue) < 1e-6) } _, ok := result[addresses[len(addresses)-1]] - require.False(t, ok) + re.False(ok) } } } @@ -224,12 +225,13 @@ func (c *emptyResponseClient) Do(_ context.Context, req *http.Request) (r *http. } func TestEmptyResponse(t *testing.T) { + re := require.New(t) client := &emptyResponseClient{} querier := NewPrometheusQuerier(client) options := NewQueryOptions(TiDB, CPUUsage, podAddresses[TiDB], time.Now(), mockDuration) result, err := querier.Query(options) - require.Nil(t, result) - require.Error(t, err) + re.Nil(result) + re.Error(err) } type errorHTTPStatusClient struct{} @@ -250,12 +252,13 @@ func (c *errorHTTPStatusClient) Do(_ context.Context, req *http.Request) (r *htt } func TestErrorHTTPStatus(t *testing.T) { + re := require.New(t) client := &errorHTTPStatusClient{} querier := NewPrometheusQuerier(client) options := NewQueryOptions(TiDB, CPUUsage, podAddresses[TiDB], time.Now(), mockDuration) result, err := querier.Query(options) - require.Nil(t, result) - require.Error(t, err) + re.Nil(result) + re.Error(err) } type errorPrometheusStatusClient struct{} @@ -274,15 +277,17 @@ func (c *errorPrometheusStatusClient) Do(_ context.Context, req *http.Request) ( } func TestErrorPrometheusStatus(t *testing.T) { + re := require.New(t) client := &errorPrometheusStatusClient{} querier := NewPrometheusQuerier(client) options := NewQueryOptions(TiDB, CPUUsage, podAddresses[TiDB], time.Now(), mockDuration) result, err := querier.Query(options) - require.Nil(t, result) - require.Error(t, err) + re.Nil(result) + re.Error(err) } func TestGetInstanceNameFromAddress(t *testing.T) { + re := require.New(t) testCases := []struct { address string expectedInstanceName string @@ -311,14 +316,15 @@ func TestGetInstanceNameFromAddress(t *testing.T) { for _, testCase := range testCases { instanceName, err := getInstanceNameFromAddress(testCase.address) if testCase.expectedInstanceName == "" { - require.Error(t, err) + re.Error(err) } else { - require.Equal(t, testCase.expectedInstanceName, instanceName) + re.Equal(testCase.expectedInstanceName, instanceName) } } } func TestGetDurationExpression(t *testing.T) { + re := require.New(t) testCases := []struct { duration time.Duration expectedExpression string @@ -343,6 +349,6 @@ func TestGetDurationExpression(t *testing.T) { for _, testCase := range testCases { expression := getDurationExpression(testCase.duration) - require.Equal(t, testCase.expectedExpression, expression) + re.Equal(testCase.expectedExpression, expression) } } diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go index bd633fef525..bf1626450f7 100644 --- a/pkg/cache/cache_test.go +++ b/pkg/cache/cache_test.go @@ -25,36 +25,37 @@ import ( ) func TestExpireRegionCache(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() cache := NewIDTTL(ctx, time.Second, 2*time.Second) // Test Pop cache.PutWithTTL(9, "9", 5*time.Second) cache.PutWithTTL(10, "10", 5*time.Second) - require.Equal(t, 2, cache.Len()) + re.Equal(2, cache.Len()) k, v, success := cache.pop() - require.True(t, success) - require.Equal(t, 1, cache.Len()) + re.True(success) + re.Equal(1, cache.Len()) k2, v2, success := cache.pop() - require.True(t, success) + re.True(success) // we can't ensure the order which the key/value pop from cache, so we save into a map kvMap := map[uint64]string{ 9: "9", 10: "10", } expV, ok := kvMap[k.(uint64)] - require.True(t, ok) - require.Equal(t, expV, v.(string)) + re.True(ok) + re.Equal(expV, v.(string)) expV, ok = kvMap[k2.(uint64)] - require.True(t, ok) - require.Equal(t, expV, v2.(string)) + re.True(ok) + re.Equal(expV, v2.(string)) cache.PutWithTTL(11, "11", 1*time.Second) time.Sleep(5 * time.Second) k, v, success = cache.pop() - require.False(t, success) - require.Nil(t, k) - require.Nil(t, v) + re.False(success) + re.Nil(k) + re.Nil(v) // Test Get cache.PutWithTTL(1, 1, 1*time.Second) @@ -62,50 +63,50 @@ func TestExpireRegionCache(t *testing.T) { cache.PutWithTTL(3, 3.0, 5*time.Second) value, ok := cache.Get(1) - require.True(t, ok) - require.Equal(t, 1, value) + re.True(ok) + re.Equal(1, value) value, ok = cache.Get(2) - require.True(t, ok) - require.Equal(t, "v2", value) + re.True(ok) + re.Equal("v2", value) value, ok = cache.Get(3) - require.True(t, ok) - require.Equal(t, 3.0, value) + re.True(ok) + re.Equal(3.0, value) - require.Equal(t, 3, cache.Len()) + re.Equal(3, cache.Len()) - require.True(t, reflect.DeepEqual(sortIDs(cache.GetAllID()), []uint64{1, 2, 3})) + re.True(reflect.DeepEqual(sortIDs(cache.GetAllID()), []uint64{1, 2, 3})) time.Sleep(2 * time.Second) value, ok = cache.Get(1) - require.False(t, ok) - require.Nil(t, value) + re.False(ok) + re.Nil(value) value, ok = cache.Get(2) - require.True(t, ok) - require.Equal(t, "v2", value) + re.True(ok) + re.Equal("v2", value) value, ok = cache.Get(3) - require.True(t, ok) - require.Equal(t, 3.0, value) + re.True(ok) + re.Equal(3.0, value) - require.Equal(t, 2, cache.Len()) - require.True(t, reflect.DeepEqual(sortIDs(cache.GetAllID()), []uint64{2, 3})) + re.Equal(2, cache.Len()) + re.True(reflect.DeepEqual(sortIDs(cache.GetAllID()), []uint64{2, 3})) cache.Remove(2) value, ok = cache.Get(2) - require.False(t, ok) - require.Nil(t, value) + re.False(ok) + re.Nil(value) value, ok = cache.Get(3) - require.True(t, ok) - require.Equal(t, 3.0, value) + re.True(ok) + re.Equal(3.0, value) - require.Equal(t, 1, cache.Len()) - require.True(t, reflect.DeepEqual(sortIDs(cache.GetAllID()), []uint64{3})) + re.Equal(1, cache.Len()) + re.True(reflect.DeepEqual(sortIDs(cache.GetAllID()), []uint64{3})) } func sortIDs(ids []uint64) []uint64 { @@ -115,6 +116,7 @@ func sortIDs(ids []uint64) []uint64 { } func TestLRUCache(t *testing.T) { + re := require.New(t) cache := newLRU(3) cache.Put(1, "1") @@ -122,173 +124,175 @@ func TestLRUCache(t *testing.T) { cache.Put(3, "3") val, ok := cache.Get(3) - require.True(t, ok) - require.True(t, reflect.DeepEqual(val, "3")) + re.True(ok) + re.True(reflect.DeepEqual(val, "3")) val, ok = cache.Get(2) - require.True(t, ok) - require.True(t, reflect.DeepEqual(val, "2")) + re.True(ok) + re.True(reflect.DeepEqual(val, "2")) val, ok = cache.Get(1) - require.True(t, ok) - require.True(t, reflect.DeepEqual(val, "1")) + re.True(ok) + re.True(reflect.DeepEqual(val, "1")) - require.Equal(t, 3, cache.Len()) + re.Equal(3, cache.Len()) cache.Put(4, "4") - require.Equal(t, 3, cache.Len()) + re.Equal(3, cache.Len()) val, ok = cache.Get(3) - require.False(t, ok) - require.Nil(t, val) + re.False(ok) + re.Nil(val) val, ok = cache.Get(1) - require.True(t, ok) - require.True(t, reflect.DeepEqual(val, "1")) + re.True(ok) + re.True(reflect.DeepEqual(val, "1")) val, ok = cache.Get(2) - require.True(t, ok) - require.True(t, reflect.DeepEqual(val, "2")) + re.True(ok) + re.True(reflect.DeepEqual(val, "2")) val, ok = cache.Get(4) - require.True(t, ok) - require.True(t, reflect.DeepEqual(val, "4")) + re.True(ok) + re.True(reflect.DeepEqual(val, "4")) - require.Equal(t, 3, cache.Len()) + re.Equal(3, cache.Len()) val, ok = cache.Peek(1) - require.True(t, ok) - require.True(t, reflect.DeepEqual(val, "1")) + re.True(ok) + re.True(reflect.DeepEqual(val, "1")) elems := cache.Elems() - require.Len(t, elems, 3) - require.True(t, reflect.DeepEqual(elems[0].Value, "4")) - require.True(t, reflect.DeepEqual(elems[1].Value, "2")) - require.True(t, reflect.DeepEqual(elems[2].Value, "1")) + re.Len(elems, 3) + re.True(reflect.DeepEqual(elems[0].Value, "4")) + re.True(reflect.DeepEqual(elems[1].Value, "2")) + re.True(reflect.DeepEqual(elems[2].Value, "1")) cache.Remove(1) cache.Remove(2) cache.Remove(4) - require.Equal(t, 0, cache.Len()) + re.Equal(0, cache.Len()) val, ok = cache.Get(1) - require.False(t, ok) - require.Nil(t, val) + re.False(ok) + re.Nil(val) val, ok = cache.Get(2) - require.False(t, ok) - require.Nil(t, val) + re.False(ok) + re.Nil(val) val, ok = cache.Get(3) - require.False(t, ok) - require.Nil(t, val) + re.False(ok) + re.Nil(val) val, ok = cache.Get(4) - require.False(t, ok) - require.Nil(t, val) + re.False(ok) + re.Nil(val) } func TestFifoCache(t *testing.T) { + re := require.New(t) cache := NewFIFO(3) cache.Put(1, "1") cache.Put(2, "2") cache.Put(3, "3") - require.Equal(t, 3, cache.Len()) + re.Equal(3, cache.Len()) cache.Put(4, "4") - require.Equal(t, 3, cache.Len()) + re.Equal(3, cache.Len()) elems := cache.Elems() - require.Len(t, elems, 3) - require.True(t, reflect.DeepEqual(elems[0].Value, "2")) - require.True(t, reflect.DeepEqual(elems[1].Value, "3")) - require.True(t, reflect.DeepEqual(elems[2].Value, "4")) + re.Len(elems, 3) + re.True(reflect.DeepEqual(elems[0].Value, "2")) + re.True(reflect.DeepEqual(elems[1].Value, "3")) + re.True(reflect.DeepEqual(elems[2].Value, "4")) elems = cache.FromElems(3) - require.Len(t, elems, 1) - require.True(t, reflect.DeepEqual(elems[0].Value, "4")) + re.Len(elems, 1) + re.True(reflect.DeepEqual(elems[0].Value, "4")) cache.Remove() cache.Remove() cache.Remove() - require.Equal(t, 0, cache.Len()) + re.Equal(0, cache.Len()) } func TestTwoQueueCache(t *testing.T) { + re := require.New(t) cache := newTwoQueue(3) cache.Put(1, "1") cache.Put(2, "2") cache.Put(3, "3") val, ok := cache.Get(3) - require.True(t, ok) - require.True(t, reflect.DeepEqual(val, "3")) + re.True(ok) + re.True(reflect.DeepEqual(val, "3")) val, ok = cache.Get(2) - require.True(t, ok) - require.True(t, reflect.DeepEqual(val, "2")) + re.True(ok) + re.True(reflect.DeepEqual(val, "2")) val, ok = cache.Get(1) - require.True(t, ok) - require.True(t, reflect.DeepEqual(val, "1")) + re.True(ok) + re.True(reflect.DeepEqual(val, "1")) - require.Equal(t, 3, cache.Len()) + re.Equal(3, cache.Len()) cache.Put(4, "4") - require.Equal(t, 3, cache.Len()) + re.Equal(3, cache.Len()) val, ok = cache.Get(3) - require.False(t, ok) - require.Nil(t, val) + re.False(ok) + re.Nil(val) val, ok = cache.Get(1) - require.True(t, ok) - require.True(t, reflect.DeepEqual(val, "1")) + re.True(ok) + re.True(reflect.DeepEqual(val, "1")) val, ok = cache.Get(2) - require.True(t, ok) - require.True(t, reflect.DeepEqual(val, "2")) + re.True(ok) + re.True(reflect.DeepEqual(val, "2")) val, ok = cache.Get(4) - require.True(t, ok) - require.True(t, reflect.DeepEqual(val, "4")) + re.True(ok) + re.True(reflect.DeepEqual(val, "4")) - require.Equal(t, 3, cache.Len()) + re.Equal(3, cache.Len()) val, ok = cache.Peek(1) - require.True(t, ok) - require.True(t, reflect.DeepEqual(val, "1")) + re.True(ok) + re.True(reflect.DeepEqual(val, "1")) elems := cache.Elems() - require.Len(t, elems, 3) - require.True(t, reflect.DeepEqual(elems[0].Value, "4")) - require.True(t, reflect.DeepEqual(elems[1].Value, "2")) - require.True(t, reflect.DeepEqual(elems[2].Value, "1")) + re.Len(elems, 3) + re.True(reflect.DeepEqual(elems[0].Value, "4")) + re.True(reflect.DeepEqual(elems[1].Value, "2")) + re.True(reflect.DeepEqual(elems[2].Value, "1")) cache.Remove(1) cache.Remove(2) cache.Remove(4) - require.Equal(t, 0, cache.Len()) + re.Equal(0, cache.Len()) val, ok = cache.Get(1) - require.False(t, ok) - require.Nil(t, val) + re.False(ok) + re.Nil(val) val, ok = cache.Get(2) - require.False(t, ok) - require.Nil(t, val) + re.False(ok) + re.Nil(val) val, ok = cache.Get(3) - require.False(t, ok) - require.Nil(t, val) + re.False(ok) + re.Nil(val) val, ok = cache.Get(4) - require.False(t, ok) - require.Nil(t, val) + re.False(ok) + re.Nil(val) } var _ PriorityQueueItem = PriorityQueueItemTest(0) @@ -300,53 +304,54 @@ func (pq PriorityQueueItemTest) ID() uint64 { } func TestPriorityQueue(t *testing.T) { + re := require.New(t) testData := []PriorityQueueItemTest{0, 1, 2, 3, 4, 5} pq := NewPriorityQueue(0) - require.False(t, pq.Put(1, testData[1])) + re.False(pq.Put(1, testData[1])) // it will have priority-value pair as 1-1 2-2 3-3 pq = NewPriorityQueue(3) - require.True(t, pq.Put(1, testData[1])) - require.True(t, pq.Put(2, testData[2])) - require.True(t, pq.Put(3, testData[4])) - require.True(t, pq.Put(5, testData[4])) - require.False(t, pq.Put(5, testData[5])) - require.True(t, pq.Put(3, testData[3])) - require.True(t, pq.Put(3, testData[3])) - require.Nil(t, pq.Get(4)) - require.Equal(t, 3, pq.Len()) + re.True(pq.Put(1, testData[1])) + re.True(pq.Put(2, testData[2])) + re.True(pq.Put(3, testData[4])) + re.True(pq.Put(5, testData[4])) + re.False(pq.Put(5, testData[5])) + re.True(pq.Put(3, testData[3])) + re.True(pq.Put(3, testData[3])) + re.Nil(pq.Get(4)) + re.Equal(3, pq.Len()) // case1 test getAll, the highest element should be the first entries := pq.Elems() - require.Len(t, entries, 3) - require.Equal(t, 1, entries[0].Priority) - require.Equal(t, testData[1], entries[0].Value) - require.Equal(t, 2, entries[1].Priority) - require.Equal(t, testData[2], entries[1].Value) - require.Equal(t, 3, entries[2].Priority) - require.Equal(t, testData[3], entries[2].Value) + re.Len(entries, 3) + re.Equal(1, entries[0].Priority) + re.Equal(testData[1], entries[0].Value) + re.Equal(2, entries[1].Priority) + re.Equal(testData[2], entries[1].Value) + re.Equal(3, entries[2].Priority) + re.Equal(testData[3], entries[2].Value) // case2 test remove the high element, and the second element should be the first pq.Remove(uint64(1)) - require.Nil(t, pq.Get(1)) - require.Equal(t, 2, pq.Len()) + re.Nil(pq.Get(1)) + re.Equal(2, pq.Len()) entry := pq.Peek() - require.Equal(t, 2, entry.Priority) - require.Equal(t, testData[2], entry.Value) + re.Equal(2, entry.Priority) + re.Equal(testData[2], entry.Value) // case3 update 3's priority to highest pq.Put(-1, testData[3]) entry = pq.Peek() - require.Equal(t, -1, entry.Priority) - require.Equal(t, testData[3], entry.Value) + re.Equal(-1, entry.Priority) + re.Equal(testData[3], entry.Value) pq.Remove(entry.Value.ID()) - require.Equal(t, testData[2], pq.Peek().Value) - require.Equal(t, 1, pq.Len()) + re.Equal(testData[2], pq.Peek().Value) + re.Equal(1, pq.Len()) // case4 remove all element pq.Remove(uint64(2)) - require.Equal(t, 0, pq.Len()) - require.Len(t, pq.items, 0) - require.Nil(t, pq.Peek()) - require.Nil(t, pq.Tail()) + re.Equal(0, pq.Len()) + re.Len(pq.items, 0) + re.Nil(pq.Peek()) + re.Nil(pq.Tail()) } diff --git a/pkg/codec/codec_test.go b/pkg/codec/codec_test.go index cd73c1da0cc..50bf552a60d 100644 --- a/pkg/codec/codec_test.go +++ b/pkg/codec/codec_test.go @@ -21,27 +21,29 @@ import ( ) func TestDecodeBytes(t *testing.T) { + re := require.New(t) key := "abcdefghijklmnopqrstuvwxyz" for i := 0; i < len(key); i++ { _, k, err := DecodeBytes(EncodeBytes([]byte(key[:i]))) - require.NoError(t, err) - require.Equal(t, key[:i], string(k)) + re.NoError(err) + re.Equal(key[:i], string(k)) } } func TestTableID(t *testing.T) { + re := require.New(t) key := EncodeBytes([]byte("t\x80\x00\x00\x00\x00\x00\x00\xff")) - require.Equal(t, int64(0xff), key.TableID()) + re.Equal(int64(0xff), key.TableID()) key = EncodeBytes([]byte("t\x80\x00\x00\x00\x00\x00\x00\xff_i\x01\x02")) - require.Equal(t, int64(0xff), key.TableID()) + re.Equal(int64(0xff), key.TableID()) key = []byte("t\x80\x00\x00\x00\x00\x00\x00\xff") - require.Equal(t, int64(0), key.TableID()) + re.Equal(int64(0), key.TableID()) key = EncodeBytes([]byte("T\x00\x00\x00\x00\x00\x00\x00\xff")) - require.Equal(t, int64(0), key.TableID()) + re.Equal(int64(0), key.TableID()) key = EncodeBytes([]byte("t\x80\x00\x00\x00\x00\x00\xff")) - require.Equal(t, int64(0), key.TableID()) + re.Equal(int64(0), key.TableID()) } diff --git a/pkg/encryption/config_test.go b/pkg/encryption/config_test.go index 04e9d417686..1e3231b0903 100644 --- a/pkg/encryption/config_test.go +++ b/pkg/encryption/config_test.go @@ -23,26 +23,30 @@ import ( ) func TestAdjustDefaultValue(t *testing.T) { + re := require.New(t) config := &Config{} err := config.Adjust() - require.NoError(t, err) - require.Equal(t, methodPlaintext, config.DataEncryptionMethod) + re.NoError(err) + re.Equal(methodPlaintext, config.DataEncryptionMethod) defaultRotationPeriod, _ := time.ParseDuration(defaultDataKeyRotationPeriod) - require.Equal(t, defaultRotationPeriod, config.DataKeyRotationPeriod.Duration) - require.Equal(t, masterKeyTypePlaintext, config.MasterKey.Type) + re.Equal(defaultRotationPeriod, config.DataKeyRotationPeriod.Duration) + re.Equal(masterKeyTypePlaintext, config.MasterKey.Type) } func TestAdjustInvalidDataEncryptionMethod(t *testing.T) { + re := require.New(t) config := &Config{DataEncryptionMethod: "unknown"} - require.NotNil(t, config.Adjust()) + re.NotNil(config.Adjust()) } func TestAdjustNegativeRotationDuration(t *testing.T) { + re := require.New(t) config := &Config{DataKeyRotationPeriod: typeutil.NewDuration(time.Duration(int64(-1)))} - require.NotNil(t, config.Adjust()) + re.NotNil(config.Adjust()) } func TestAdjustInvalidMasterKeyType(t *testing.T) { + re := require.New(t) config := &Config{MasterKey: MasterKeyConfig{Type: "unknown"}} - require.NotNil(t, config.Adjust()) + re.NotNil(config.Adjust()) } diff --git a/pkg/encryption/crypter_test.go b/pkg/encryption/crypter_test.go index 716d15ecdcb..c29ed6a8725 100644 --- a/pkg/encryption/crypter_test.go +++ b/pkg/encryption/crypter_test.go @@ -24,77 +24,81 @@ import ( ) func TestEncryptionMethodSupported(t *testing.T) { - require.NotNil(t, CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_PLAINTEXT)) - require.NotNil(t, CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_UNKNOWN)) - require.Nil(t, CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_AES128_CTR)) - require.Nil(t, CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_AES192_CTR)) - require.Nil(t, CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_AES256_CTR)) + re := require.New(t) + re.NotNil(CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_PLAINTEXT)) + re.NotNil(CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_UNKNOWN)) + re.Nil(CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_AES128_CTR)) + re.Nil(CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_AES192_CTR)) + re.Nil(CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_AES256_CTR)) } func TestKeyLength(t *testing.T) { + re := require.New(t) _, err := KeyLength(encryptionpb.EncryptionMethod_PLAINTEXT) - require.NotNil(t, err) + re.NotNil(err) _, err = KeyLength(encryptionpb.EncryptionMethod_UNKNOWN) - require.NotNil(t, err) + re.NotNil(err) length, err := KeyLength(encryptionpb.EncryptionMethod_AES128_CTR) - require.NoError(t, err) - require.Equal(t, 16, length) + re.NoError(err) + re.Equal(16, length) length, err = KeyLength(encryptionpb.EncryptionMethod_AES192_CTR) - require.NoError(t, err) - require.Equal(t, 24, length) + re.NoError(err) + re.Equal(24, length) length, err = KeyLength(encryptionpb.EncryptionMethod_AES256_CTR) - require.NoError(t, err) - require.Equal(t, 32, length) + re.NoError(err) + re.Equal(32, length) } func TestNewIv(t *testing.T) { + re := require.New(t) ivCtr, err := NewIvCTR() - require.NoError(t, err) - require.Len(t, []byte(ivCtr), ivLengthCTR) + re.NoError(err) + re.Len([]byte(ivCtr), ivLengthCTR) ivGcm, err := NewIvGCM() - require.NoError(t, err) - require.Len(t, []byte(ivGcm), ivLengthGCM) + re.NoError(err) + re.Len([]byte(ivGcm), ivLengthGCM) } func TestNewDataKey(t *testing.T) { + re := require.New(t) for _, method := range []encryptionpb.EncryptionMethod{ encryptionpb.EncryptionMethod_AES128_CTR, encryptionpb.EncryptionMethod_AES192_CTR, encryptionpb.EncryptionMethod_AES256_CTR, } { _, key, err := NewDataKey(method, uint64(123)) - require.NoError(t, err) + re.NoError(err) length, err := KeyLength(method) - require.NoError(t, err) - require.Len(t, key.Key, length) - require.Equal(t, method, key.Method) - require.False(t, key.WasExposed) - require.Equal(t, uint64(123), key.CreationTime) + re.NoError(err) + re.Len(key.Key, length) + re.Equal(method, key.Method) + re.False(key.WasExposed) + re.Equal(uint64(123), key.CreationTime) } } func TestAesGcmCrypter(t *testing.T) { + re := require.New(t) key, err := hex.DecodeString("ed568fbd8c8018ed2d042a4e5d38d6341486922d401d2022fb81e47c900d3f07") - require.NoError(t, err) + re.NoError(err) plaintext, err := hex.DecodeString( "5c873a18af5e7c7c368cb2635e5a15c7f87282085f4b991e84b78c5967e946d4") - require.NoError(t, err) + re.NoError(err) // encrypt ivBytes, err := hex.DecodeString("ba432b70336c40c39ba14c1b") - require.NoError(t, err) + re.NoError(err) iv := IvGCM(ivBytes) ciphertext, err := aesGcmEncryptImpl(key, plaintext, iv) - require.NoError(t, err) - require.Len(t, []byte(iv), ivLengthGCM) - require.Equal( - t, + re.NoError(err) + re.Len([]byte(iv), ivLengthGCM) + re.Equal( "bbb9b49546350880cf55d4e4eaccc831c506a4aeae7f6cda9c821d4cb8cfc269dcdaecb09592ef25d7a33b40d3f02208", hex.EncodeToString(ciphertext), ) // decrypt plaintext2, err := AesGcmDecrypt(key, ciphertext, iv) - require.NoError(t, err) - require.True(t, bytes.Equal(plaintext2, plaintext)) + re.NoError(err) + re.True(bytes.Equal(plaintext2, plaintext)) // Modify ciphertext to test authentication failure. We modify the beginning of the ciphertext, // which is the real ciphertext part, not the tag. fakeCiphertext := make([]byte, len(ciphertext)) @@ -102,5 +106,5 @@ func TestAesGcmCrypter(t *testing.T) { // ignore overflow fakeCiphertext[0] = ciphertext[0] + 1 _, err = AesGcmDecrypt(key, fakeCiphertext, iv) - require.NotNil(t, err) + re.NotNil(err) } diff --git a/pkg/encryption/master_key_test.go b/pkg/encryption/master_key_test.go index 0fc1d376ca7..990d6322c3e 100644 --- a/pkg/encryption/master_key_test.go +++ b/pkg/encryption/master_key_test.go @@ -24,59 +24,63 @@ import ( ) func TestPlaintextMasterKey(t *testing.T) { + re := require.New(t) config := &encryptionpb.MasterKey{ Backend: &encryptionpb.MasterKey_Plaintext{ Plaintext: &encryptionpb.MasterKeyPlaintext{}, }, } masterKey, err := NewMasterKey(config, nil) - require.NoError(t, err) - require.NotNil(t, masterKey) - require.Len(t, masterKey.key, 0) + re.NoError(err) + re.NotNil(masterKey) + re.Len(masterKey.key, 0) plaintext := "this is a plaintext" ciphertext, iv, err := masterKey.Encrypt([]byte(plaintext)) - require.NoError(t, err) - require.Len(t, iv, 0) - require.Equal(t, plaintext, string(ciphertext)) + re.NoError(err) + re.Len(iv, 0) + re.Equal(plaintext, string(ciphertext)) plaintext2, err := masterKey.Decrypt(ciphertext, iv) - require.NoError(t, err) - require.Equal(t, plaintext, string(plaintext2)) + re.NoError(err) + re.Equal(plaintext, string(plaintext2)) - require.True(t, masterKey.IsPlaintext()) + re.True(masterKey.IsPlaintext()) } func TestEncrypt(t *testing.T) { + re := require.New(t) keyHex := "2f07ec61e5a50284f47f2b402a962ec672e500b26cb3aa568bb1531300c74806" key, err := hex.DecodeString(keyHex) - require.NoError(t, err) + re.NoError(err) masterKey := &MasterKey{key: key} plaintext := "this-is-a-plaintext" ciphertext, iv, err := masterKey.Encrypt([]byte(plaintext)) - require.NoError(t, err) - require.Len(t, iv, ivLengthGCM) + re.NoError(err) + re.Len(iv, ivLengthGCM) plaintext2, err := AesGcmDecrypt(key, ciphertext, iv) - require.NoError(t, err) - require.Equal(t, plaintext, string(plaintext2)) + re.NoError(err) + re.Equal(plaintext, string(plaintext2)) } func TestDecrypt(t *testing.T) { + re := require.New(t) keyHex := "2f07ec61e5a50284f47f2b402a962ec672e500b26cb3aa568bb1531300c74806" key, err := hex.DecodeString(keyHex) - require.NoError(t, err) + re.NoError(err) plaintext := "this-is-a-plaintext" iv, err := hex.DecodeString("ba432b70336c40c39ba14c1b") - require.NoError(t, err) + re.NoError(err) ciphertext, err := aesGcmEncryptImpl(key, []byte(plaintext), iv) - require.NoError(t, err) + re.NoError(err) masterKey := &MasterKey{key: key} plaintext2, err := masterKey.Decrypt(ciphertext, iv) - require.NoError(t, err) - require.Equal(t, plaintext, string(plaintext2)) + re.NoError(err) + re.Equal(plaintext, string(plaintext2)) } func TestNewFileMasterKeyMissingPath(t *testing.T) { + re := require.New(t) config := &encryptionpb.MasterKey{ Backend: &encryptionpb.MasterKey_File{ File: &encryptionpb.MasterKeyFile{ @@ -85,12 +89,13 @@ func TestNewFileMasterKeyMissingPath(t *testing.T) { }, } _, err := NewMasterKey(config, nil) - require.Error(t, err) + re.Error(err) } func TestNewFileMasterKeyMissingFile(t *testing.T) { + re := require.New(t) dir, err := os.MkdirTemp("", "test_key_files") - require.NoError(t, err) + re.NoError(err) path := dir + "/key" config := &encryptionpb.MasterKey{ Backend: &encryptionpb.MasterKey_File{ @@ -100,12 +105,13 @@ func TestNewFileMasterKeyMissingFile(t *testing.T) { }, } _, err = NewMasterKey(config, nil) - require.Error(t, err) + re.Error(err) } func TestNewFileMasterKeyNotHexString(t *testing.T) { + re := require.New(t) dir, err := os.MkdirTemp("", "test_key_files") - require.NoError(t, err) + re.NoError(err) path := dir + "/key" os.WriteFile(path, []byte("not-a-hex-string"), 0600) config := &encryptionpb.MasterKey{ @@ -116,12 +122,13 @@ func TestNewFileMasterKeyNotHexString(t *testing.T) { }, } _, err = NewMasterKey(config, nil) - require.Error(t, err) + re.Error(err) } func TestNewFileMasterKeyLengthMismatch(t *testing.T) { + re := require.New(t) dir, err := os.MkdirTemp("", "test_key_files") - require.NoError(t, err) + re.NoError(err) path := dir + "/key" os.WriteFile(path, []byte("2f07ec61e5a50284f47f2b402a962ec6"), 0600) config := &encryptionpb.MasterKey{ @@ -132,13 +139,14 @@ func TestNewFileMasterKeyLengthMismatch(t *testing.T) { }, } _, err = NewMasterKey(config, nil) - require.Error(t, err) + re.Error(err) } func TestNewFileMasterKey(t *testing.T) { + re := require.New(t) key := "2f07ec61e5a50284f47f2b402a962ec672e500b26cb3aa568bb1531300c74806" dir, err := os.MkdirTemp("", "test_key_files") - require.NoError(t, err) + re.NoError(err) path := dir + "/key" os.WriteFile(path, []byte(key), 0600) config := &encryptionpb.MasterKey{ @@ -149,6 +157,6 @@ func TestNewFileMasterKey(t *testing.T) { }, } masterKey, err := NewMasterKey(config, nil) - require.NoError(t, err) - require.Equal(t, key, hex.EncodeToString(masterKey.key)) + re.NoError(err) + re.Equal(key, hex.EncodeToString(masterKey.key)) } diff --git a/pkg/encryption/region_crypter_test.go b/pkg/encryption/region_crypter_test.go index 06398ebc7ff..b1ca558063c 100644 --- a/pkg/encryption/region_crypter_test.go +++ b/pkg/encryption/region_crypter_test.go @@ -70,15 +70,17 @@ func (m *testKeyManager) GetKey(keyID uint64) (*encryptionpb.DataKey, error) { } func TestNilRegion(t *testing.T) { + re := require.New(t) m := newTestKeyManager() region, err := EncryptRegion(nil, m) - require.Error(t, err) - require.Nil(t, region) + re.Error(err) + re.Nil(region) err = DecryptRegion(nil, m) - require.Error(t, err) + re.Error(err) } func TestEncryptRegionWithoutKeyManager(t *testing.T) { + re := require.New(t) region := &metapb.Region{ Id: 10, StartKey: []byte("abc"), @@ -86,14 +88,15 @@ func TestEncryptRegionWithoutKeyManager(t *testing.T) { EncryptionMeta: nil, } region, err := EncryptRegion(region, nil) - require.NoError(t, err) + re.NoError(err) // check the region isn't changed - require.Equal(t, "abc", string(region.StartKey)) - require.Equal(t, "xyz", string(region.EndKey)) - require.Nil(t, region.EncryptionMeta) + re.Equal("abc", string(region.StartKey)) + re.Equal("xyz", string(region.EndKey)) + re.Nil(region.EncryptionMeta) } func TestEncryptRegionWhileEncryptionDisabled(t *testing.T) { + re := require.New(t) region := &metapb.Region{ Id: 10, StartKey: []byte("abc"), @@ -103,14 +106,15 @@ func TestEncryptRegionWhileEncryptionDisabled(t *testing.T) { m := newTestKeyManager() m.EncryptionEnabled = false region, err := EncryptRegion(region, m) - require.NoError(t, err) + re.NoError(err) // check the region isn't changed - require.Equal(t, "abc", string(region.StartKey)) - require.Equal(t, "xyz", string(region.EndKey)) - require.Nil(t, region.EncryptionMeta) + re.Equal("abc", string(region.StartKey)) + re.Equal("xyz", string(region.EndKey)) + re.Nil(region.EncryptionMeta) } func TestEncryptRegion(t *testing.T) { + re := require.New(t) startKey := []byte("abc") endKey := []byte("xyz") region := &metapb.Region{ @@ -123,27 +127,28 @@ func TestEncryptRegion(t *testing.T) { copy(region.EndKey, endKey) m := newTestKeyManager() outRegion, err := EncryptRegion(region, m) - require.NoError(t, err) - require.NotEqual(t, region, outRegion) + re.NoError(err) + re.NotEqual(region, outRegion) // check region is encrypted - require.NotNil(t, outRegion.EncryptionMeta) - require.Equal(t, uint64(2), outRegion.EncryptionMeta.KeyId) - require.Len(t, outRegion.EncryptionMeta.Iv, ivLengthCTR) + re.NotNil(outRegion.EncryptionMeta) + re.Equal(uint64(2), outRegion.EncryptionMeta.KeyId) + re.Len(outRegion.EncryptionMeta.Iv, ivLengthCTR) // Check encrypted content _, currentKey, err := m.GetCurrentKey() - require.NoError(t, err) + re.NoError(err) block, err := aes.NewCipher(currentKey.Key) - require.NoError(t, err) + re.NoError(err) stream := cipher.NewCTR(block, outRegion.EncryptionMeta.Iv) ciphertextStartKey := make([]byte, len(startKey)) stream.XORKeyStream(ciphertextStartKey, startKey) - require.Equal(t, string(ciphertextStartKey), string(outRegion.StartKey)) + re.Equal(string(ciphertextStartKey), string(outRegion.StartKey)) ciphertextEndKey := make([]byte, len(endKey)) stream.XORKeyStream(ciphertextEndKey, endKey) - require.Equal(t, string(ciphertextEndKey), string(outRegion.EndKey)) + re.Equal(string(ciphertextEndKey), string(outRegion.EndKey)) } func TestDecryptRegionNotEncrypted(t *testing.T) { + re := require.New(t) region := &metapb.Region{ Id: 10, StartKey: []byte("abc"), @@ -152,14 +157,15 @@ func TestDecryptRegionNotEncrypted(t *testing.T) { } m := newTestKeyManager() err := DecryptRegion(region, m) - require.NoError(t, err) + re.NoError(err) // check the region isn't changed - require.Equal(t, "abc", string(region.StartKey)) - require.Equal(t, "xyz", string(region.EndKey)) - require.Nil(t, region.EncryptionMeta) + re.Equal("abc", string(region.StartKey)) + re.Equal("xyz", string(region.EndKey)) + re.Nil(region.EncryptionMeta) } func TestDecryptRegionWithoutKeyManager(t *testing.T) { + re := require.New(t) region := &metapb.Region{ Id: 10, StartKey: []byte("abc"), @@ -170,14 +176,15 @@ func TestDecryptRegionWithoutKeyManager(t *testing.T) { }, } err := DecryptRegion(region, nil) - require.Error(t, err) + re.Error(err) } func TestDecryptRegionWhileKeyMissing(t *testing.T) { + re := require.New(t) keyID := uint64(3) m := newTestKeyManager() _, err := m.GetKey(3) - require.Error(t, err) + re.Error(err) region := &metapb.Region{ Id: 10, @@ -189,10 +196,11 @@ func TestDecryptRegionWhileKeyMissing(t *testing.T) { }, } err = DecryptRegion(region, m) - require.Error(t, err) + re.Error(err) } func TestDecryptRegion(t *testing.T) { + re := require.New(t) keyID := uint64(1) startKey := []byte("abc") endKey := []byte("xyz") @@ -211,19 +219,19 @@ func TestDecryptRegion(t *testing.T) { copy(region.EncryptionMeta.Iv, iv) m := newTestKeyManager() err := DecryptRegion(region, m) - require.NoError(t, err) + re.NoError(err) // check region is decrypted - require.Nil(t, region.EncryptionMeta) + re.Nil(region.EncryptionMeta) // Check decrypted content key, err := m.GetKey(keyID) - require.NoError(t, err) + re.NoError(err) block, err := aes.NewCipher(key.Key) - require.NoError(t, err) + re.NoError(err) stream := cipher.NewCTR(block, iv) plaintextStartKey := make([]byte, len(startKey)) stream.XORKeyStream(plaintextStartKey, startKey) - require.Equal(t, string(plaintextStartKey), string(region.StartKey)) + re.Equal(string(plaintextStartKey), string(region.StartKey)) plaintextEndKey := make([]byte, len(endKey)) stream.XORKeyStream(plaintextEndKey, endKey) - require.Equal(t, string(plaintextEndKey), string(region.EndKey)) + re.Equal(string(plaintextEndKey), string(region.EndKey)) } diff --git a/pkg/errs/errs_test.go b/pkg/errs/errs_test.go index 65ebb6460d0..74e55257d70 100644 --- a/pkg/errs/errs_test.go +++ b/pkg/errs/errs_test.go @@ -72,44 +72,46 @@ func newZapTestLogger(cfg *log.Config, opts ...zap.Option) verifyLogger { } func TestError(t *testing.T) { + re := require.New(t) conf := &log.Config{Level: "debug", File: log.FileLogConfig{}, DisableTimestamp: true} lg := newZapTestLogger(conf) log.ReplaceGlobals(lg.Logger, nil) rfc := `[error="[PD:member:ErrEtcdLeaderNotFound]etcd leader not found` log.Error("test", zap.Error(ErrEtcdLeaderNotFound.FastGenByArgs())) - require.Contains(t, lg.Message(), rfc) + re.Contains(lg.Message(), rfc) err := errors.New("test error") log.Error("test", ZapError(ErrEtcdLeaderNotFound, err)) rfc = `[error="[PD:member:ErrEtcdLeaderNotFound]test error` - require.Contains(t, lg.Message(), rfc) + re.Contains(lg.Message(), rfc) } func TestErrorEqual(t *testing.T) { + re := require.New(t) err1 := ErrSchedulerNotFound.FastGenByArgs() err2 := ErrSchedulerNotFound.FastGenByArgs() - require.True(t, errors.ErrorEqual(err1, err2)) + re.True(errors.ErrorEqual(err1, err2)) err := errors.New("test") err1 = ErrSchedulerNotFound.Wrap(err).FastGenWithCause() err2 = ErrSchedulerNotFound.Wrap(err).FastGenWithCause() - require.True(t, errors.ErrorEqual(err1, err2)) + re.True(errors.ErrorEqual(err1, err2)) err1 = ErrSchedulerNotFound.FastGenByArgs() err2 = ErrSchedulerNotFound.Wrap(err).FastGenWithCause() - require.False(t, errors.ErrorEqual(err1, err2)) + re.False(errors.ErrorEqual(err1, err2)) err3 := errors.New("test") err4 := errors.New("test") err1 = ErrSchedulerNotFound.Wrap(err3).FastGenWithCause() err2 = ErrSchedulerNotFound.Wrap(err4).FastGenWithCause() - require.True(t, errors.ErrorEqual(err1, err2)) + re.True(errors.ErrorEqual(err1, err2)) err3 = errors.New("test1") err4 = errors.New("test") err1 = ErrSchedulerNotFound.Wrap(err3).FastGenWithCause() err2 = ErrSchedulerNotFound.Wrap(err4).FastGenWithCause() - require.False(t, errors.ErrorEqual(err1, err2)) + re.False(errors.ErrorEqual(err1, err2)) } func TestZapError(t *testing.T) { @@ -121,6 +123,7 @@ func TestZapError(t *testing.T) { } func TestErrorWithStack(t *testing.T) { + re := require.New(t) conf := &log.Config{Level: "debug", File: log.FileLogConfig{}, DisableTimestamp: true} lg := newZapTestLogger(conf) log.ReplaceGlobals(lg.Logger, nil) @@ -133,8 +136,8 @@ func TestErrorWithStack(t *testing.T) { // This test is based on line number and the first log is in line 141, the second is in line 142. // So they have the same length stack. Move this test to another place need to change the corresponding length. idx1 := strings.Index(m1, "[stack=") - require.GreaterOrEqual(t, idx1, -1) + re.GreaterOrEqual(idx1, -1) idx2 := strings.Index(m2, "[stack=") - require.GreaterOrEqual(t, idx2, -1) - require.Equal(t, len(m1[idx1:]), len(m2[idx2:])) + re.GreaterOrEqual(idx2, -1) + re.Equal(len(m1[idx1:]), len(m2[idx2:])) } diff --git a/pkg/etcdutil/etcdutil_test.go b/pkg/etcdutil/etcdutil_test.go index 7bc73f12cbe..7731a319a94 100644 --- a/pkg/etcdutil/etcdutil_test.go +++ b/pkg/etcdutil/etcdutil_test.go @@ -28,28 +28,29 @@ import ( ) func TestMemberHelpers(t *testing.T) { + re := require.New(t) cfg1 := NewTestSingleConfig() etcd1, err := embed.StartEtcd(cfg1) defer func() { etcd1.Close() CleanConfig(cfg1) }() - require.NoError(t, err) + re.NoError(err) ep1 := cfg1.LCUrls[0].String() client1, err := clientv3.New(clientv3.Config{ Endpoints: []string{ep1}, }) - require.NoError(t, err) + re.NoError(err) <-etcd1.Server.ReadyNotify() // Test ListEtcdMembers listResp1, err := ListEtcdMembers(client1) - require.NoError(t, err) - require.Len(t, listResp1.Members, 1) + re.NoError(err) + re.Len(listResp1.Members, 1) // types.ID is an alias of uint64. - require.Equal(t, uint64(etcd1.Server.ID()), listResp1.Members[0].ID) + re.Equal(uint64(etcd1.Server.ID()), listResp1.Members[0].ID) // Test AddEtcdMember // Make a new etcd config. @@ -61,28 +62,28 @@ func TestMemberHelpers(t *testing.T) { // Add it to the cluster above. peerURL := cfg2.LPUrls[0].String() addResp, err := AddEtcdMember(client1, []string{peerURL}) - require.NoError(t, err) + re.NoError(err) etcd2, err := embed.StartEtcd(cfg2) defer func() { etcd2.Close() CleanConfig(cfg2) }() - require.NoError(t, err) - require.Equal(t, uint64(etcd2.Server.ID()), addResp.Member.ID) + re.NoError(err) + re.Equal(uint64(etcd2.Server.ID()), addResp.Member.ID) ep2 := cfg2.LCUrls[0].String() client2, err := clientv3.New(clientv3.Config{ Endpoints: []string{ep2}, }) - require.NoError(t, err) + re.NoError(err) <-etcd2.Server.ReadyNotify() - require.NoError(t, err) + re.NoError(err) listResp2, err := ListEtcdMembers(client2) - require.NoError(t, err) - require.Len(t, listResp2.Members, 2) + re.NoError(err) + re.Len(listResp2.Members, 2) for _, m := range listResp2.Members { switch m.ID { case uint64(etcd1.Server.ID()): @@ -94,34 +95,35 @@ func TestMemberHelpers(t *testing.T) { // Test CheckClusterID urlsMap, err := types.NewURLsMap(cfg2.InitialCluster) - require.NoError(t, err) + re.NoError(err) err = CheckClusterID(etcd1.Server.Cluster().ID(), urlsMap, &tls.Config{MinVersion: tls.VersionTLS12}) - require.NoError(t, err) + re.NoError(err) // Test RemoveEtcdMember _, err = RemoveEtcdMember(client1, uint64(etcd2.Server.ID())) - require.NoError(t, err) + re.NoError(err) listResp3, err := ListEtcdMembers(client1) - require.NoError(t, err) - require.Len(t, listResp3.Members, 1) - require.Equal(t, uint64(etcd1.Server.ID()), listResp3.Members[0].ID) + re.NoError(err) + re.Len(listResp3.Members, 1) + re.Equal(uint64(etcd1.Server.ID()), listResp3.Members[0].ID) } func TestEtcdKVGet(t *testing.T) { + re := require.New(t) cfg := NewTestSingleConfig() etcd, err := embed.StartEtcd(cfg) defer func() { etcd.Close() CleanConfig(cfg) }() - require.NoError(t, err) + re.NoError(err) ep := cfg.LCUrls[0].String() client, err := clientv3.New(clientv3.Config{ Endpoints: []string{ep}, }) - require.NoError(t, err) + re.NoError(err) <-etcd.Server.ReadyNotify() @@ -131,69 +133,70 @@ func TestEtcdKVGet(t *testing.T) { kv := clientv3.NewKV(client) for i := range keys { _, err = kv.Put(context.TODO(), keys[i], vals[i]) - require.NoError(t, err) + re.NoError(err) } // Test simple point get resp, err := EtcdKVGet(client, "test/key1") - require.NoError(t, err) - require.Equal(t, "val1", string(resp.Kvs[0].Value)) + re.NoError(err) + re.Equal("val1", string(resp.Kvs[0].Value)) // Test range get withRange := clientv3.WithRange("test/zzzz") withLimit := clientv3.WithLimit(3) resp, err = EtcdKVGet(client, "test/", withRange, withLimit, clientv3.WithSort(clientv3.SortByKey, clientv3.SortAscend)) - require.NoError(t, err) - require.Len(t, resp.Kvs, 3) + re.NoError(err) + re.Len(resp.Kvs, 3) for i := range resp.Kvs { - require.Equal(t, keys[i], string(resp.Kvs[i].Key)) - require.Equal(t, vals[i], string(resp.Kvs[i].Value)) + re.Equal(keys[i], string(resp.Kvs[i].Key)) + re.Equal(vals[i], string(resp.Kvs[i].Value)) } lastKey := string(resp.Kvs[len(resp.Kvs)-1].Key) next := clientv3.GetPrefixRangeEnd(lastKey) resp, err = EtcdKVGet(client, next, withRange, withLimit, clientv3.WithSort(clientv3.SortByKey, clientv3.SortAscend)) - require.NoError(t, err) - require.Len(t, resp.Kvs, 2) + re.NoError(err) + re.Len(resp.Kvs, 2) } func TestEtcdKVPutWithTTL(t *testing.T) { + re := require.New(t) cfg := NewTestSingleConfig() etcd, err := embed.StartEtcd(cfg) defer func() { etcd.Close() CleanConfig(cfg) }() - require.NoError(t, err) + re.NoError(err) ep := cfg.LCUrls[0].String() client, err := clientv3.New(clientv3.Config{ Endpoints: []string{ep}, }) - require.NoError(t, err) + re.NoError(err) <-etcd.Server.ReadyNotify() _, err = EtcdKVPutWithTTL(context.TODO(), client, "test/ttl1", "val1", 2) - require.NoError(t, err) + re.NoError(err) _, err = EtcdKVPutWithTTL(context.TODO(), client, "test/ttl2", "val2", 4) - require.NoError(t, err) + re.NoError(err) time.Sleep(3 * time.Second) // test/ttl1 is outdated resp, err := EtcdKVGet(client, "test/ttl1") - require.NoError(t, err) - require.Equal(t, int64(0), resp.Count) + re.NoError(err) + re.Equal(int64(0), resp.Count) // but test/ttl2 is not resp, err = EtcdKVGet(client, "test/ttl2") - require.NoError(t, err) - require.Equal(t, "val2", string(resp.Kvs[0].Value)) + re.NoError(err) + re.Equal("val2", string(resp.Kvs[0].Value)) time.Sleep(2 * time.Second) // test/ttl2 is also outdated resp, err = EtcdKVGet(client, "test/ttl2") - require.NoError(t, err) - require.Equal(t, int64(0), resp.Count) + re.NoError(err) + re.Equal(int64(0), resp.Count) } diff --git a/pkg/grpcutil/grpcutil_test.go b/pkg/grpcutil/grpcutil_test.go index d1b9d3a8830..44eee64b85e 100644 --- a/pkg/grpcutil/grpcutil_test.go +++ b/pkg/grpcutil/grpcutil_test.go @@ -9,18 +9,19 @@ import ( "github.com/tikv/pd/pkg/errs" ) -func loadTLSContent(t *testing.T, caPath, certPath, keyPath string) (caData, certData, keyData []byte) { +func loadTLSContent(re *require.Assertions, caPath, certPath, keyPath string) (caData, certData, keyData []byte) { var err error caData, err = os.ReadFile(caPath) - require.NoError(t, err) + re.NoError(err) certData, err = os.ReadFile(certPath) - require.NoError(t, err) + re.NoError(err) keyData, err = os.ReadFile(keyPath) - require.NoError(t, err) + re.NoError(err) return } func TestToTLSConfig(t *testing.T) { + re := require.New(t) tlsConfig := TLSConfig{ KeyPath: "../../tests/client/cert/pd-server-key.pem", CertPath: "../../tests/client/cert/pd-server.pem", @@ -28,24 +29,24 @@ func TestToTLSConfig(t *testing.T) { } // test without bytes _, err := tlsConfig.ToTLSConfig() - require.NoError(t, err) + re.NoError(err) // test with bytes - caData, certData, keyData := loadTLSContent(t, tlsConfig.CAPath, tlsConfig.CertPath, tlsConfig.KeyPath) + caData, certData, keyData := loadTLSContent(re, tlsConfig.CAPath, tlsConfig.CertPath, tlsConfig.KeyPath) tlsConfig.SSLCABytes = caData tlsConfig.SSLCertBytes = certData tlsConfig.SSLKEYBytes = keyData _, err = tlsConfig.ToTLSConfig() - require.NoError(t, err) + re.NoError(err) // test wrong cert bytes tlsConfig.SSLCertBytes = []byte("invalid cert") _, err = tlsConfig.ToTLSConfig() - require.True(t, errors.ErrorEqual(err, errs.ErrCryptoX509KeyPair)) + re.True(errors.ErrorEqual(err, errs.ErrCryptoX509KeyPair)) // test wrong ca bytes tlsConfig.SSLCertBytes = certData tlsConfig.SSLCABytes = []byte("invalid ca") _, err = tlsConfig.ToTLSConfig() - require.True(t, errors.ErrorEqual(err, errs.ErrCryptoAppendCertsFromPEM)) + re.True(errors.ErrorEqual(err, errs.ErrCryptoAppendCertsFromPEM)) } diff --git a/pkg/keyutil/util_test.go b/pkg/keyutil/util_test.go index 6603c61b131..f69463c5060 100644 --- a/pkg/keyutil/util_test.go +++ b/pkg/keyutil/util_test.go @@ -21,8 +21,9 @@ import ( ) func TestKeyUtil(t *testing.T) { + re := require.New(t) startKey := []byte("a") endKey := []byte("b") key := BuildKeyRangeKey(startKey, endKey) - require.Equal(t, "61-62", key) + re.Equal("61-62", key) } diff --git a/pkg/logutil/log_test.go b/pkg/logutil/log_test.go index 42a9126ea33..270a8e5b0ba 100644 --- a/pkg/logutil/log_test.go +++ b/pkg/logutil/log_test.go @@ -24,16 +24,18 @@ import ( ) func TestStringToZapLogLevel(t *testing.T) { - require.Equal(t, zapcore.FatalLevel, StringToZapLogLevel("fatal")) - require.Equal(t, zapcore.ErrorLevel, StringToZapLogLevel("ERROR")) - require.Equal(t, zapcore.WarnLevel, StringToZapLogLevel("warn")) - require.Equal(t, zapcore.WarnLevel, StringToZapLogLevel("warning")) - require.Equal(t, zapcore.DebugLevel, StringToZapLogLevel("debug")) - require.Equal(t, zapcore.InfoLevel, StringToZapLogLevel("info")) - require.Equal(t, zapcore.InfoLevel, StringToZapLogLevel("whatever")) + re := require.New(t) + re.Equal(zapcore.FatalLevel, StringToZapLogLevel("fatal")) + re.Equal(zapcore.ErrorLevel, StringToZapLogLevel("ERROR")) + re.Equal(zapcore.WarnLevel, StringToZapLogLevel("warn")) + re.Equal(zapcore.WarnLevel, StringToZapLogLevel("warning")) + re.Equal(zapcore.DebugLevel, StringToZapLogLevel("debug")) + re.Equal(zapcore.InfoLevel, StringToZapLogLevel("info")) + re.Equal(zapcore.InfoLevel, StringToZapLogLevel("whatever")) } func TestRedactLog(t *testing.T) { + re := require.New(t) testCases := []struct { name string arg interface{} @@ -71,11 +73,11 @@ func TestRedactLog(t *testing.T) { SetRedactLog(testCase.enableRedactLog) switch r := testCase.arg.(type) { case []byte: - require.True(t, reflect.DeepEqual(testCase.expect, RedactBytes(r))) + re.True(reflect.DeepEqual(testCase.expect, RedactBytes(r))) case string: - require.True(t, reflect.DeepEqual(testCase.expect, RedactString(r))) + re.True(reflect.DeepEqual(testCase.expect, RedactString(r))) case fmt.Stringer: - require.True(t, reflect.DeepEqual(testCase.expect, RedactStringer(r))) + re.True(reflect.DeepEqual(testCase.expect, RedactStringer(r))) default: panic("unmatched case") } diff --git a/pkg/metricutil/metricutil_test.go b/pkg/metricutil/metricutil_test.go index 512732c7f7e..a72eb7ee5f5 100644 --- a/pkg/metricutil/metricutil_test.go +++ b/pkg/metricutil/metricutil_test.go @@ -23,6 +23,7 @@ import ( ) func TestCamelCaseToSnakeCase(t *testing.T) { + re := require.New(t) inputs := []struct { name string newName string @@ -50,7 +51,7 @@ func TestCamelCaseToSnakeCase(t *testing.T) { } for _, input := range inputs { - require.Equal(t, input.newName, camelCaseToSnakeCase(input.name)) + re.Equal(input.newName, camelCaseToSnakeCase(input.name)) } } From a7ac85daa078f913610e5bd1ac101f824d084608 Mon Sep 17 00:00:00 2001 From: buffer <1045931706@qq.com> Date: Tue, 31 May 2022 16:22:27 +0800 Subject: [PATCH 05/82] config: fix the bug that the type of bucket size is not right. (#5074) close tikv/pd#5073 Signed-off-by: bufferflies <1045931706@qq.com> Co-authored-by: Ti Chi Robot --- server/config/store_config.go | 11 ++++++++--- server/config/store_config_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/server/config/store_config.go b/server/config/store_config.go index 6e8ba7c22f7..27fc456dd08 100644 --- a/server/config/store_config.go +++ b/server/config/store_config.go @@ -35,6 +35,8 @@ var ( defaultRegionMaxSize = uint64(144) // default region split size is 96MB defaultRegionSplitSize = uint64(96) + // default bucket size is 96MB + defaultBucketSize = uint64(96) // default region max key is 144000 defaultRegionMaxKey = uint64(1440000) // default region split key is 960000 @@ -58,7 +60,7 @@ type Coprocessor struct { RegionMaxKeys int `json:"region-max-keys"` RegionSplitKeys int `json:"region-split-keys"` EnableRegionBucket bool `json:"enable-region-bucket"` - RegionBucketSize int `json:"region-bucket-size"` + RegionBucketSize string `json:"region-bucket-size"` } // String implements fmt.Stringer interface. @@ -111,11 +113,14 @@ func (c *StoreConfig) IsEnableRegionBucket() bool { } // GetRegionBucketSize returns region bucket size if enable region buckets. -func (c *StoreConfig) GetRegionBucketSize() int { +func (c *StoreConfig) GetRegionBucketSize() uint64 { if c == nil || !c.Coprocessor.EnableRegionBucket { return 0 } - return c.Coprocessor.RegionBucketSize + if len(c.Coprocessor.RegionBucketSize) == 0 { + return defaultBucketSize + } + return typeutil.ParseMBFromText(c.Coprocessor.RegionBucketSize, defaultBucketSize) } // CheckRegionSize return error if the smallest region's size is less than mergeSize diff --git a/server/config/store_config_test.go b/server/config/store_config_test.go index 106d8b7bf4e..478e1ebb3d7 100644 --- a/server/config/store_config_test.go +++ b/server/config/store_config_test.go @@ -77,6 +77,30 @@ func (t *testTiKVConfigSuite) TestUpdateConfig(c *C) { c.Assert(manager.source.(*TiKVConfigSource).schema, Equals, "http") } +func (t *testTiKVConfigSuite) TestParseConfig(c *C) { + body := ` +{ +"coprocessor":{ +"split-region-on-table":false, +"batch-split-limit":10, +"region-max-size":"384MiB", +"region-split-size":"256MiB", +"region-max-keys":3840000, +"region-split-keys":2560000, +"consistency-check-method":"mvcc", +"enable-region-bucket":true, +"region-bucket-size":"96MiB", +"region-size-threshold-for-approximate":"384MiB", +"region-bucket-merge-size-ratio":0.33 +} +} +` + + var config StoreConfig + c.Assert(json.Unmarshal([]byte(body), &config), IsNil) + c.Assert(config.GetRegionBucketSize(), Equals, uint64(96)) +} + func (t *testTiKVConfigSuite) TestMergeCheck(c *C) { testdata := []struct { size uint64 From b92303c6a0395e0a231d4b00e89dd9a105e196b4 Mon Sep 17 00:00:00 2001 From: Yongbo Jiang Date: Tue, 31 May 2022 18:04:27 +0800 Subject: [PATCH 06/82] workflow: change timeout-minutes of `statics` to 8 (#5079) close tikv/pd#5078 change timeout-minutes of `statics` to 8 Signed-off-by: Cabinfever_B --- .github/workflows/check.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/check.yaml b/.github/workflows/check.yaml index 4fbed3641e5..47ef287d73f 100644 --- a/.github/workflows/check.yaml +++ b/.github/workflows/check.yaml @@ -6,6 +6,7 @@ concurrency: jobs: statics: runs-on: ubuntu-latest + timeout-minutes: 8 steps: - uses: actions/setup-go@v2 with: From 52dd58715804faa5e6ee88bf11183b61b5273570 Mon Sep 17 00:00:00 2001 From: Yongbo Jiang Date: Tue, 31 May 2022 18:16:28 +0800 Subject: [PATCH 07/82] *: Add Limiter Config (#4839) ref tikv/pd#4666 Add Rate Limiter Config for server Signed-off-by: Cabinfever_B Co-authored-by: Ryan Leung Co-authored-by: Ti Chi Robot --- pkg/ratelimit/limiter.go | 27 ++++++++--- pkg/ratelimit/limiter_test.go | 45 +++++++++++------- pkg/ratelimit/option.go | 88 +++++++++++++++++++++++++++++------ 3 files changed, 124 insertions(+), 36 deletions(-) diff --git a/pkg/ratelimit/limiter.go b/pkg/ratelimit/limiter.go index 43f01cea41b..4bf930ed6c5 100644 --- a/pkg/ratelimit/limiter.go +++ b/pkg/ratelimit/limiter.go @@ -20,6 +20,15 @@ import ( "golang.org/x/time/rate" ) +// DimensionConfig is the limit dimension config of one label +type DimensionConfig struct { + // qps conifg + QPS float64 + QPSBurst int + // concurrency config + ConcurrencyLimit uint64 +} + // Limiter is a controller for the request rate. type Limiter struct { qpsLimiter sync.Map @@ -30,7 +39,9 @@ type Limiter struct { // NewLimiter returns a global limiter which can be updated in the later. func NewLimiter() *Limiter { - return &Limiter{labelAllowList: make(map[string]struct{})} + return &Limiter{ + labelAllowList: make(map[string]struct{}), + } } // Allow is used to check whether it has enough token. @@ -65,10 +76,12 @@ func (l *Limiter) Release(label string) { } // Update is used to update Ratelimiter with Options -func (l *Limiter) Update(label string, opts ...Option) { +func (l *Limiter) Update(label string, opts ...Option) UpdateStatus { + var status UpdateStatus for _, opt := range opts { - opt(label, l) + status |= opt(label, l) } + return status } // GetQPSLimiterStatus returns the status of a given label's QPS limiter. @@ -80,8 +93,8 @@ func (l *Limiter) GetQPSLimiterStatus(label string) (limit rate.Limit, burst int return 0, 0 } -// DeleteQPSLimiter deletes QPS limiter of given label -func (l *Limiter) DeleteQPSLimiter(label string) { +// QPSUnlimit deletes QPS limiter of the given label +func (l *Limiter) QPSUnlimit(label string) { l.qpsLimiter.Delete(label) } @@ -94,8 +107,8 @@ func (l *Limiter) GetConcurrencyLimiterStatus(label string) (limit uint64, curre return 0, 0 } -// DeleteConcurrencyLimiter deletes concurrency limiter of given label -func (l *Limiter) DeleteConcurrencyLimiter(label string) { +// ConcurrencyUnlimit deletes concurrency limiter of the given label +func (l *Limiter) ConcurrencyUnlimit(label string) { l.concurrencyLimiter.Delete(label) } diff --git a/pkg/ratelimit/limiter_test.go b/pkg/ratelimit/limiter_test.go index bd095543a05..cf75d76152a 100644 --- a/pkg/ratelimit/limiter_test.go +++ b/pkg/ratelimit/limiter_test.go @@ -34,9 +34,8 @@ func (s *testRatelimiterSuite) TestUpdateConcurrencyLimiter(c *C) { limiter := NewLimiter() label := "test" - for _, opt := range opts { - opt(label, limiter) - } + status := limiter.Update(label, opts...) + c.Assert(status&ConcurrencyChanged != 0, IsTrue) var lock sync.Mutex successCount, failedCount := 0, 0 var wg sync.WaitGroup @@ -57,7 +56,11 @@ func (s *testRatelimiterSuite) TestUpdateConcurrencyLimiter(c *C) { c.Assert(limit, Equals, uint64(10)) c.Assert(current, Equals, uint64(0)) - limiter.Update(label, UpdateConcurrencyLimiter(5)) + status = limiter.Update(label, UpdateConcurrencyLimiter(10)) + c.Assert(status&ConcurrencyNoChange != 0, IsTrue) + + status = limiter.Update(label, UpdateConcurrencyLimiter(5)) + c.Assert(status&ConcurrencyChanged != 0, IsTrue) failedCount = 0 successCount = 0 for i := 0; i < 15; i++ { @@ -71,7 +74,8 @@ func (s *testRatelimiterSuite) TestUpdateConcurrencyLimiter(c *C) { limiter.Release(label) } - limiter.DeleteConcurrencyLimiter(label) + status = limiter.Update(label, UpdateConcurrencyLimiter(0)) + c.Assert(status&ConcurrencyDeleted != 0, IsTrue) failedCount = 0 successCount = 0 for i := 0; i < 15; i++ { @@ -99,7 +103,8 @@ func (s *testRatelimiterSuite) TestBlockList(c *C) { } c.Assert(limiter.IsInAllowList(label), Equals, true) - UpdateQPSLimiter(rate.Every(time.Second), 1)(label, limiter) + status := UpdateQPSLimiter(float64(rate.Every(time.Second)), 1)(label, limiter) + c.Assert(status&InAllowList != 0, Equals, true) for i := 0; i < 10; i++ { c.Assert(limiter.Allow(label), Equals, true) } @@ -107,13 +112,12 @@ func (s *testRatelimiterSuite) TestBlockList(c *C) { func (s *testRatelimiterSuite) TestUpdateQPSLimiter(c *C) { c.Parallel() - opts := []Option{UpdateQPSLimiter(rate.Every(time.Second), 1)} + opts := []Option{UpdateQPSLimiter(float64(rate.Every(time.Second)), 1)} limiter := NewLimiter() label := "test" - for _, opt := range opts { - opt(label, limiter) - } + status := limiter.Update(label, opts...) + c.Assert(status&QPSChanged != 0, IsTrue) var lock sync.Mutex successCount, failedCount := 0, 0 @@ -130,7 +134,11 @@ func (s *testRatelimiterSuite) TestUpdateQPSLimiter(c *C) { c.Assert(limit, Equals, rate.Limit(1)) c.Assert(burst, Equals, 1) - limiter.Update(label, UpdateQPSLimiter(5, 5)) + status = limiter.Update(label, UpdateQPSLimiter(float64(rate.Every(time.Second)), 1)) + c.Assert(status&QPSNoChange != 0, IsTrue) + + status = limiter.Update(label, UpdateQPSLimiter(5, 5)) + c.Assert(status&QPSChanged != 0, IsTrue) limit, burst = limiter.GetQPSLimiterStatus(label) c.Assert(limit, Equals, rate.Limit(5)) c.Assert(burst, Equals, 5) @@ -144,7 +152,9 @@ func (s *testRatelimiterSuite) TestUpdateQPSLimiter(c *C) { } } time.Sleep(time.Second) - limiter.DeleteQPSLimiter(label) + + status = limiter.Update(label, UpdateQPSLimiter(0, 0)) + c.Assert(status&QPSDeleted != 0, IsTrue) for i := 0; i < 10; i++ { c.Assert(limiter.Allow(label), Equals, true) } @@ -155,7 +165,7 @@ func (s *testRatelimiterSuite) TestUpdateQPSLimiter(c *C) { func (s *testRatelimiterSuite) TestQPSLimiter(c *C) { c.Parallel() - opts := []Option{UpdateQPSLimiter(rate.Every(3*time.Second), 100)} + opts := []Option{UpdateQPSLimiter(float64(rate.Every(3*time.Second)), 100)} limiter := NewLimiter() label := "test" @@ -184,9 +194,12 @@ func (s *testRatelimiterSuite) TestQPSLimiter(c *C) { func (s *testRatelimiterSuite) TestTwoLimiters(c *C) { c.Parallel() - opts := []Option{UpdateQPSLimiter(100, 100), - UpdateConcurrencyLimiter(100), + cfg := &DimensionConfig{ + QPS: 100, + QPSBurst: 100, + ConcurrencyLimit: 100, } + opts := []Option{UpdateDimensionConfig(cfg)} limiter := NewLimiter() label := "test" @@ -217,7 +230,7 @@ func (s *testRatelimiterSuite) TestTwoLimiters(c *C) { for i := 0; i < 100; i++ { limiter.Release(label) } - limiter.Update(label, UpdateQPSLimiter(rate.Every(10*time.Second), 1)) + limiter.Update(label, UpdateQPSLimiter(float64(rate.Every(10*time.Second)), 1)) wg.Add(100) for i := 0; i < 100; i++ { go CountRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) diff --git a/pkg/ratelimit/option.go b/pkg/ratelimit/option.go index af98eddb827..53afb9926d4 100644 --- a/pkg/ratelimit/option.go +++ b/pkg/ratelimit/option.go @@ -16,39 +16,101 @@ package ratelimit import "golang.org/x/time/rate" +// UpdateStatus is flags for updating limiter config. +type UpdateStatus uint32 + +// Flags for limiter. +const ( + eps float64 = 1e-8 + // QPSNoChange shows that limiter's config isn't changed. + QPSNoChange UpdateStatus = 1 << iota + // QPSChanged shows that limiter's config is changed and not deleted. + QPSChanged + // QPSDeleted shows that limiter's config is deleted. + QPSDeleted + // ConcurrencyNoChange shows that limiter's config isn't changed. + ConcurrencyNoChange + // ConcurrencyChanged shows that limiter's config is changed and not deleted. + ConcurrencyChanged + // ConcurrencyDeleted shows that limiter's config is deleted. + ConcurrencyDeleted + // InAllowList shows that limiter's config isn't changed because it is in in allow list. + InAllowList +) + // Option is used to create a limiter with the optional settings. // these setting is used to add a kind of limiter for a service -type Option func(string, *Limiter) +type Option func(string, *Limiter) UpdateStatus // AddLabelAllowList adds a label into allow list. // It means the given label will not be limited func AddLabelAllowList() Option { - return func(label string, l *Limiter) { + return func(label string, l *Limiter) UpdateStatus { l.labelAllowList[label] = struct{}{} + return 0 + } +} + +func updateConcurrencyConfig(l *Limiter, label string, limit uint64) UpdateStatus { + oldConcurrencyLimit, _ := l.GetConcurrencyLimiterStatus(label) + if oldConcurrencyLimit == limit { + return ConcurrencyNoChange + } + if limit < 1 { + l.ConcurrencyUnlimit(label) + return ConcurrencyDeleted + } + if limiter, exist := l.concurrencyLimiter.LoadOrStore(label, newConcurrencyLimiter(limit)); exist { + limiter.(*concurrencyLimiter).setLimit(limit) + } + return ConcurrencyChanged +} + +func updateQPSConfig(l *Limiter, label string, limit float64, burst int) UpdateStatus { + oldQPSLimit, oldBurst := l.GetQPSLimiterStatus(label) + + if (float64(oldQPSLimit)-limit < eps && float64(oldQPSLimit)-limit > -eps) && oldBurst == burst { + return QPSNoChange + } + if limit <= eps || burst < 1 { + l.QPSUnlimit(label) + return QPSDeleted } + if limiter, exist := l.qpsLimiter.LoadOrStore(label, NewRateLimiter(limit, burst)); exist { + limiter.(*RateLimiter).SetLimit(rate.Limit(limit)) + limiter.(*RateLimiter).SetBurst(burst) + } + return QPSChanged } // UpdateConcurrencyLimiter creates a concurrency limiter for a given label if it doesn't exist. func UpdateConcurrencyLimiter(limit uint64) Option { - return func(label string, l *Limiter) { + return func(label string, l *Limiter) UpdateStatus { if _, allow := l.labelAllowList[label]; allow { - return - } - if limiter, exist := l.concurrencyLimiter.LoadOrStore(label, newConcurrencyLimiter(limit)); exist { - limiter.(*concurrencyLimiter).setLimit(limit) + return InAllowList } + return updateConcurrencyConfig(l, label, limit) } } // UpdateQPSLimiter creates a QPS limiter for a given label if it doesn't exist. -func UpdateQPSLimiter(limit rate.Limit, burst int) Option { - return func(label string, l *Limiter) { +func UpdateQPSLimiter(limit float64, burst int) Option { + return func(label string, l *Limiter) UpdateStatus { if _, allow := l.labelAllowList[label]; allow { - return + return InAllowList } - if limiter, exist := l.qpsLimiter.LoadOrStore(label, NewRateLimiter(float64(limit), burst)); exist { - limiter.(*RateLimiter).SetLimit(limit) - limiter.(*RateLimiter).SetBurst(burst) + return updateQPSConfig(l, label, limit, burst) + } +} + +// UpdateDimensionConfig creates QPS limiter and concurrency limiter for a given label by config if it doesn't exist. +func UpdateDimensionConfig(cfg *DimensionConfig) Option { + return func(label string, l *Limiter) UpdateStatus { + if _, allow := l.labelAllowList[label]; allow { + return InAllowList } + status := updateQPSConfig(l, label, cfg.QPS, cfg.QPSBurst) + status |= updateConcurrencyConfig(l, label, cfg.ConcurrencyLimit) + return status } } From 1fa1d4f55ba5b37acdca7f191dff97a1f966cc71 Mon Sep 17 00:00:00 2001 From: buffer <1045931706@qq.com> Date: Wed, 1 Jun 2022 11:44:27 +0800 Subject: [PATCH 08/82] grafana: add some metrics about bucket (#5069) close tikv/pd#5068 Signed-off-by: bufferflies <1045931706@qq.com> --- metrics/grafana/pd.json | 492 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 484 insertions(+), 8 deletions(-) diff --git a/metrics/grafana/pd.json b/metrics/grafana/pd.json index 32747d85965..1a35f91bddf 100644 --- a/metrics/grafana/pd.json +++ b/metrics/grafana/pd.json @@ -7187,6 +7187,194 @@ "align": false, "alignLevel": null } + }, { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "${DS_TEST-CLUSTER}", + "description": "The inner status of balance Hot Region scheduler", + "fill": 0, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 51 + }, + "id": 1458, + "legend": { + "alignAsTable": true, + "avg": false, + "current": true, + "hideEmpty": true, + "hideZero": true, + "max": true, + "min": false, + "rightSide": true, + "show": true, + "sort": "current", + "sortDesc": true, + "total": false, + "values": true + }, + "lines": true, + "linewidth": 2, + "links": [], + "nullPointMode": "null", + "paceLength": 10, + "percentage": false, + "pointradius": 5, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "sum(rate(pd_scheduler_event_count{k8s_cluster=\"$k8s_cluster\", tidb_cluster=\"$tidb_cluster\", instance=\"$instance\", type=\"balance-hot-region-scheduler\"}[5m])) by (name)", + "format": "time_series", + "intervalFactor": 2, + "legendFormat": "{{name}}", + "metric": "pd_scheduler_event_count", + "refId": "A", + "step": 4 + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Balance Hot Region scheduler", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "ops", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + },{ + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "${DS_TEST-CLUSTER}", + "description": "The inner status of split bucket scheduler", + "fill": 0, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 51 + }, + "id": 1459, + "legend": { + "alignAsTable": true, + "avg": false, + "current": true, + "hideEmpty": true, + "hideZero": true, + "max": true, + "min": false, + "rightSide": true, + "show": true, + "sort": "current", + "sortDesc": true, + "total": false, + "values": true + }, + "lines": true, + "linewidth": 2, + "links": [], + "nullPointMode": "null", + "paceLength": 10, + "percentage": false, + "pointradius": 5, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "sum(rate(pd_scheduler_event_count{k8s_cluster=\"$k8s_cluster\", tidb_cluster=\"$tidb_cluster\", instance=\"$instance\", type=\"split-bucket-scheduler\"}[5m])) by (name)", + "format": "time_series", + "intervalFactor": 2, + "legendFormat": "{{name}}", + "metric": "pd_scheduler_event_count", + "refId": "A", + "step": 4 + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Split Bucket scheduler", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "ops", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } }, { "aliasColors": {}, @@ -7199,7 +7387,7 @@ "h": 8, "w": 24, "x": 0, - "y": 51 + "y": 59 }, "id": 108, "legend": { @@ -7304,7 +7492,7 @@ "h": 8, "w": 12, "x": 0, - "y": 59 + "y": 67 }, "id": 1424, "interval": null, @@ -7377,7 +7565,7 @@ "h": 8, "w": 12, "x": 12, - "y": 59 + "y": 67 }, "id": 141, "legend": { @@ -7469,7 +7657,7 @@ "h": 8, "w": 12, "x": 0, - "y": 67 + "y": 75 }, "id": 70, "legend": { @@ -7561,7 +7749,7 @@ "h": 8, "w": 12, "x": 12, - "y": 67 + "y": 75 }, "id": 71, "legend": { @@ -7652,7 +7840,7 @@ "h": 8, "w": 12, "x": 0, - "y": 75 + "y": 83 }, "id": 109, "legend": { @@ -7746,7 +7934,7 @@ "h": 8, "w": 12, "x": 12, - "y": 75 + "y": 83 }, "id": 110, "legend": { @@ -10616,7 +10804,7 @@ "h": 8, "w": 12, "x": 12, - "y": 55 + "y": 47 }, "id": 1403, "legend": { @@ -10697,6 +10885,198 @@ "align": false, "alignLevel": null } + }, { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "${DS_TEST-CLUSTER}", + "description": "The Interval of TIkv bucket report interval", + "editable": true, + "error": false, + "fill": 0, + "grid": {}, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 55 + }, + "id": 1451, + "legend": { + "alignAsTable": true, + "avg": true, + "current": true, + "hideEmpty": true, + "hideZero": true, + "max": true, + "min": false, + "rightSide": true, + "show": true, + "total": false, + "values": true + }, + "lines": true, + "linewidth": 1, + "links": [], + "nullPointMode": "null as zero", + "paceLength": 10, + "percentage": false, + "pointradius": 5, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "histogram_quantile(0.99, sum(rate(pd_server_bucket_report_interval_seconds_bucket{k8s_cluster=\"$k8s_cluster\", tidb_cluster=~\"$tidb_cluster.*\", store=~\"$store\"}[1m])) by (address, store, le))", + "format": "time_series", + "hide": false, + "intervalFactor": 2, + "legendFormat": "{{address}}-store-{{store}}", + "refId": "A", + "step": 4 + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "99% Bucket Report Interval", + "tooltip": { + "msResolution": false, + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "s", + "label": null, + "logBase": 1, + "max": null, + "min": "0", + "show": true + }, + { + "format": "s", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + },{ + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "${DS_TEST-CLUSTER}", + "description": "The State of Bucket Report", + "editable": true, + "error": false, + "fill": 0, + "grid": {}, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 63 + }, + "id": 1452, + "legend": { + "alignAsTable": true, + "avg": true, + "current": true, + "hideEmpty": true, + "hideZero": true, + "max": true, + "min": false, + "rightSide": true, + "show": true, + "total": false, + "values": true + }, + "lines": true, + "linewidth": 1, + "links": [], + "nullPointMode": "null as zero", + "paceLength": 10, + "percentage": false, + "pointradius": 5, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "sum(rate(pd_server_bucket_report{k8s_cluster=\"$k8s_cluster\", tidb_cluster=~\"$tidb_cluster.*\", store=~\"$store\", instance=\"$instance\"}[1m])) by (address, store, type,status)", + "format": "time_series", + "hide": false, + "intervalFactor": 2, + "legendFormat": "{{address}}-store-{{store}}", + "refId": "A", + "step": 4 + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Bucket Report State", + "tooltip": { + "msResolution": false, + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "opm", + "label": null, + "logBase": 1, + "max": null, + "min": "0", + "show": true + }, + { + "format": "s", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } } ], "repeat": null, @@ -11072,6 +11452,102 @@ "title": "Region Heartbeat Interval", "transparent": true, "type": "bargauge" + },{ + "datasource": "${DS_TEST-CLUSTER}", + "fieldConfig": { + "defaults": { + "custom": { + "align": null + }, + "mappings": [] + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 48 + }, + "id": 1454, + "interval": "", + "options": { + "displayMode": "lcd", + "orientation": "horizontal", + "reduceOptions": { + "calcs": [ + "mean" + ], + "fields": "", + "values": false + }, + "showUnfilled": true + }, + "pluginVersion": "7.1.5", + "repeatDirection": "h", + "targets": [ + { + "expr": "sum(delta(pd_server_bucket_report_interval_seconds_bucket{k8s_cluster=\"$k8s_cluster\", tidb_cluster=~\"$tidb_cluster.*\", instance=~\"$instance\"}[1m])) by (le)", + "format": "heatmap", + "hide": false, + "interval": "", + "legendFormat": "{{le}}", + "refId": "A" + } + ], + "timeFrom": null, + "timeShift": null, + "title": "Bucket Report Interval", + "transparent": true, + "type": "bargauge" + },{ + "datasource": "${DS_TEST-CLUSTER}", + "fieldConfig": { + "defaults": { + "custom": { + "align": null + }, + "mappings": [] + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 48 + }, + "id": 1455, + "interval": "", + "options": { + "displayMode": "lcd", + "orientation": "horizontal", + "reduceOptions": { + "calcs": [ + "mean" + ], + "fields": "", + "values": false + }, + "showUnfilled": true + }, + "pluginVersion": "7.1.5", + "repeatDirection": "h", + "targets": [ + { + "expr": "sum(delta(pd_scheduler_buckets_hot_degree_hist_bucket{k8s_cluster=\"$k8s_cluster\", tidb_cluster=~\"$tidb_cluster.*\", instance=~\"$instance\"}[1m])) by (le)", + "format": "heatmap", + "hide": false, + "interval": "", + "legendFormat": "{{le}}", + "refId": "A" + } + ], + "timeFrom": null, + "timeShift": null, + "title": "Hot Degree of Bucket", + "transparent": true, + "type": "bargauge" } ], "title": "Heartbeat distribution ", From ce6b35d4c57c645181316fc07c40a47387bc5403 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Wed, 1 Jun 2022 16:04:27 +0800 Subject: [PATCH 09/82] client, tests: testify the client tests (#5081) ref tikv/pd#4813 Testify the client tests. Signed-off-by: JmPotato --- client/client_test.go | 73 +-- client/go.mod | 2 +- client/go.sum | 24 +- client/option_test.go | 28 +- client/testutil/testutil.go | 9 +- pkg/testutil/testutil.go | 31 + scripts/check-testing-t.sh | 4 +- tests/client/client_test.go | 1002 ++++++++++++++++--------------- tests/client/client_tls_test.go | 87 ++- tests/client/go.mod | 2 +- tests/cluster.go | 21 + 11 files changed, 652 insertions(+), 631 deletions(-) diff --git a/client/client_test.go b/client/client_test.go index 7b9470ace5b..c80b78bb96b 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -16,38 +16,33 @@ package pd import ( "context" + "reflect" "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/client/testutil" "go.uber.org/goleak" "google.golang.org/grpc" ) -func Test(t *testing.T) { - TestingT(t) -} - func TestMain(m *testing.M) { goleak.VerifyTestMain(m, testutil.LeakOptions...) } -var _ = Suite(&testClientSuite{}) - -type testClientSuite struct{} - -func (s *testClientSuite) TestTsLessEqual(c *C) { - c.Assert(tsLessEqual(9, 9, 9, 9), IsTrue) - c.Assert(tsLessEqual(8, 9, 9, 8), IsTrue) - c.Assert(tsLessEqual(9, 8, 8, 9), IsFalse) - c.Assert(tsLessEqual(9, 8, 9, 6), IsFalse) - c.Assert(tsLessEqual(9, 6, 9, 8), IsTrue) +func TestTsLessEqual(t *testing.T) { + re := require.New(t) + re.True(tsLessEqual(9, 9, 9, 9)) + re.True(tsLessEqual(8, 9, 9, 8)) + re.False(tsLessEqual(9, 8, 8, 9)) + re.False(tsLessEqual(9, 8, 9, 6)) + re.True(tsLessEqual(9, 6, 9, 8)) } -func (s *testClientSuite) TestUpdateURLs(c *C) { +func TestUpdateURLs(t *testing.T) { + re := require.New(t) members := []*pdpb.Member{ {Name: "pd4", ClientUrls: []string{"tmp://pd4"}}, {Name: "pd1", ClientUrls: []string{"tmp://pd1"}}, @@ -63,40 +58,35 @@ func (s *testClientSuite) TestUpdateURLs(c *C) { cli := &baseClient{option: newOption()} cli.urls.Store([]string{}) cli.updateURLs(members[1:]) - c.Assert(cli.GetURLs(), DeepEquals, getURLs([]*pdpb.Member{members[1], members[3], members[2]})) + re.True(reflect.DeepEqual(getURLs([]*pdpb.Member{members[1], members[3], members[2]}), cli.GetURLs())) cli.updateURLs(members[1:]) - c.Assert(cli.GetURLs(), DeepEquals, getURLs([]*pdpb.Member{members[1], members[3], members[2]})) + re.True(reflect.DeepEqual(getURLs([]*pdpb.Member{members[1], members[3], members[2]}), cli.GetURLs())) cli.updateURLs(members) - c.Assert(cli.GetURLs(), DeepEquals, getURLs([]*pdpb.Member{members[1], members[3], members[2], members[0]})) + re.True(reflect.DeepEqual(getURLs([]*pdpb.Member{members[1], members[3], members[2], members[0]}), cli.GetURLs())) } const testClientURL = "tmp://test.url:5255" -var _ = Suite(&testClientCtxSuite{}) - -type testClientCtxSuite struct{} - -func (s *testClientCtxSuite) TestClientCtx(c *C) { +func TestClientCtx(t *testing.T) { + re := require.New(t) start := time.Now() ctx, cancel := context.WithTimeout(context.TODO(), time.Second*3) defer cancel() _, err := NewClientWithContext(ctx, []string{testClientURL}, SecurityOption{}) - c.Assert(err, NotNil) - c.Assert(time.Since(start), Less, time.Second*5) + re.Error(err) + re.Less(time.Since(start), time.Second*5) } -func (s *testClientCtxSuite) TestClientWithRetry(c *C) { +func TestClientWithRetry(t *testing.T) { + re := require.New(t) start := time.Now() _, err := NewClientWithContext(context.TODO(), []string{testClientURL}, SecurityOption{}, WithMaxErrorRetry(5)) - c.Assert(err, NotNil) - c.Assert(time.Since(start), Less, time.Second*10) + re.Error(err) + re.Less(time.Since(start), time.Second*10) } -var _ = Suite(&testClientDialOptionSuite{}) - -type testClientDialOptionSuite struct{} - -func (s *testClientDialOptionSuite) TestGRPCDialOption(c *C) { +func TestGRPCDialOption(t *testing.T) { + re := require.New(t) start := time.Now() ctx, cancel := context.WithTimeout(context.TODO(), 500*time.Millisecond) defer cancel() @@ -111,15 +101,12 @@ func (s *testClientDialOptionSuite) TestGRPCDialOption(c *C) { cli.urls.Store([]string{testClientURL}) cli.option.gRPCDialOptions = []grpc.DialOption{grpc.WithBlock()} err := cli.updateMember() - c.Assert(err, NotNil) - c.Assert(time.Since(start), Greater, 500*time.Millisecond) + re.Error(err) + re.Greater(time.Since(start), 500*time.Millisecond) } -var _ = Suite(&testTsoRequestSuite{}) - -type testTsoRequestSuite struct{} - -func (s *testTsoRequestSuite) TestTsoRequestWait(c *C) { +func TestTsoRequestWait(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) req := &tsoRequest{ done: make(chan error, 1), @@ -130,7 +117,7 @@ func (s *testTsoRequestSuite) TestTsoRequestWait(c *C) { } cancel() _, _, err := req.Wait() - c.Assert(errors.Cause(err), Equals, context.Canceled) + re.ErrorIs(errors.Cause(err), context.Canceled) ctx, cancel = context.WithCancel(context.Background()) req = &tsoRequest{ @@ -142,5 +129,5 @@ func (s *testTsoRequestSuite) TestTsoRequestWait(c *C) { } cancel() _, _, err = req.Wait() - c.Assert(errors.Cause(err), Equals, context.Canceled) + re.ErrorIs(errors.Cause(err), context.Canceled) } diff --git a/client/go.mod b/client/go.mod index 893cada680d..56380b0b51c 100644 --- a/client/go.mod +++ b/client/go.mod @@ -4,12 +4,12 @@ go 1.16 require ( github.com/opentracing/opentracing-go v1.2.0 - github.com/pingcap/check v0.0.0-20211026125417-57bd13f7b5f0 github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c github.com/pingcap/failpoint v0.0.0-20210918120811-547c13e3eb00 github.com/pingcap/kvproto v0.0.0-20220510035547-0e2f26c0a46a github.com/pingcap/log v0.0.0-20211215031037-e024ba4eb0ee github.com/prometheus/client_golang v1.11.0 + github.com/stretchr/testify v1.7.0 go.uber.org/goleak v1.1.11 go.uber.org/zap v1.20.0 google.golang.org/grpc v1.43.0 diff --git a/client/go.sum b/client/go.sum index 6682bdb2893..90019b9b382 100644 --- a/client/go.sum +++ b/client/go.sum @@ -69,7 +69,6 @@ github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/grpc-ecosystem/grpc-gateway v1.12.1/go.mod h1:8XEsbTttt/W+VvjtQhLACqCisSPWTxCZ7sBRjU6iH9c= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= @@ -84,8 +83,10 @@ github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+o github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= @@ -97,9 +98,6 @@ github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRW github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= -github.com/pingcap/check v0.0.0-20190102082844-67f458068fc8/go.mod h1:B1+S9LNcuMyLH/4HMTViQOJevkGiik3wW2AN9zb2fNQ= -github.com/pingcap/check v0.0.0-20211026125417-57bd13f7b5f0 h1:HVl5539r48eA+uDuX/ziBmQCxzT1pGrzWbKuXT46Bq0= -github.com/pingcap/check v0.0.0-20211026125417-57bd13f7b5f0/go.mod h1:PYMCGwN0JHjoqGr3HrZoD+b8Tgx8bKnArhSq8YVzUMc= github.com/pingcap/errors v0.11.0/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c h1:xpW9bvK+HuuTmyFqUwr+jcCvpVkK7sumiz+ko5H9eq4= @@ -108,7 +106,6 @@ github.com/pingcap/failpoint v0.0.0-20210918120811-547c13e3eb00 h1:C3N3itkduZXDZ github.com/pingcap/failpoint v0.0.0-20210918120811-547c13e3eb00/go.mod h1:4qGtCB0QK0wBzKtFEGDhxXnSnbQApw1gc9siScUl8ew= github.com/pingcap/kvproto v0.0.0-20220510035547-0e2f26c0a46a h1:TxdHGOFeNa1q1mVv6TgReayf26iI4F8PQUm6RnZ/V/E= github.com/pingcap/kvproto v0.0.0-20220510035547-0e2f26c0a46a/go.mod h1:OYtxs0786qojVTmkVeufx93xe+jUgm56GUYRIKnmaGI= -github.com/pingcap/log v0.0.0-20191012051959-b742a5d432e9/go.mod h1:4rbK1p9ILyIfb6hU7OG2CiWSqMXnp3JMbiaVJ6mvoY8= github.com/pingcap/log v0.0.0-20211215031037-e024ba4eb0ee h1:VO2t6IBpfvW34TdtD/G10VvnGqjLic1jzOuHjUb5VqM= github.com/pingcap/log v0.0.0-20211215031037-e024ba4eb0ee/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -137,7 +134,6 @@ github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4O github.com/prometheus/procfs v0.6.0 h1:mxy4L2jP6qMonqmq+aTtOx1ifVWUgG/TAmntgbh3xv4= github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= -github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= @@ -154,8 +150,6 @@ github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= -go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= -go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= @@ -163,21 +157,14 @@ go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A= go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI= go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= -go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= -go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= -go.uber.org/multierr v1.4.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/multierr v1.7.0 h1:zaiO/rmgFjbmCXdSYJWQcdvOCsthmdaHfr3Gm2Kx4Ec= go.uber.org/multierr v1.7.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak= -go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= -go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= -go.uber.org/zap v1.12.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= go.uber.org/zap v1.19.0/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI= go.uber.org/zap v1.20.0 h1:N4oPlghZwYG55MlU6LXk/Zp00FVNE9X9wrYO8CEs4lc= go.uber.org/zap v1.20.0/go.mod h1:wjWOCqI0f2ZZrJF/UufIOkiC8ii6tm1iqIsLo76RfJw= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -186,7 +173,6 @@ golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvx golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= @@ -244,10 +230,7 @@ golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191107010934-f79515f33823/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= @@ -290,8 +273,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXLknAOE8= gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= @@ -307,4 +290,3 @@ gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= diff --git a/client/option_test.go b/client/option_test.go index b3d044bbd1b..2a7f7824e12 100644 --- a/client/option_test.go +++ b/client/option_test.go @@ -15,43 +15,41 @@ package pd import ( + "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/client/testutil" ) -var _ = Suite(&testClientOptionSuite{}) - -type testClientOptionSuite struct{} - -func (s *testClientSuite) TestDynamicOptionChange(c *C) { +func TestDynamicOptionChange(t *testing.T) { + re := require.New(t) o := newOption() // Check the default value setting. - c.Assert(o.getMaxTSOBatchWaitInterval(), Equals, defaultMaxTSOBatchWaitInterval) - c.Assert(o.getEnableTSOFollowerProxy(), Equals, defaultEnableTSOFollowerProxy) + re.Equal(defaultMaxTSOBatchWaitInterval, o.getMaxTSOBatchWaitInterval()) + re.Equal(defaultEnableTSOFollowerProxy, o.getEnableTSOFollowerProxy()) // Check the invalid value setting. - c.Assert(o.setMaxTSOBatchWaitInterval(time.Second), NotNil) - c.Assert(o.getMaxTSOBatchWaitInterval(), Equals, defaultMaxTSOBatchWaitInterval) + re.NotNil(o.setMaxTSOBatchWaitInterval(time.Second)) + re.Equal(defaultMaxTSOBatchWaitInterval, o.getMaxTSOBatchWaitInterval()) expectInterval := time.Millisecond o.setMaxTSOBatchWaitInterval(expectInterval) - c.Assert(o.getMaxTSOBatchWaitInterval(), Equals, expectInterval) + re.Equal(expectInterval, o.getMaxTSOBatchWaitInterval()) expectInterval = time.Duration(float64(time.Millisecond) * 0.5) o.setMaxTSOBatchWaitInterval(expectInterval) - c.Assert(o.getMaxTSOBatchWaitInterval(), Equals, expectInterval) + re.Equal(expectInterval, o.getMaxTSOBatchWaitInterval()) expectInterval = time.Duration(float64(time.Millisecond) * 1.5) o.setMaxTSOBatchWaitInterval(expectInterval) - c.Assert(o.getMaxTSOBatchWaitInterval(), Equals, expectInterval) + re.Equal(expectInterval, o.getMaxTSOBatchWaitInterval()) expectBool := true o.setEnableTSOFollowerProxy(expectBool) // Check the value changing notification. - testutil.WaitUntil(c, func() bool { + testutil.WaitUntil(t, func() bool { <-o.enableTSOFollowerProxyCh return true }) - c.Assert(o.getEnableTSOFollowerProxy(), Equals, expectBool) + re.Equal(expectBool, o.getEnableTSOFollowerProxy()) // Check whether any data will be sent to the channel. // It will panic if the test fails. close(o.enableTSOFollowerProxyCh) diff --git a/client/testutil/testutil.go b/client/testutil/testutil.go index 3627566ecfb..095a31ae74a 100644 --- a/client/testutil/testutil.go +++ b/client/testutil/testutil.go @@ -15,9 +15,8 @@ package testutil import ( + "testing" "time" - - "github.com/pingcap/check" ) const ( @@ -45,8 +44,8 @@ func WithSleepInterval(sleep time.Duration) WaitOption { } // WaitUntil repeatedly evaluates f() for a period of time, util it returns true. -func WaitUntil(c *check.C, f func() bool, opts ...WaitOption) { - c.Log("wait start") +func WaitUntil(t *testing.T, f func() bool, opts ...WaitOption) { + t.Log("wait start") option := &WaitOp{ retryTimes: waitMaxRetry, sleepInterval: waitRetrySleep, @@ -60,5 +59,5 @@ func WaitUntil(c *check.C, f func() bool, opts ...WaitOption) { } time.Sleep(option.sleepInterval) } - c.Fatal("wait timeout") + t.Fatal("wait timeout") } diff --git a/pkg/testutil/testutil.go b/pkg/testutil/testutil.go index c3c917d7b3a..dfb209c648d 100644 --- a/pkg/testutil/testutil.go +++ b/pkg/testutil/testutil.go @@ -17,10 +17,12 @@ package testutil import ( "os" "strings" + "testing" "time" "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "google.golang.org/grpc" ) @@ -71,6 +73,26 @@ func WaitUntil(c *check.C, f CheckFunc, opts ...WaitOption) { c.Fatal("wait timeout") } +// WaitUntilWithTestingT repeatedly evaluates f() for a period of time, util it returns true. +// NOTICE: this is a temporary function that we will be used to replace `WaitUntil` later. +func WaitUntilWithTestingT(t *testing.T, f CheckFunc, opts ...WaitOption) { + t.Log("wait start") + option := &WaitOp{ + retryTimes: waitMaxRetry, + sleepInterval: waitRetrySleep, + } + for _, opt := range opts { + opt(option) + } + for i := 0; i < option.retryTimes; i++ { + if f() { + return + } + time.Sleep(option.sleepInterval) + } + t.Fatal("wait timeout") +} + // NewRequestHeader creates a new request header. func NewRequestHeader(clusterID uint64) *pdpb.RequestHeader { return &pdpb.RequestHeader{ @@ -86,6 +108,15 @@ func MustNewGrpcClient(c *check.C, addr string) pdpb.PDClient { return pdpb.NewPDClient(conn) } +// MustNewGrpcClientWithTestify must create a new grpc client. +// NOTICE: this is a temporary function that we will be used to replace `MustNewGrpcClient` later. +func MustNewGrpcClientWithTestify(re *require.Assertions, addr string) pdpb.PDClient { + conn, err := grpc.Dial(strings.TrimPrefix(addr, "http://"), grpc.WithInsecure()) + + re.NoError(err) + return pdpb.NewPDClient(conn) +} + // CleanServer is used to clean data directory. func CleanServer(dataDir string) { // Clean data directory diff --git a/scripts/check-testing-t.sh b/scripts/check-testing-t.sh index 0697a007480..6d107b5a0d1 100755 --- a/scripts/check-testing-t.sh +++ b/scripts/check-testing-t.sh @@ -1,5 +1,7 @@ #!/bin/bash +# TODO: remove this script after migrating all tests to the new test framework. + # Check if there are any packages foget to add `TestingT` when use "github.com/pingcap/check". res=$(diff <(grep -rl --include=\*_test.go "github.com/pingcap/check" . | xargs -L 1 dirname | sort -u) \ @@ -13,7 +15,7 @@ fi # Check if there are duplicated `TestingT` in package. -res=$(grep -r --include=\*_test.go "TestingT(" . | cut -f1 | xargs -L 1 dirname | sort | uniq -d) +res=$(grep -r --include=\*_test.go "TestingT(t)" . | cut -f1 | xargs -L 1 dirname | sort | uniq -d) if [ "$res" ]; then echo "following packages may have duplicated TestingT:" diff --git a/tests/client/client_test.go b/tests/client/client_test.go index 8b195eaa587..3afda979c44 100644 --- a/tests/client/client_test.go +++ b/tests/client/client_test.go @@ -21,16 +21,18 @@ import ( "fmt" "math" "path" + "reflect" "sort" "sync" "testing" "time" "github.com/gogo/protobuf/proto" - . "github.com/pingcap/check" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" pd "github.com/tikv/pd/client" "github.com/tikv/pd/pkg/assertutil" "github.com/tikv/pd/pkg/mock/mockid" @@ -50,30 +52,10 @@ const ( tsoRequestRound = 30 ) -func Test(t *testing.T) { - TestingT(t) -} - func TestMain(m *testing.M) { goleak.VerifyTestMain(m, testutil.LeakOptions...) } -var _ = Suite(&clientTestSuite{}) - -type clientTestSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *clientTestSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) - server.EnableZap = true -} - -func (s *clientTestSuite) TearDownSuite(c *C) { - s.cancel() -} - type client interface { GetLeaderAddr() string ScheduleCheckLeader() @@ -81,75 +63,81 @@ type client interface { GetAllocatorLeaderURLs() map[string]string } -func (s *clientTestSuite) TestClientLeaderChange(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 3) - c.Assert(err, IsNil) +func TestClientLeaderChange(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 3) + re.NoError(err) defer cluster.Destroy() - endpoints := s.runServer(c, cluster) - cli := setupCli(c, s.ctx, endpoints) + endpoints := runServer(re, cluster) + cli := setupCli(re, ctx, endpoints) var ts1, ts2 uint64 - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { p1, l1, err := cli.GetTS(context.TODO()) if err == nil { ts1 = tsoutil.ComposeTS(p1, l1) return true } - c.Log(err) + t.Log(err) return false }) - c.Assert(cluster.CheckTSOUnique(ts1), IsTrue) + re.True(cluster.CheckTSOUnique(ts1)) leader := cluster.GetLeader() - waitLeader(c, cli.(client), cluster.GetServer(leader).GetConfig().ClientUrls) + waitLeader(t, cli.(client), cluster.GetServer(leader).GetConfig().ClientUrls) err = cluster.GetServer(leader).Stop() - c.Assert(err, IsNil) + re.NoError(err) leader = cluster.WaitLeader() - c.Assert(leader, Not(Equals), "") - waitLeader(c, cli.(client), cluster.GetServer(leader).GetConfig().ClientUrls) + re.NotEmpty(leader) + waitLeader(t, cli.(client), cluster.GetServer(leader).GetConfig().ClientUrls) // Check TS won't fall back after leader changed. - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { p2, l2, err := cli.GetTS(context.TODO()) if err == nil { ts2 = tsoutil.ComposeTS(p2, l2) return true } - c.Log(err) + t.Log(err) return false }) - c.Assert(cluster.CheckTSOUnique(ts2), IsTrue) - c.Assert(ts1, Less, ts2) + re.True(cluster.CheckTSOUnique(ts2)) + re.Less(ts1, ts2) // Check URL list. cli.Close() urls := cli.(client).GetURLs() sort.Strings(urls) sort.Strings(endpoints) - c.Assert(urls, DeepEquals, endpoints) + re.True(reflect.DeepEqual(endpoints, urls)) } -func (s *clientTestSuite) TestLeaderTransfer(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 2) - c.Assert(err, IsNil) +func TestLeaderTransfer(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 2) + re.NoError(err) defer cluster.Destroy() - endpoints := s.runServer(c, cluster) - cli := setupCli(c, s.ctx, endpoints) + endpoints := runServer(re, cluster) + cli := setupCli(re, ctx, endpoints) var lastTS uint64 - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { physical, logical, err := cli.GetTS(context.TODO()) if err == nil { lastTS = tsoutil.ComposeTS(physical, logical) return true } - c.Log(err) + t.Log(err) return false }) - c.Assert(cluster.CheckTSOUnique(lastTS), IsTrue) + re.True(cluster.CheckTSOUnique(lastTS)) // Start a goroutine the make sure TS won't fall back. quit := make(chan struct{}) @@ -167,8 +155,8 @@ func (s *clientTestSuite) TestLeaderTransfer(c *C) { physical, logical, err := cli.GetTS(context.TODO()) if err == nil { ts := tsoutil.ComposeTS(physical, logical) - c.Assert(cluster.CheckTSOUnique(ts), IsTrue) - c.Assert(lastTS, Less, ts) + re.True(cluster.CheckTSOUnique(ts)) + re.Less(lastTS, ts) lastTS = ts } time.Sleep(time.Millisecond) @@ -179,69 +167,75 @@ func (s *clientTestSuite) TestLeaderTransfer(c *C) { for i := 0; i < 5; i++ { oldLeaderName := cluster.WaitLeader() err := cluster.GetServer(oldLeaderName).ResignLeader() - c.Assert(err, IsNil) + re.NoError(err) newLeaderName := cluster.WaitLeader() - c.Assert(newLeaderName, Not(Equals), oldLeaderName) + re.NotEqual(oldLeaderName, newLeaderName) } close(quit) wg.Wait() } // More details can be found in this issue: https://github.com/tikv/pd/issues/4884 -func (s *clientTestSuite) TestUpdateAfterResetTSO(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 2) - c.Assert(err, IsNil) +func TestUpdateAfterResetTSO(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 2) + re.NoError(err) defer cluster.Destroy() - endpoints := s.runServer(c, cluster) - cli := setupCli(c, s.ctx, endpoints) + endpoints := runServer(re, cluster) + cli := setupCli(re, ctx, endpoints) - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { _, _, err := cli.GetTS(context.TODO()) return err == nil }) // Transfer leader to trigger the TSO resetting. - c.Assert(failpoint.Enable("github.com/tikv/pd/server/updateAfterResetTSO", "return(true)"), IsNil) + re.Nil(failpoint.Enable("github.com/tikv/pd/server/updateAfterResetTSO", "return(true)")) oldLeaderName := cluster.WaitLeader() err = cluster.GetServer(oldLeaderName).ResignLeader() - c.Assert(err, IsNil) - c.Assert(failpoint.Disable("github.com/tikv/pd/server/updateAfterResetTSO"), IsNil) + re.NoError(err) + re.Nil(failpoint.Disable("github.com/tikv/pd/server/updateAfterResetTSO")) newLeaderName := cluster.WaitLeader() - c.Assert(newLeaderName, Not(Equals), oldLeaderName) + re.NotEqual(oldLeaderName, newLeaderName) // Request a new TSO. - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { _, _, err := cli.GetTS(context.TODO()) return err == nil }) // Transfer leader back. - c.Assert(failpoint.Enable("github.com/tikv/pd/server/tso/delaySyncTimestamp", `return(true)`), IsNil) + re.Nil(failpoint.Enable("github.com/tikv/pd/server/tso/delaySyncTimestamp", `return(true)`)) err = cluster.GetServer(newLeaderName).ResignLeader() - c.Assert(err, IsNil) + re.NoError(err) // Should NOT panic here. - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { _, _, err := cli.GetTS(context.TODO()) return err == nil }) - c.Assert(failpoint.Disable("github.com/tikv/pd/server/tso/delaySyncTimestamp"), IsNil) + re.Nil(failpoint.Disable("github.com/tikv/pd/server/tso/delaySyncTimestamp")) } -func (s *clientTestSuite) TestTSOAllocatorLeader(c *C) { +func TestTSOAllocatorLeader(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() dcLocationConfig := map[string]string{ "pd1": "dc-1", "pd2": "dc-2", "pd3": "dc-3", } dcLocationNum := len(dcLocationConfig) - cluster, err := tests.NewTestCluster(s.ctx, dcLocationNum, func(conf *config.Config, serverName string) { + cluster, err := tests.NewTestCluster(ctx, dcLocationNum, func(conf *config.Config, serverName string) { conf.EnableLocalTSO = true conf.Labels[config.ZoneLabel] = dcLocationConfig[serverName] }) - c.Assert(err, IsNil) + re.NoError(err) defer cluster.Destroy() err = cluster.RunInitialServers() - c.Assert(err, IsNil) - cluster.WaitAllLeaders(c, dcLocationConfig) + re.NoError(err) + cluster.WaitAllLeadersWithTestingT(t, dcLocationConfig) var ( testServers = cluster.GetServers() @@ -255,13 +249,13 @@ func (s *clientTestSuite) TestTSOAllocatorLeader(c *C) { var allocatorLeaderMap = make(map[string]string) for _, dcLocation := range dcLocationConfig { var pdName string - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { pdName = cluster.WaitAllocatorLeader(dcLocation) return len(pdName) > 0 }) allocatorLeaderMap[dcLocation] = pdName } - cli := setupCli(c, s.ctx, endpoints) + cli := setupCli(re, ctx, endpoints) // Check allocator leaders URL map. cli.Close() @@ -270,27 +264,30 @@ func (s *clientTestSuite) TestTSOAllocatorLeader(c *C) { urls := cli.(client).GetURLs() sort.Strings(urls) sort.Strings(endpoints) - c.Assert(urls, DeepEquals, endpoints) + re.True(reflect.DeepEqual(endpoints, urls)) continue } pdName, exist := allocatorLeaderMap[dcLocation] - c.Assert(exist, IsTrue) - c.Assert(len(pdName), Greater, 0) + re.True(exist) + re.Greater(len(pdName), 0) pdURL, exist := endpointsMap[pdName] - c.Assert(exist, IsTrue) - c.Assert(len(pdURL), Greater, 0) - c.Assert(url, Equals, pdURL) + re.True(exist) + re.Greater(len(pdURL), 0) + re.Equal(pdURL, url) } } -func (s *clientTestSuite) TestTSOFollowerProxy(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 3) - c.Assert(err, IsNil) +func TestTSOFollowerProxy(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 3) + re.NoError(err) defer cluster.Destroy() - endpoints := s.runServer(c, cluster) - cli1 := setupCli(c, s.ctx, endpoints) - cli2 := setupCli(c, s.ctx, endpoints) + endpoints := runServer(re, cluster) + cli1 := setupCli(re, ctx, endpoints) + cli2 := setupCli(re, ctx, endpoints) cli2.UpdateOption(pd.EnableTSOFollowerProxy, true) var wg sync.WaitGroup @@ -301,15 +298,15 @@ func (s *clientTestSuite) TestTSOFollowerProxy(c *C) { var lastTS uint64 for i := 0; i < tsoRequestRound; i++ { physical, logical, err := cli2.GetTS(context.Background()) - c.Assert(err, IsNil) + re.NoError(err) ts := tsoutil.ComposeTS(physical, logical) - c.Assert(lastTS, Less, ts) + re.Less(lastTS, ts) lastTS = ts // After requesting with the follower proxy, request with the leader directly. physical, logical, err = cli1.GetTS(context.Background()) - c.Assert(err, IsNil) + re.NoError(err) ts = tsoutil.ComposeTS(physical, logical) - c.Assert(lastTS, Less, ts) + re.Less(lastTS, ts) lastTS = ts } }() @@ -317,71 +314,79 @@ func (s *clientTestSuite) TestTSOFollowerProxy(c *C) { wg.Wait() } -func (s *clientTestSuite) TestGlobalAndLocalTSO(c *C) { +func TestGlobalAndLocalTSO(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() dcLocationConfig := map[string]string{ "pd1": "dc-1", "pd2": "dc-2", "pd3": "dc-3", } dcLocationNum := len(dcLocationConfig) - cluster, err := tests.NewTestCluster(s.ctx, dcLocationNum, func(conf *config.Config, serverName string) { + cluster, err := tests.NewTestCluster(ctx, dcLocationNum, func(conf *config.Config, serverName string) { conf.EnableLocalTSO = true conf.Labels[config.ZoneLabel] = dcLocationConfig[serverName] }) - c.Assert(err, IsNil) + re.NoError(err) defer cluster.Destroy() - endpoints := s.runServer(c, cluster) - cli := setupCli(c, s.ctx, endpoints) + endpoints := runServer(re, cluster) + cli := setupCli(re, ctx, endpoints) // Wait for all nodes becoming healthy. time.Sleep(time.Second * 5) // Join a new dc-location - pd4, err := cluster.Join(s.ctx, func(conf *config.Config, serverName string) { + pd4, err := cluster.Join(ctx, func(conf *config.Config, serverName string) { conf.EnableLocalTSO = true conf.Labels[config.ZoneLabel] = "dc-4" }) - c.Assert(err, IsNil) + re.NoError(err) err = pd4.Run() - c.Assert(err, IsNil) + re.NoError(err) dcLocationConfig["pd4"] = "dc-4" cluster.CheckClusterDCLocation() - cluster.WaitAllLeaders(c, dcLocationConfig) + cluster.WaitAllLeadersWithTestingT(t, dcLocationConfig) // Test a nonexistent dc-location for Local TSO p, l, err := cli.GetLocalTS(context.TODO(), "nonexistent-dc") - c.Assert(p, Equals, int64(0)) - c.Assert(l, Equals, int64(0)) - c.Assert(err, NotNil) - c.Assert(err, ErrorMatches, ".*unknown dc-location.*") + re.Equal(int64(0), p) + re.Equal(int64(0), l, int64(0)) + re.Error(err) + re.Contains(err.Error(), "unknown dc-location") wg := &sync.WaitGroup{} - requestGlobalAndLocalTSO(c, wg, dcLocationConfig, cli) + requestGlobalAndLocalTSO(re, wg, dcLocationConfig, cli) // assert global tso after resign leader - c.Assert(failpoint.Enable("github.com/tikv/pd/client/skipUpdateMember", `return(true)`), IsNil) + re.Nil(failpoint.Enable("github.com/tikv/pd/client/skipUpdateMember", `return(true)`)) err = cluster.ResignLeader() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() - _, _, err = cli.GetTS(s.ctx) - c.Assert(err, NotNil) - c.Assert(pd.IsLeaderChange(err), IsTrue) - _, _, err = cli.GetTS(s.ctx) - c.Assert(err, IsNil) - c.Assert(failpoint.Disable("github.com/tikv/pd/client/skipUpdateMember"), IsNil) + _, _, err = cli.GetTS(ctx) + re.Error(err) + re.True(pd.IsLeaderChange(err)) + _, _, err = cli.GetTS(ctx) + re.NoError(err) + re.Nil(failpoint.Disable("github.com/tikv/pd/client/skipUpdateMember")) // Test the TSO follower proxy while enabling the Local TSO. cli.UpdateOption(pd.EnableTSOFollowerProxy, true) // Sleep a while here to prevent from canceling the ongoing TSO request. time.Sleep(time.Millisecond * 50) - requestGlobalAndLocalTSO(c, wg, dcLocationConfig, cli) + requestGlobalAndLocalTSO(re, wg, dcLocationConfig, cli) cli.UpdateOption(pd.EnableTSOFollowerProxy, false) time.Sleep(time.Millisecond * 50) - requestGlobalAndLocalTSO(c, wg, dcLocationConfig, cli) + requestGlobalAndLocalTSO(re, wg, dcLocationConfig, cli) } -func requestGlobalAndLocalTSO(c *C, wg *sync.WaitGroup, dcLocationConfig map[string]string, cli pd.Client) { +func requestGlobalAndLocalTSO( + re *require.Assertions, + wg *sync.WaitGroup, + dcLocationConfig map[string]string, + cli pd.Client, +) { for _, dcLocation := range dcLocationConfig { wg.Add(tsoRequestConcurrencyNumber) for i := 0; i < tsoRequestConcurrencyNumber; i++ { @@ -390,131 +395,143 @@ func requestGlobalAndLocalTSO(c *C, wg *sync.WaitGroup, dcLocationConfig map[str var lastTS uint64 for i := 0; i < tsoRequestRound; i++ { globalPhysical1, globalLogical1, err := cli.GetTS(context.TODO()) - c.Assert(err, IsNil) + re.NoError(err) globalTS1 := tsoutil.ComposeTS(globalPhysical1, globalLogical1) localPhysical, localLogical, err := cli.GetLocalTS(context.TODO(), dc) - c.Assert(err, IsNil) + re.NoError(err) localTS := tsoutil.ComposeTS(localPhysical, localLogical) globalPhysical2, globalLogical2, err := cli.GetTS(context.TODO()) - c.Assert(err, IsNil) + re.NoError(err) globalTS2 := tsoutil.ComposeTS(globalPhysical2, globalLogical2) - c.Assert(lastTS, Less, globalTS1) - c.Assert(globalTS1, Less, localTS) - c.Assert(localTS, Less, globalTS2) + re.Less(lastTS, globalTS1) + re.Less(globalTS1, localTS) + re.Less(localTS, globalTS2) lastTS = globalTS2 } - c.Assert(lastTS, Greater, uint64(0)) + re.Greater(lastTS, uint64(0)) }(dcLocation) } } wg.Wait() } -func (s *clientTestSuite) TestCustomTimeout(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 1) - c.Assert(err, IsNil) +func TestCustomTimeout(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 1) + re.NoError(err) defer cluster.Destroy() - endpoints := s.runServer(c, cluster) - cli := setupCli(c, s.ctx, endpoints, pd.WithCustomTimeoutOption(1*time.Second)) + endpoints := runServer(re, cluster) + cli := setupCli(re, ctx, endpoints, pd.WithCustomTimeoutOption(1*time.Second)) start := time.Now() - c.Assert(failpoint.Enable("github.com/tikv/pd/server/customTimeout", "return(true)"), IsNil) + re.Nil(failpoint.Enable("github.com/tikv/pd/server/customTimeout", "return(true)")) _, err = cli.GetAllStores(context.TODO()) - c.Assert(failpoint.Disable("github.com/tikv/pd/server/customTimeout"), IsNil) - c.Assert(err, NotNil) - c.Assert(time.Since(start), GreaterEqual, 1*time.Second) - c.Assert(time.Since(start), Less, 2*time.Second) + re.Nil(failpoint.Disable("github.com/tikv/pd/server/customTimeout")) + re.Error(err) + re.GreaterOrEqual(time.Since(start), 1*time.Second) + re.Less(time.Since(start), 2*time.Second) } -func (s *clientTestSuite) TestGetRegionFromFollowerClient(c *C) { +func TestGetRegionFromFollowerClient(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() pd.LeaderHealthCheckInterval = 100 * time.Millisecond - cluster, err := tests.NewTestCluster(s.ctx, 3) - c.Assert(err, IsNil) + cluster, err := tests.NewTestCluster(ctx, 3) + re.NoError(err) defer cluster.Destroy() - endpoints := s.runServer(c, cluster) - cli := setupCli(c, s.ctx, endpoints, pd.WithForwardingOption(true)) + endpoints := runServer(re, cluster) + cli := setupCli(re, ctx, endpoints, pd.WithForwardingOption(true)) - c.Assert(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork1", "return(true)"), IsNil) + re.Nil(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork1", "return(true)")) time.Sleep(200 * time.Millisecond) r, err := cli.GetRegion(context.Background(), []byte("a")) - c.Assert(err, IsNil) - c.Assert(r, NotNil) + re.NoError(err) + re.NotNil(r) - c.Assert(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork1"), IsNil) + re.Nil(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork1")) time.Sleep(200 * time.Millisecond) r, err = cli.GetRegion(context.Background(), []byte("a")) - c.Assert(err, IsNil) - c.Assert(r, NotNil) + re.NoError(err) + re.NotNil(r) } // case 1: unreachable -> normal -func (s *clientTestSuite) TestGetTsoFromFollowerClient1(c *C) { +func TestGetTsoFromFollowerClient1(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() pd.LeaderHealthCheckInterval = 100 * time.Millisecond - cluster, err := tests.NewTestCluster(s.ctx, 3) - c.Assert(err, IsNil) + cluster, err := tests.NewTestCluster(ctx, 3) + re.NoError(err) defer cluster.Destroy() - endpoints := s.runServer(c, cluster) - cli := setupCli(c, s.ctx, endpoints, pd.WithForwardingOption(true)) + endpoints := runServer(re, cluster) + cli := setupCli(re, ctx, endpoints, pd.WithForwardingOption(true)) - c.Assert(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork", "return(true)"), IsNil) + re.Nil(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork", "return(true)")) var lastTS uint64 - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { physical, logical, err := cli.GetTS(context.TODO()) if err == nil { lastTS = tsoutil.ComposeTS(physical, logical) return true } - c.Log(err) + t.Log(err) return false }) - lastTS = checkTS(c, cli, lastTS) - c.Assert(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork"), IsNil) + lastTS = checkTS(re, cli, lastTS) + re.Nil(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork")) time.Sleep(2 * time.Second) - checkTS(c, cli, lastTS) + checkTS(re, cli, lastTS) } // case 2: unreachable -> leader transfer -> normal -func (s *clientTestSuite) TestGetTsoFromFollowerClient2(c *C) { +func TestGetTsoFromFollowerClient2(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() pd.LeaderHealthCheckInterval = 100 * time.Millisecond - cluster, err := tests.NewTestCluster(s.ctx, 3) - c.Assert(err, IsNil) + cluster, err := tests.NewTestCluster(ctx, 3) + re.NoError(err) defer cluster.Destroy() - endpoints := s.runServer(c, cluster) - cli := setupCli(c, s.ctx, endpoints, pd.WithForwardingOption(true)) + endpoints := runServer(re, cluster) + cli := setupCli(re, ctx, endpoints, pd.WithForwardingOption(true)) - c.Assert(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork", "return(true)"), IsNil) + re.Nil(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork", "return(true)")) var lastTS uint64 - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { physical, logical, err := cli.GetTS(context.TODO()) if err == nil { lastTS = tsoutil.ComposeTS(physical, logical) return true } - c.Log(err) + t.Log(err) return false }) - lastTS = checkTS(c, cli, lastTS) - c.Assert(cluster.GetServer(cluster.GetLeader()).ResignLeader(), IsNil) + lastTS = checkTS(re, cli, lastTS) + re.NoError(cluster.GetServer(cluster.GetLeader()).ResignLeader()) cluster.WaitLeader() - lastTS = checkTS(c, cli, lastTS) + lastTS = checkTS(re, cli, lastTS) - c.Assert(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork"), IsNil) + re.Nil(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork")) time.Sleep(5 * time.Second) - checkTS(c, cli, lastTS) + checkTS(re, cli, lastTS) } -func checkTS(c *C, cli pd.Client, lastTS uint64) uint64 { +func checkTS(re *require.Assertions, cli pd.Client, lastTS uint64) uint64 { for i := 0; i < tsoRequestRound; i++ { physical, logical, err := cli.GetTS(context.TODO()) if err == nil { ts := tsoutil.ComposeTS(physical, logical) - c.Assert(lastTS, Less, ts) + re.Less(lastTS, ts) lastTS = ts } time.Sleep(time.Millisecond) @@ -522,12 +539,12 @@ func checkTS(c *C, cli pd.Client, lastTS uint64) uint64 { return lastTS } -func (s *clientTestSuite) runServer(c *C, cluster *tests.TestCluster) []string { +func runServer(re *require.Assertions, cluster *tests.TestCluster) []string { err := cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() leaderServer := cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) + re.NoError(leaderServer.BootstrapCluster()) testServers := cluster.GetServers() endpoints := make([]string, 0, len(testServers)) @@ -537,32 +554,80 @@ func (s *clientTestSuite) runServer(c *C, cluster *tests.TestCluster) []string { return endpoints } -func setupCli(c *C, ctx context.Context, endpoints []string, opts ...pd.ClientOption) pd.Client { +func setupCli(re *require.Assertions, ctx context.Context, endpoints []string, opts ...pd.ClientOption) pd.Client { cli, err := pd.NewClientWithContext(ctx, endpoints, pd.SecurityOption{}, opts...) - c.Assert(err, IsNil) + re.NoError(err) return cli } -func waitLeader(c *C, cli client, leader string) { - testutil.WaitUntil(c, func() bool { +func waitLeader(t *testing.T, cli client, leader string) { + testutil.WaitUntilWithTestingT(t, func() bool { cli.ScheduleCheckLeader() return cli.GetLeaderAddr() == leader }) } -func (s *clientTestSuite) TestCloseClient(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 1) - c.Assert(err, IsNil) +func TestConfigTTLAfterTransferLeader(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 3) + re.NoError(err) + defer cluster.Destroy() + err = cluster.RunInitialServers() + re.NoError(err) + leader := cluster.GetServer(cluster.WaitLeader()) + re.NoError(leader.BootstrapCluster()) + addr := fmt.Sprintf("%s/pd/api/v1/config?ttlSecond=5", leader.GetAddr()) + postData, err := json.Marshal(map[string]interface{}{ + "schedule.max-snapshot-count": 999, + "schedule.enable-location-replacement": false, + "schedule.max-merge-region-size": 999, + "schedule.max-merge-region-keys": 999, + "schedule.scheduler-max-waiting-operator": 999, + "schedule.leader-schedule-limit": 999, + "schedule.region-schedule-limit": 999, + "schedule.hot-region-schedule-limit": 999, + "schedule.replica-schedule-limit": 999, + "schedule.merge-schedule-limit": 999, + }) + re.NoError(err) + resp, err := leader.GetHTTPClient().Post(addr, "application/json", bytes.NewBuffer(postData)) + resp.Body.Close() + re.NoError(err) + time.Sleep(2 * time.Second) + re.NoError(leader.Destroy()) + time.Sleep(2 * time.Second) + leader = cluster.GetServer(cluster.WaitLeader()) + re.NotNil(leader) + options := leader.GetPersistOptions() + re.NotNil(options) + re.Equal(uint64(999), options.GetMaxSnapshotCount()) + re.False(options.IsLocationReplacementEnabled()) + re.Equal(uint64(999), options.GetMaxMergeRegionSize()) + re.Equal(uint64(999), options.GetMaxMergeRegionKeys()) + re.Equal(uint64(999), options.GetSchedulerMaxWaitingOperator()) + re.Equal(uint64(999), options.GetLeaderScheduleLimit()) + re.Equal(uint64(999), options.GetRegionScheduleLimit()) + re.Equal(uint64(999), options.GetHotRegionScheduleLimit()) + re.Equal(uint64(999), options.GetReplicaScheduleLimit()) + re.Equal(uint64(999), options.GetMergeScheduleLimit()) +} + +func TestCloseClient(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 1) + re.NoError(err) defer cluster.Destroy() - endpoints := s.runServer(c, cluster) - cli := setupCli(c, s.ctx, endpoints) + endpoints := runServer(re, cluster) + cli := setupCli(re, ctx, endpoints) cli.GetTSAsync(context.TODO()) time.Sleep(time.Second) cli.Close() } -var _ = Suite(&testClientSuite{}) - type idAllocator struct { allocator *mockid.IDAllocator } @@ -605,7 +670,8 @@ var ( } ) -type testClientSuite struct { +type clientTestSuite struct { + suite.Suite cleanup server.CleanupFunc ctx context.Context clean context.CancelFunc @@ -617,38 +683,34 @@ type testClientSuite struct { reportBucket pdpb.PD_ReportBucketsClient } -func checkerWithNilAssert(c *C) *assertutil.Checker { - checker := assertutil.NewChecker(c.FailNow) - checker.IsNil = func(obtained interface{}) { - c.Assert(obtained, IsNil) - } - return checker +func TestClientTestSuite(t *testing.T) { + suite.Run(t, new(clientTestSuite)) } -func (s *testClientSuite) SetUpSuite(c *C) { +func (suite *clientTestSuite) SetupSuite() { var err error - s.srv, s.cleanup, err = server.NewTestServer(checkerWithNilAssert(c)) - c.Assert(err, IsNil) - s.grpcPDClient = testutil.MustNewGrpcClient(c, s.srv.GetAddr()) - s.grpcSvr = &server.GrpcServer{Server: s.srv} - - mustWaitLeader(c, map[string]*server.Server{s.srv.GetAddr(): s.srv}) - bootstrapServer(c, newHeader(s.srv), s.grpcPDClient) - - s.ctx, s.clean = context.WithCancel(context.Background()) - s.client = setupCli(c, s.ctx, s.srv.GetEndpoints()) - - c.Assert(err, IsNil) - s.regionHeartbeat, err = s.grpcPDClient.RegionHeartbeat(s.ctx) - c.Assert(err, IsNil) - s.reportBucket, err = s.grpcPDClient.ReportBuckets(s.ctx) - c.Assert(err, IsNil) - cluster := s.srv.GetRaftCluster() - c.Assert(cluster, NotNil) + re := suite.Require() + suite.srv, suite.cleanup, err = server.NewTestServer(suite.checkerWithNilAssert()) + suite.NoError(err) + suite.grpcPDClient = testutil.MustNewGrpcClientWithTestify(re, suite.srv.GetAddr()) + suite.grpcSvr = &server.GrpcServer{Server: suite.srv} + + suite.mustWaitLeader(map[string]*server.Server{suite.srv.GetAddr(): suite.srv}) + suite.bootstrapServer(newHeader(suite.srv), suite.grpcPDClient) + + suite.ctx, suite.clean = context.WithCancel(context.Background()) + suite.client = setupCli(re, suite.ctx, suite.srv.GetEndpoints()) + + suite.regionHeartbeat, err = suite.grpcPDClient.RegionHeartbeat(suite.ctx) + suite.NoError(err) + suite.reportBucket, err = suite.grpcPDClient.ReportBuckets(suite.ctx) + suite.NoError(err) + cluster := suite.srv.GetRaftCluster() + suite.NotNil(cluster) now := time.Now().UnixNano() for _, store := range stores { - s.grpcSvr.PutStore(context.Background(), &pdpb.PutStoreRequest{ - Header: newHeader(s.srv), + suite.grpcSvr.PutStore(context.Background(), &pdpb.PutStoreRequest{ + Header: newHeader(suite.srv), Store: &metapb.Store{ Id: store.Id, Address: store.Address, @@ -660,13 +722,23 @@ func (s *testClientSuite) SetUpSuite(c *C) { config.EnableRegionBucket = true } -func (s *testClientSuite) TearDownSuite(c *C) { - s.client.Close() - s.clean() - s.cleanup() +func (suite *clientTestSuite) TearDownSuite() { + suite.client.Close() + suite.clean() + suite.cleanup() } -func mustWaitLeader(c *C, svrs map[string]*server.Server) *server.Server { +func (suite *clientTestSuite) checkerWithNilAssert() *assertutil.Checker { + checker := assertutil.NewChecker(func() { + suite.FailNow("should be nil") + }) + checker.IsNil = func(obtained interface{}) { + suite.Nil(obtained) + } + return checker +} + +func (suite *clientTestSuite) mustWaitLeader(svrs map[string]*server.Server) *server.Server { for i := 0; i < 500; i++ { for _, s := range svrs { if !s.IsClosed() && s.GetMember().IsLeader() { @@ -675,7 +747,7 @@ func mustWaitLeader(c *C, svrs map[string]*server.Server) *server.Server { } time.Sleep(100 * time.Millisecond) } - c.Fatal("no leader") + suite.FailNow("no leader") return nil } @@ -685,7 +757,7 @@ func newHeader(srv *server.Server) *pdpb.RequestHeader { } } -func bootstrapServer(c *C, header *pdpb.RequestHeader, client pdpb.PDClient) { +func (suite *clientTestSuite) bootstrapServer(header *pdpb.RequestHeader, client pdpb.PDClient) { regionID := regionIDAllocator.alloc() region := &metapb.Region{ Id: regionID, @@ -701,10 +773,10 @@ func bootstrapServer(c *C, header *pdpb.RequestHeader, client pdpb.PDClient) { Region: region, } _, err := client.Bootstrap(context.Background(), req) - c.Assert(err, IsNil) + suite.NoError(err) } -func (s *testClientSuite) TestNormalTSO(c *C) { +func (suite *clientTestSuite) TestNormalTSO() { var wg sync.WaitGroup wg.Add(tsoRequestConcurrencyNumber) for i := 0; i < tsoRequestConcurrencyNumber; i++ { @@ -712,10 +784,10 @@ func (s *testClientSuite) TestNormalTSO(c *C) { defer wg.Done() var lastTS uint64 for i := 0; i < tsoRequestRound; i++ { - physical, logical, err := s.client.GetTS(context.Background()) - c.Assert(err, IsNil) + physical, logical, err := suite.client.GetTS(context.Background()) + suite.NoError(err) ts := tsoutil.ComposeTS(physical, logical) - c.Assert(lastTS, Less, ts) + suite.Less(lastTS, ts) lastTS = ts } }() @@ -723,7 +795,7 @@ func (s *testClientSuite) TestNormalTSO(c *C) { wg.Wait() } -func (s *testClientSuite) TestGetTSAsync(c *C) { +func (suite *clientTestSuite) TestGetTSAsync() { var wg sync.WaitGroup wg.Add(tsoRequestConcurrencyNumber) for i := 0; i < tsoRequestConcurrencyNumber; i++ { @@ -731,14 +803,14 @@ func (s *testClientSuite) TestGetTSAsync(c *C) { defer wg.Done() tsFutures := make([]pd.TSFuture, tsoRequestRound) for i := range tsFutures { - tsFutures[i] = s.client.GetTSAsync(context.Background()) + tsFutures[i] = suite.client.GetTSAsync(context.Background()) } var lastTS uint64 = math.MaxUint64 for i := len(tsFutures) - 1; i >= 0; i-- { physical, logical, err := tsFutures[i].Wait() - c.Assert(err, IsNil) + suite.NoError(err) ts := tsoutil.ComposeTS(physical, logical) - c.Assert(lastTS, Greater, ts) + suite.Greater(lastTS, ts) lastTS = ts } }() @@ -746,7 +818,7 @@ func (s *testClientSuite) TestGetTSAsync(c *C) { wg.Wait() } -func (s *testClientSuite) TestGetRegion(c *C) { +func (suite *clientTestSuite) TestGetRegion() { regionID := regionIDAllocator.alloc() region := &metapb.Region{ Id: regionID, @@ -757,24 +829,25 @@ func (s *testClientSuite) TestGetRegion(c *C) { Peers: peers, } req := &pdpb.RegionHeartbeatRequest{ - Header: newHeader(s.srv), + Header: newHeader(suite.srv), Region: region, Leader: peers[0], } - err := s.regionHeartbeat.Send(req) - c.Assert(err, IsNil) - testutil.WaitUntil(c, func() bool { - r, err := s.client.GetRegion(context.Background(), []byte("a")) - c.Assert(err, IsNil) + err := suite.regionHeartbeat.Send(req) + suite.NoError(err) + t := suite.T() + testutil.WaitUntilWithTestingT(t, func() bool { + r, err := suite.client.GetRegion(context.Background(), []byte("a")) + suite.NoError(err) if r == nil { return false } - return c.Check(r.Meta, DeepEquals, region) && - c.Check(r.Leader, DeepEquals, peers[0]) && - c.Check(r.Buckets, IsNil) + return reflect.DeepEqual(region, r.Meta) && + reflect.DeepEqual(peers[0], r.Leader) && + r.Buckets == nil }) breq := &pdpb.ReportBucketsRequest{ - Header: newHeader(s.srv), + Header: newHeader(suite.srv), Buckets: &metapb.Buckets{ RegionId: regionID, Version: 1, @@ -790,30 +863,29 @@ func (s *testClientSuite) TestGetRegion(c *C) { }, }, } - c.Assert(s.reportBucket.Send(breq), IsNil) - testutil.WaitUntil(c, func() bool { - r, err := s.client.GetRegion(context.Background(), []byte("a"), pd.WithBuckets()) - c.Assert(err, IsNil) + suite.NoError(suite.reportBucket.Send(breq)) + testutil.WaitUntilWithTestingT(t, func() bool { + r, err := suite.client.GetRegion(context.Background(), []byte("a"), pd.WithBuckets()) + suite.NoError(err) if r == nil { return false } - return c.Check(r.Buckets, NotNil) + return r.Buckets != nil }) - config := s.srv.GetRaftCluster().GetStoreConfig() + config := suite.srv.GetRaftCluster().GetStoreConfig() config.EnableRegionBucket = false - testutil.WaitUntil(c, func() bool { - r, err := s.client.GetRegion(context.Background(), []byte("a"), pd.WithBuckets()) - c.Assert(err, IsNil) + testutil.WaitUntilWithTestingT(t, func() bool { + r, err := suite.client.GetRegion(context.Background(), []byte("a"), pd.WithBuckets()) + suite.NoError(err) if r == nil { return false } - return c.Check(r.Buckets, IsNil) + return r.Buckets == nil }) config.EnableRegionBucket = true - c.Succeed() } -func (s *testClientSuite) TestGetPrevRegion(c *C) { +func (suite *clientTestSuite) TestGetPrevRegion() { regionLen := 10 regions := make([]*metapb.Region, 0, regionLen) for i := 0; i < regionLen; i++ { @@ -830,29 +902,28 @@ func (s *testClientSuite) TestGetPrevRegion(c *C) { } regions = append(regions, r) req := &pdpb.RegionHeartbeatRequest{ - Header: newHeader(s.srv), + Header: newHeader(suite.srv), Region: r, Leader: peers[0], } - err := s.regionHeartbeat.Send(req) - c.Assert(err, IsNil) + err := suite.regionHeartbeat.Send(req) + suite.NoError(err) } time.Sleep(500 * time.Millisecond) for i := 0; i < 20; i++ { - testutil.WaitUntil(c, func() bool { - r, err := s.client.GetPrevRegion(context.Background(), []byte{byte(i)}) - c.Assert(err, IsNil) + testutil.WaitUntilWithTestingT(suite.T(), func() bool { + r, err := suite.client.GetPrevRegion(context.Background(), []byte{byte(i)}) + suite.NoError(err) if i > 0 && i < regionLen { - return c.Check(r.Leader, DeepEquals, peers[0]) && - c.Check(r.Meta, DeepEquals, regions[i-1]) + return reflect.DeepEqual(peers[0], r.Leader) && + reflect.DeepEqual(regions[i-1], r.Meta) } - return c.Check(r, IsNil) + return r == nil }) } - c.Succeed() } -func (s *testClientSuite) TestScanRegions(c *C) { +func (suite *clientTestSuite) TestScanRegions() { regionLen := 10 regions := make([]*metapb.Region, 0, regionLen) for i := 0; i < regionLen; i++ { @@ -869,53 +940,54 @@ func (s *testClientSuite) TestScanRegions(c *C) { } regions = append(regions, r) req := &pdpb.RegionHeartbeatRequest{ - Header: newHeader(s.srv), + Header: newHeader(suite.srv), Region: r, Leader: peers[0], } - err := s.regionHeartbeat.Send(req) - c.Assert(err, IsNil) + err := suite.regionHeartbeat.Send(req) + suite.NoError(err) } // Wait for region heartbeats. - testutil.WaitUntil(c, func() bool { - scanRegions, err := s.client.ScanRegions(context.Background(), []byte{0}, nil, 10) + t := suite.T() + testutil.WaitUntilWithTestingT(t, func() bool { + scanRegions, err := suite.client.ScanRegions(context.Background(), []byte{0}, nil, 10) return err == nil && len(scanRegions) == 10 }) // Set leader of region3 to nil. region3 := core.NewRegionInfo(regions[3], nil) - s.srv.GetRaftCluster().HandleRegionHeartbeat(region3) + suite.srv.GetRaftCluster().HandleRegionHeartbeat(region3) // Add down peer for region4. region4 := core.NewRegionInfo(regions[4], regions[4].Peers[0], core.WithDownPeers([]*pdpb.PeerStats{{Peer: regions[4].Peers[1]}})) - s.srv.GetRaftCluster().HandleRegionHeartbeat(region4) + suite.srv.GetRaftCluster().HandleRegionHeartbeat(region4) // Add pending peers for region5. region5 := core.NewRegionInfo(regions[5], regions[5].Peers[0], core.WithPendingPeers([]*metapb.Peer{regions[5].Peers[1], regions[5].Peers[2]})) - s.srv.GetRaftCluster().HandleRegionHeartbeat(region5) + suite.srv.GetRaftCluster().HandleRegionHeartbeat(region5) check := func(start, end []byte, limit int, expect []*metapb.Region) { - scanRegions, err := s.client.ScanRegions(context.Background(), start, end, limit) - c.Assert(err, IsNil) - c.Assert(scanRegions, HasLen, len(expect)) - c.Log("scanRegions", scanRegions) - c.Log("expect", expect) + scanRegions, err := suite.client.ScanRegions(context.Background(), start, end, limit) + suite.NoError(err) + suite.Len(scanRegions, len(expect)) + t.Log("scanRegions", scanRegions) + t.Log("expect", expect) for i := range expect { - c.Assert(scanRegions[i].Meta, DeepEquals, expect[i]) + suite.True(reflect.DeepEqual(expect[i], scanRegions[i].Meta)) if scanRegions[i].Meta.GetId() == region3.GetID() { - c.Assert(scanRegions[i].Leader, DeepEquals, &metapb.Peer{}) + suite.True(reflect.DeepEqual(&metapb.Peer{}, scanRegions[i].Leader)) } else { - c.Assert(scanRegions[i].Leader, DeepEquals, expect[i].Peers[0]) + suite.True(reflect.DeepEqual(expect[i].Peers[0], scanRegions[i].Leader)) } if scanRegions[i].Meta.GetId() == region4.GetID() { - c.Assert(scanRegions[i].DownPeers, DeepEquals, []*metapb.Peer{expect[i].Peers[1]}) + suite.True(reflect.DeepEqual([]*metapb.Peer{expect[i].Peers[1]}, scanRegions[i].DownPeers)) } if scanRegions[i].Meta.GetId() == region5.GetID() { - c.Assert(scanRegions[i].PendingPeers, DeepEquals, []*metapb.Peer{expect[i].Peers[1], expect[i].Peers[2]}) + suite.True(reflect.DeepEqual([]*metapb.Peer{expect[i].Peers[1], expect[i].Peers[2]}, scanRegions[i].PendingPeers)) } } } @@ -927,7 +999,7 @@ func (s *testClientSuite) TestScanRegions(c *C) { check([]byte{1}, []byte{6}, 2, regions[1:3]) } -func (s *testClientSuite) TestGetRegionByID(c *C) { +func (suite *clientTestSuite) TestGetRegionByID() { regionID := regionIDAllocator.alloc() region := &metapb.Region{ Id: regionID, @@ -938,125 +1010,125 @@ func (s *testClientSuite) TestGetRegionByID(c *C) { Peers: peers, } req := &pdpb.RegionHeartbeatRequest{ - Header: newHeader(s.srv), + Header: newHeader(suite.srv), Region: region, Leader: peers[0], } - err := s.regionHeartbeat.Send(req) - c.Assert(err, IsNil) + err := suite.regionHeartbeat.Send(req) + suite.NoError(err) - testutil.WaitUntil(c, func() bool { - r, err := s.client.GetRegionByID(context.Background(), regionID) - c.Assert(err, IsNil) + testutil.WaitUntilWithTestingT(suite.T(), func() bool { + r, err := suite.client.GetRegionByID(context.Background(), regionID) + suite.NoError(err) if r == nil { return false } - return c.Check(r.Meta, DeepEquals, region) && - c.Check(r.Leader, DeepEquals, peers[0]) + return reflect.DeepEqual(region, r.Meta) && + reflect.DeepEqual(peers[0], r.Leader) }) - c.Succeed() } -func (s *testClientSuite) TestGetStore(c *C) { - cluster := s.srv.GetRaftCluster() - c.Assert(cluster, NotNil) +func (suite *clientTestSuite) TestGetStore() { + cluster := suite.srv.GetRaftCluster() + suite.NotNil(cluster) store := stores[0] // Get an up store should be OK. - n, err := s.client.GetStore(context.Background(), store.GetId()) - c.Assert(err, IsNil) - c.Assert(n, DeepEquals, store) + n, err := suite.client.GetStore(context.Background(), store.GetId()) + suite.NoError(err) + suite.True(reflect.DeepEqual(store, n)) - stores, err := s.client.GetAllStores(context.Background()) - c.Assert(err, IsNil) - c.Assert(stores, DeepEquals, stores) + actualStores, err := suite.client.GetAllStores(context.Background()) + suite.NoError(err) + suite.Len(actualStores, len(stores)) + stores = actualStores // Mark the store as offline. err = cluster.RemoveStore(store.GetId(), false) - c.Assert(err, IsNil) + suite.NoError(err) offlineStore := proto.Clone(store).(*metapb.Store) offlineStore.State = metapb.StoreState_Offline offlineStore.NodeState = metapb.NodeState_Removing // Get an offline store should be OK. - n, err = s.client.GetStore(context.Background(), store.GetId()) - c.Assert(err, IsNil) - c.Assert(n, DeepEquals, offlineStore) + n, err = suite.client.GetStore(context.Background(), store.GetId()) + suite.NoError(err) + suite.True(reflect.DeepEqual(offlineStore, n)) // Should return offline stores. contains := false - stores, err = s.client.GetAllStores(context.Background()) - c.Assert(err, IsNil) + stores, err = suite.client.GetAllStores(context.Background()) + suite.NoError(err) for _, store := range stores { if store.GetId() == offlineStore.GetId() { contains = true - c.Assert(store, DeepEquals, offlineStore) + suite.True(reflect.DeepEqual(offlineStore, store)) } } - c.Assert(contains, IsTrue) + suite.True(contains) // Mark the store as physically destroyed and offline. err = cluster.RemoveStore(store.GetId(), true) - c.Assert(err, IsNil) + suite.NoError(err) physicallyDestroyedStoreID := store.GetId() // Get a physically destroyed and offline store // It should be Tombstone(become Tombstone automically) or Offline - n, err = s.client.GetStore(context.Background(), physicallyDestroyedStoreID) - c.Assert(err, IsNil) + n, err = suite.client.GetStore(context.Background(), physicallyDestroyedStoreID) + suite.NoError(err) if n != nil { // store is still offline and physically destroyed - c.Assert(n.GetNodeState(), Equals, metapb.NodeState_Removing) - c.Assert(n.PhysicallyDestroyed, IsTrue) + suite.Equal(metapb.NodeState_Removing, n.GetNodeState()) + suite.True(n.PhysicallyDestroyed) } // Should return tombstone stores. contains = false - stores, err = s.client.GetAllStores(context.Background()) - c.Assert(err, IsNil) + stores, err = suite.client.GetAllStores(context.Background()) + suite.NoError(err) for _, store := range stores { if store.GetId() == physicallyDestroyedStoreID { contains = true - c.Assert(store.GetState(), Not(Equals), metapb.StoreState_Up) - c.Assert(store.PhysicallyDestroyed, IsTrue) + suite.NotEqual(metapb.StoreState_Up, store.GetState()) + suite.True(store.PhysicallyDestroyed) } } - c.Assert(contains, IsTrue) + suite.True(contains) // Should not return tombstone stores. - stores, err = s.client.GetAllStores(context.Background(), pd.WithExcludeTombstone()) - c.Assert(err, IsNil) + stores, err = suite.client.GetAllStores(context.Background(), pd.WithExcludeTombstone()) + suite.NoError(err) for _, store := range stores { if store.GetId() == physicallyDestroyedStoreID { - c.Assert(store.GetState(), Equals, metapb.StoreState_Offline) - c.Assert(store.PhysicallyDestroyed, IsTrue) + suite.Equal(metapb.StoreState_Offline, store.GetState()) + suite.True(store.PhysicallyDestroyed) } } } -func (s *testClientSuite) checkGCSafePoint(c *C, expectedSafePoint uint64) { +func (suite *clientTestSuite) checkGCSafePoint(expectedSafePoint uint64) { req := &pdpb.GetGCSafePointRequest{ - Header: newHeader(s.srv), + Header: newHeader(suite.srv), } - resp, err := s.grpcSvr.GetGCSafePoint(context.Background(), req) - c.Assert(err, IsNil) - c.Assert(resp.SafePoint, Equals, expectedSafePoint) + resp, err := suite.grpcSvr.GetGCSafePoint(context.Background(), req) + suite.NoError(err) + suite.Equal(expectedSafePoint, resp.SafePoint) } -func (s *testClientSuite) TestUpdateGCSafePoint(c *C) { - s.checkGCSafePoint(c, 0) +func (suite *clientTestSuite) TestUpdateGCSafePoint() { + suite.checkGCSafePoint(0) for _, safePoint := range []uint64{0, 1, 2, 3, 233, 23333, 233333333333, math.MaxUint64} { - newSafePoint, err := s.client.UpdateGCSafePoint(context.Background(), safePoint) - c.Assert(err, IsNil) - c.Assert(newSafePoint, Equals, safePoint) - s.checkGCSafePoint(c, safePoint) + newSafePoint, err := suite.client.UpdateGCSafePoint(context.Background(), safePoint) + suite.NoError(err) + suite.Equal(safePoint, newSafePoint) + suite.checkGCSafePoint(safePoint) } // If the new safe point is less than the old one, it should not be updated. - newSafePoint, err := s.client.UpdateGCSafePoint(context.Background(), 1) - c.Assert(newSafePoint, Equals, uint64(math.MaxUint64)) - c.Assert(err, IsNil) - s.checkGCSafePoint(c, math.MaxUint64) + newSafePoint, err := suite.client.UpdateGCSafePoint(context.Background(), 1) + suite.Equal(uint64(math.MaxUint64), newSafePoint) + suite.NoError(err) + suite.checkGCSafePoint(math.MaxUint64) } -func (s *testClientSuite) TestUpdateServiceGCSafePoint(c *C) { +func (suite *clientTestSuite) TestUpdateServiceGCSafePoint() { serviceSafePoints := []struct { ServiceID string TTL int64 @@ -1067,105 +1139,105 @@ func (s *testClientSuite) TestUpdateServiceGCSafePoint(c *C) { {"c", 1000, 3}, } for _, ssp := range serviceSafePoints { - min, err := s.client.UpdateServiceGCSafePoint(context.Background(), + min, err := suite.client.UpdateServiceGCSafePoint(context.Background(), ssp.ServiceID, 1000, ssp.SafePoint) - c.Assert(err, IsNil) + suite.NoError(err) // An service safepoint of ID "gc_worker" is automatically initialized as 0 - c.Assert(min, Equals, uint64(0)) + suite.Equal(uint64(0), min) } - min, err := s.client.UpdateServiceGCSafePoint(context.Background(), + min, err := suite.client.UpdateServiceGCSafePoint(context.Background(), "gc_worker", math.MaxInt64, 10) - c.Assert(err, IsNil) - c.Assert(min, Equals, uint64(1)) + suite.NoError(err) + suite.Equal(uint64(1), min) - min, err = s.client.UpdateServiceGCSafePoint(context.Background(), + min, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "a", 1000, 4) - c.Assert(err, IsNil) - c.Assert(min, Equals, uint64(2)) + suite.NoError(err) + suite.Equal(uint64(2), min) - min, err = s.client.UpdateServiceGCSafePoint(context.Background(), + min, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "b", -100, 2) - c.Assert(err, IsNil) - c.Assert(min, Equals, uint64(3)) + suite.NoError(err) + suite.Equal(uint64(3), min) // Minimum safepoint does not regress - min, err = s.client.UpdateServiceGCSafePoint(context.Background(), + min, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "b", 1000, 2) - c.Assert(err, IsNil) - c.Assert(min, Equals, uint64(3)) + suite.NoError(err) + suite.Equal(uint64(3), min) // Update only the TTL of the minimum safepoint - oldMinSsp, err := s.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) - c.Assert(err, IsNil) - c.Assert(oldMinSsp.ServiceID, Equals, "c") - c.Assert(oldMinSsp.SafePoint, Equals, uint64(3)) - min, err = s.client.UpdateServiceGCSafePoint(context.Background(), + oldMinSsp, err := suite.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) + suite.NoError(err) + suite.Equal("c", oldMinSsp.ServiceID) + suite.Equal(uint64(3), oldMinSsp.SafePoint) + min, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "c", 2000, 3) - c.Assert(err, IsNil) - c.Assert(min, Equals, uint64(3)) - minSsp, err := s.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) - c.Assert(err, IsNil) - c.Assert(minSsp.ServiceID, Equals, "c") - c.Assert(oldMinSsp.SafePoint, Equals, uint64(3)) - c.Assert(minSsp.ExpiredAt-oldMinSsp.ExpiredAt, GreaterEqual, int64(1000)) + suite.NoError(err) + suite.Equal(uint64(3), min) + minSsp, err := suite.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) + suite.NoError(err) + suite.Equal("c", minSsp.ServiceID) + suite.Equal(uint64(3), oldMinSsp.SafePoint) + suite.GreaterOrEqual(minSsp.ExpiredAt-oldMinSsp.ExpiredAt, int64(1000)) // Shrinking TTL is also allowed - min, err = s.client.UpdateServiceGCSafePoint(context.Background(), + min, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "c", 1, 3) - c.Assert(err, IsNil) - c.Assert(min, Equals, uint64(3)) - minSsp, err = s.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) - c.Assert(err, IsNil) - c.Assert(minSsp.ServiceID, Equals, "c") - c.Assert(minSsp.ExpiredAt, Less, oldMinSsp.ExpiredAt) + suite.NoError(err) + suite.Equal(uint64(3), min) + minSsp, err = suite.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) + suite.NoError(err) + suite.Equal("c", minSsp.ServiceID) + suite.Less(minSsp.ExpiredAt, oldMinSsp.ExpiredAt) // TTL can be infinite (represented by math.MaxInt64) - min, err = s.client.UpdateServiceGCSafePoint(context.Background(), + min, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "c", math.MaxInt64, 3) - c.Assert(err, IsNil) - c.Assert(min, Equals, uint64(3)) - minSsp, err = s.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) - c.Assert(err, IsNil) - c.Assert(minSsp.ServiceID, Equals, "c") - c.Assert(minSsp.ExpiredAt, Equals, int64(math.MaxInt64)) + suite.NoError(err) + suite.Equal(uint64(3), min) + minSsp, err = suite.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) + suite.NoError(err) + suite.Equal("c", minSsp.ServiceID) + suite.Equal(minSsp.ExpiredAt, int64(math.MaxInt64)) // Delete "a" and "c" - min, err = s.client.UpdateServiceGCSafePoint(context.Background(), + min, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "c", -1, 3) - c.Assert(err, IsNil) - c.Assert(min, Equals, uint64(4)) - min, err = s.client.UpdateServiceGCSafePoint(context.Background(), + suite.NoError(err) + suite.Equal(uint64(4), min) + min, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "a", -1, 4) - c.Assert(err, IsNil) + suite.NoError(err) // Now gc_worker is the only remaining service safe point. - c.Assert(min, Equals, uint64(10)) + suite.Equal(uint64(10), min) // gc_worker cannot be deleted. - _, err = s.client.UpdateServiceGCSafePoint(context.Background(), + _, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "gc_worker", -1, 10) - c.Assert(err, NotNil) + suite.Error(err) // Cannot set non-infinity TTL for gc_worker - _, err = s.client.UpdateServiceGCSafePoint(context.Background(), + _, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "gc_worker", 10000000, 10) - c.Assert(err, NotNil) + suite.Error(err) // Service safepoint must have a non-empty ID - _, err = s.client.UpdateServiceGCSafePoint(context.Background(), + _, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "", 1000, 15) - c.Assert(err, NotNil) + suite.Error(err) // Put some other safepoints to test fixing gc_worker's safepoint when there exists other safepoints. - _, err = s.client.UpdateServiceGCSafePoint(context.Background(), + _, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "a", 1000, 11) - c.Assert(err, IsNil) - _, err = s.client.UpdateServiceGCSafePoint(context.Background(), + suite.NoError(err) + _, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "b", 1000, 12) - c.Assert(err, IsNil) - _, err = s.client.UpdateServiceGCSafePoint(context.Background(), + suite.NoError(err) + _, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "c", 1000, 13) - c.Assert(err, IsNil) + suite.NoError(err) // Force set invalid ttl to gc_worker gcWorkerKey := path.Join("gc", "safe_point", "service", "gc_worker") @@ -1176,38 +1248,38 @@ func (s *testClientSuite) TestUpdateServiceGCSafePoint(c *C) { SafePoint: 10, } value, err := json.Marshal(gcWorkerSsp) - c.Assert(err, IsNil) - err = s.srv.GetStorage().Save(gcWorkerKey, string(value)) - c.Assert(err, IsNil) + suite.NoError(err) + err = suite.srv.GetStorage().Save(gcWorkerKey, string(value)) + suite.NoError(err) } - minSsp, err = s.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) - c.Assert(err, IsNil) - c.Assert(minSsp.ServiceID, Equals, "gc_worker") - c.Assert(minSsp.SafePoint, Equals, uint64(10)) - c.Assert(minSsp.ExpiredAt, Equals, int64(math.MaxInt64)) + minSsp, err = suite.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) + suite.NoError(err) + suite.Equal("gc_worker", minSsp.ServiceID) + suite.Equal(uint64(10), minSsp.SafePoint) + suite.Equal(int64(math.MaxInt64), minSsp.ExpiredAt) // Force delete gc_worker, then the min service safepoint is 11 of "a". - err = s.srv.GetStorage().Remove(gcWorkerKey) - c.Assert(err, IsNil) - minSsp, err = s.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) - c.Assert(err, IsNil) - c.Assert(minSsp.SafePoint, Equals, uint64(11)) + err = suite.srv.GetStorage().Remove(gcWorkerKey) + suite.NoError(err) + minSsp, err = suite.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) + suite.NoError(err) + suite.Equal(uint64(11), minSsp.SafePoint) // After calling LoadMinServiceGCS when "gc_worker"'s service safepoint is missing, "gc_worker"'s service safepoint // will be newly created. // Increase "a" so that "gc_worker" is the only minimum that will be returned by LoadMinServiceGCSafePoint. - _, err = s.client.UpdateServiceGCSafePoint(context.Background(), + _, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "a", 1000, 14) - c.Assert(err, IsNil) + suite.NoError(err) - minSsp, err = s.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) - c.Assert(err, IsNil) - c.Assert(minSsp.ServiceID, Equals, "gc_worker") - c.Assert(minSsp.SafePoint, Equals, uint64(11)) - c.Assert(minSsp.ExpiredAt, Equals, int64(math.MaxInt64)) + minSsp, err = suite.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) + suite.NoError(err) + suite.Equal("gc_worker", minSsp.ServiceID) + suite.Equal(uint64(11), minSsp.SafePoint) + suite.Equal(int64(math.MaxInt64), minSsp.ExpiredAt) } -func (s *testClientSuite) TestScatterRegion(c *C) { +func (suite *clientTestSuite) TestScatterRegion() { regionID := regionIDAllocator.alloc() region := &metapb.Region{ Id: regionID, @@ -1220,106 +1292,46 @@ func (s *testClientSuite) TestScatterRegion(c *C) { EndKey: []byte("ggg"), } req := &pdpb.RegionHeartbeatRequest{ - Header: newHeader(s.srv), + Header: newHeader(suite.srv), Region: region, Leader: peers[0], } - err := s.regionHeartbeat.Send(req) + err := suite.regionHeartbeat.Send(req) regionsID := []uint64{regionID} - c.Assert(err, IsNil) + suite.NoError(err) // Test interface `ScatterRegions`. - testutil.WaitUntil(c, func() bool { - scatterResp, err := s.client.ScatterRegions(context.Background(), regionsID, pd.WithGroup("test"), pd.WithRetry(1)) - if c.Check(err, NotNil) { + t := suite.T() + testutil.WaitUntilWithTestingT(t, func() bool { + scatterResp, err := suite.client.ScatterRegions(context.Background(), regionsID, pd.WithGroup("test"), pd.WithRetry(1)) + if err != nil { return false } - if c.Check(scatterResp.FinishedPercentage, Not(Equals), uint64(100)) { + if scatterResp.FinishedPercentage != uint64(100) { return false } - resp, err := s.client.GetOperator(context.Background(), regionID) - if c.Check(err, NotNil) { + resp, err := suite.client.GetOperator(context.Background(), regionID) + if err != nil { return false } - return c.Check(resp.GetRegionId(), Equals, regionID) && c.Check(string(resp.GetDesc()), Equals, "scatter-region") && c.Check(resp.GetStatus(), Equals, pdpb.OperatorStatus_RUNNING) + return resp.GetRegionId() == regionID && + string(resp.GetDesc()) == "scatter-region" && + resp.GetStatus() == pdpb.OperatorStatus_RUNNING }, testutil.WithSleepInterval(1*time.Second)) // Test interface `ScatterRegion`. // TODO: Deprecate interface `ScatterRegion`. - testutil.WaitUntil(c, func() bool { - err := s.client.ScatterRegion(context.Background(), regionID) - if c.Check(err, NotNil) { + testutil.WaitUntilWithTestingT(t, func() bool { + err := suite.client.ScatterRegion(context.Background(), regionID) + if err != nil { fmt.Println(err) return false } - resp, err := s.client.GetOperator(context.Background(), regionID) - if c.Check(err, NotNil) { + resp, err := suite.client.GetOperator(context.Background(), regionID) + if err != nil { return false } - return c.Check(resp.GetRegionId(), Equals, regionID) && c.Check(string(resp.GetDesc()), Equals, "scatter-region") && c.Check(resp.GetStatus(), Equals, pdpb.OperatorStatus_RUNNING) + return resp.GetRegionId() == regionID && + string(resp.GetDesc()) == "scatter-region" && + resp.GetStatus() == pdpb.OperatorStatus_RUNNING }, testutil.WithSleepInterval(1*time.Second)) - - c.Succeed() -} - -type testConfigTTLSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *testConfigTTLSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) - server.EnableZap = true -} - -func (s *testConfigTTLSuite) TearDownSuite(c *C) { - s.cancel() -} - -var _ = SerialSuites(&testConfigTTLSuite{}) - -var ttlConfig = map[string]interface{}{ - "schedule.max-snapshot-count": 999, - "schedule.enable-location-replacement": false, - "schedule.max-merge-region-size": 999, - "schedule.max-merge-region-keys": 999, - "schedule.scheduler-max-waiting-operator": 999, - "schedule.leader-schedule-limit": 999, - "schedule.region-schedule-limit": 999, - "schedule.hot-region-schedule-limit": 999, - "schedule.replica-schedule-limit": 999, - "schedule.merge-schedule-limit": 999, -} - -func assertTTLConfig(c *C, options *config.PersistOptions, checker Checker) { - c.Assert(options.GetMaxSnapshotCount(), checker, uint64(999)) - c.Assert(options.IsLocationReplacementEnabled(), checker, false) - c.Assert(options.GetMaxMergeRegionSize(), checker, uint64(999)) - c.Assert(options.GetMaxMergeRegionKeys(), checker, uint64(999)) - c.Assert(options.GetSchedulerMaxWaitingOperator(), checker, uint64(999)) - c.Assert(options.GetLeaderScheduleLimit(), checker, uint64(999)) - c.Assert(options.GetRegionScheduleLimit(), checker, uint64(999)) - c.Assert(options.GetHotRegionScheduleLimit(), checker, uint64(999)) - c.Assert(options.GetReplicaScheduleLimit(), checker, uint64(999)) - c.Assert(options.GetMergeScheduleLimit(), checker, uint64(999)) -} - -func (s *testConfigTTLSuite) TestConfigTTLAfterTransferLeader(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 3) - c.Assert(err, IsNil) - defer cluster.Destroy() - err = cluster.RunInitialServers() - c.Assert(err, IsNil) - leader := cluster.GetServer(cluster.WaitLeader()) - c.Assert(leader.BootstrapCluster(), IsNil) - addr := fmt.Sprintf("%s/pd/api/v1/config?ttlSecond=5", leader.GetAddr()) - postData, err := json.Marshal(ttlConfig) - c.Assert(err, IsNil) - resp, err := leader.GetHTTPClient().Post(addr, "application/json", bytes.NewBuffer(postData)) - resp.Body.Close() - c.Assert(err, IsNil) - time.Sleep(2 * time.Second) - _ = leader.Destroy() - time.Sleep(2 * time.Second) - leader = cluster.GetServer(cluster.WaitLeader()) - assertTTLConfig(c, leader.GetPersistOptions(), Equals) } diff --git a/tests/client/client_tls_test.go b/tests/client/client_tls_test.go index 3fbca17c835..48a6fec3d2d 100644 --- a/tests/client/client_tls_test.go +++ b/tests/client/client_tls_test.go @@ -22,21 +22,19 @@ import ( "os" "path/filepath" "strings" + "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" pd "github.com/tikv/pd/client" "github.com/tikv/pd/pkg/grpcutil" "github.com/tikv/pd/pkg/netutil" - "github.com/tikv/pd/server" "github.com/tikv/pd/server/config" "github.com/tikv/pd/tests" "go.etcd.io/etcd/pkg/transport" "google.golang.org/grpc" ) -var _ = Suite(&clientTLSTestSuite{}) - var ( testTLSInfo = transport.TLSInfo{ KeyFile: "./cert/pd-server-key.pem", @@ -57,49 +55,38 @@ var ( } ) -type clientTLSTestSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *clientTLSTestSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) - server.EnableZap = true -} - -func (s *clientTLSTestSuite) TearDownSuite(c *C) { - s.cancel() -} - // TestTLSReloadAtomicReplace ensures server reloads expired/valid certs // when all certs are atomically replaced by directory renaming. // And expects server to reject client requests, and vice versa. -func (s *clientTLSTestSuite) TestTLSReloadAtomicReplace(c *C) { +func TestTLSReloadAtomicReplace(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() tmpDir, err := os.MkdirTemp(os.TempDir(), "cert-tmp") - c.Assert(err, IsNil) + re.NoError(err) os.RemoveAll(tmpDir) defer os.RemoveAll(tmpDir) certsDir, err := os.MkdirTemp(os.TempDir(), "cert-to-load") - c.Assert(err, IsNil) + re.NoError(err) defer os.RemoveAll(certsDir) certsDirExp, err := os.MkdirTemp(os.TempDir(), "cert-expired") - c.Assert(err, IsNil) + re.NoError(err) defer os.RemoveAll(certsDirExp) cloneFunc := func() transport.TLSInfo { tlsInfo, terr := copyTLSFiles(testTLSInfo, certsDir) - c.Assert(terr, IsNil) + re.NoError(terr) _, err = copyTLSFiles(testTLSInfoExpired, certsDirExp) - c.Assert(err, IsNil) + re.NoError(err) return tlsInfo } replaceFunc := func() { err = os.Rename(certsDir, tmpDir) - c.Assert(err, IsNil) + re.NoError(err) err = os.Rename(certsDirExp, certsDir) - c.Assert(err, IsNil) + re.NoError(err) // after rename, // 'certsDir' contains expired certs // 'tmpDir' contains valid certs @@ -107,25 +94,26 @@ func (s *clientTLSTestSuite) TestTLSReloadAtomicReplace(c *C) { } revertFunc := func() { err = os.Rename(tmpDir, certsDirExp) - c.Assert(err, IsNil) + re.NoError(err) err = os.Rename(certsDir, tmpDir) - c.Assert(err, IsNil) + re.NoError(err) err = os.Rename(certsDirExp, certsDir) - c.Assert(err, IsNil) + re.NoError(err) } - s.testTLSReload(c, cloneFunc, replaceFunc, revertFunc) + testTLSReload(re, ctx, cloneFunc, replaceFunc, revertFunc) } -func (s *clientTLSTestSuite) testTLSReload( - c *C, +func testTLSReload( + re *require.Assertions, + ctx context.Context, cloneFunc func() transport.TLSInfo, replaceFunc func(), revertFunc func()) { tlsInfo := cloneFunc() // 1. start cluster with valid certs - clus, err := tests.NewTestCluster(s.ctx, 1, func(conf *config.Config, serverName string) { + clus, err := tests.NewTestCluster(ctx, 1, func(conf *config.Config, serverName string) { conf.Security.TLSConfig = grpcutil.TLSConfig{ KeyPath: tlsInfo.KeyFile, CertPath: tlsInfo.CertFile, @@ -137,10 +125,10 @@ func (s *clientTLSTestSuite) testTLSReload( conf.PeerUrls = strings.ReplaceAll(conf.PeerUrls, "http", "https") conf.InitialCluster = strings.ReplaceAll(conf.InitialCluster, "http", "https") }) - c.Assert(err, IsNil) + re.NoError(err) defer clus.Destroy() err = clus.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) clus.WaitLeader() testServers := clus.GetServers() @@ -148,20 +136,20 @@ func (s *clientTLSTestSuite) testTLSReload( for _, s := range testServers { endpoints = append(endpoints, s.GetConfig().AdvertiseClientUrls) tlsConfig, err := s.GetConfig().Security.ToTLSConfig() - c.Assert(err, IsNil) + re.NoError(err) httpClient := &http.Client{ Transport: &http.Transport{ DisableKeepAlives: true, TLSClientConfig: tlsConfig, }, } - c.Assert(netutil.IsEnableHTTPS(httpClient), IsTrue) + re.True(netutil.IsEnableHTTPS(httpClient)) } // 2. concurrent client dialing while certs become expired errc := make(chan error, 1) go func() { for { - dctx, dcancel := context.WithTimeout(s.ctx, time.Second) + dctx, dcancel := context.WithTimeout(ctx, time.Second) cli, err := pd.NewClientWithContext(dctx, endpoints, pd.SecurityOption{ CAPath: testClientTLSInfo.TrustedCAFile, CertPath: testClientTLSInfo.CertFile, @@ -183,46 +171,46 @@ func (s *clientTLSTestSuite) testTLSReload( // 4. expect dial time-out when loading expired certs select { case cerr := <-errc: - c.Assert(strings.Contains(cerr.Error(), "failed to get cluster id"), IsTrue) + re.Contains(cerr.Error(), "failed to get cluster id") case <-time.After(5 * time.Second): - c.Fatal("failed to receive dial timeout error") + re.FailNow("failed to receive dial timeout error") } // 5. replace expired certs back with valid ones revertFunc() // 6. new requests should trigger listener to reload valid certs - dctx, dcancel := context.WithTimeout(s.ctx, 5*time.Second) + dctx, dcancel := context.WithTimeout(ctx, 5*time.Second) cli, err := pd.NewClientWithContext(dctx, endpoints, pd.SecurityOption{ CAPath: testClientTLSInfo.TrustedCAFile, CertPath: testClientTLSInfo.CertFile, KeyPath: testClientTLSInfo.KeyFile, }, pd.WithGRPCDialOptions(grpc.WithBlock())) - c.Assert(err, IsNil) + re.NoError(err) dcancel() cli.Close() // 7. test use raw bytes to init tls config - caData, certData, keyData := loadTLSContent(c, + caData, certData, keyData := loadTLSContent(re, testClientTLSInfo.TrustedCAFile, testClientTLSInfo.CertFile, testClientTLSInfo.KeyFile) - ctx1, cancel1 := context.WithTimeout(s.ctx, 2*time.Second) + ctx1, cancel1 := context.WithTimeout(ctx, 2*time.Second) _, err = pd.NewClientWithContext(ctx1, endpoints, pd.SecurityOption{ SSLCABytes: caData, SSLCertBytes: certData, SSLKEYBytes: keyData, }, pd.WithGRPCDialOptions(grpc.WithBlock())) - c.Assert(err, IsNil) + re.NoError(err) cancel1() } -func loadTLSContent(c *C, caPath, certPath, keyPath string) (caData, certData, keyData []byte) { +func loadTLSContent(re *require.Assertions, caPath, certPath, keyPath string) (caData, certData, keyData []byte) { var err error caData, err = os.ReadFile(caPath) - c.Assert(err, IsNil) + re.NoError(err) certData, err = os.ReadFile(certPath) - c.Assert(err, IsNil) + re.NoError(err) keyData, err = os.ReadFile(keyPath) - c.Assert(err, IsNil) + re.NoError(err) return } @@ -245,6 +233,7 @@ func copyTLSFiles(ti transport.TLSInfo, dst string) (transport.TLSInfo, error) { } return ci, nil } + func copyFile(src, dst string) error { f, err := os.Open(src) if err != nil { diff --git a/tests/client/go.mod b/tests/client/go.mod index 93fb9d96eaa..9d539193d52 100644 --- a/tests/client/go.mod +++ b/tests/client/go.mod @@ -5,9 +5,9 @@ go 1.16 require ( github.com/gogo/protobuf v1.3.2 github.com/golang/protobuf v1.5.2 // indirect - github.com/pingcap/check v0.0.0-20211026125417-57bd13f7b5f0 github.com/pingcap/failpoint v0.0.0-20210918120811-547c13e3eb00 github.com/pingcap/kvproto v0.0.0-20220510035547-0e2f26c0a46a + github.com/stretchr/testify v1.7.0 github.com/tikv/pd v0.0.0-00010101000000-000000000000 github.com/tikv/pd/client v0.0.0-00010101000000-000000000000 go.etcd.io/etcd v0.5.0-alpha.5.0.20191023171146-3cf2f69b5738 diff --git a/tests/cluster.go b/tests/cluster.go index 2061668f393..3b0e10a02e6 100644 --- a/tests/cluster.go +++ b/tests/cluster.go @@ -19,6 +19,7 @@ import ( "net/http" "os" "sync" + "testing" "time" "github.com/coreos/go-semver/semver" @@ -622,6 +623,26 @@ func (c *TestCluster) WaitAllLeaders(testC *check.C, dcLocations map[string]stri wg.Wait() } +// WaitAllLeadersWithTestingT will block and wait for the election of PD leader and all Local TSO Allocator leaders. +// NOTICE: this is a temporary function that we will be used to replace `WaitAllLeaders` later. +func (c *TestCluster) WaitAllLeadersWithTestingT(t *testing.T, dcLocations map[string]string) { + c.WaitLeader() + c.CheckClusterDCLocation() + // Wait for each DC's Local TSO Allocator leader + wg := sync.WaitGroup{} + for _, dcLocation := range dcLocations { + wg.Add(1) + go func(dc string) { + testutil.WaitUntilWithTestingT(t, func() bool { + leaderName := c.WaitAllocatorLeader(dc) + return leaderName != "" + }) + wg.Done() + }(dcLocation) + } + wg.Wait() +} + // GetCluster returns PD cluster. func (c *TestCluster) GetCluster() *metapb.Cluster { leader := c.GetLeader() From 422acfb5e6a02fcef61b07c6cfb9004849d21e93 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Wed, 1 Jun 2022 17:42:29 +0800 Subject: [PATCH 10/82] mockhbstream, movingaverage, netutil, progress: testify the tests (#5086) ref tikv/pd#4813 Testify the pkg/mockhbstream, pkg/movingaverage, pkg/netutil, pkg/progress tests. Signed-off-by: JmPotato --- pkg/mock/mockhbstream/mockhbstream_test.go | 36 +++------- pkg/movingaverage/avg_over_time_test.go | 51 +++++++------- pkg/movingaverage/max_filter_test.go | 19 +++-- pkg/movingaverage/moving_average_test.go | 60 ++++++++-------- pkg/movingaverage/queue_test.go | 18 +++-- pkg/netutil/address_test.go | 24 +++---- pkg/progress/progress_test.go | 80 ++++++++++------------ 7 files changed, 128 insertions(+), 160 deletions(-) diff --git a/pkg/mock/mockhbstream/mockhbstream_test.go b/pkg/mock/mockhbstream/mockhbstream_test.go index e6d05f19d1b..5f9d814835b 100644 --- a/pkg/mock/mockhbstream/mockhbstream_test.go +++ b/pkg/mock/mockhbstream/mockhbstream_test.go @@ -19,40 +19,22 @@ import ( "testing" "github.com/gogo/protobuf/proto" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/eraftpb" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/mock/mockcluster" "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/schedule/hbstream" ) -func TestHeaertbeatStreams(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testHeartbeatStreamSuite{}) - -type testHeartbeatStreamSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *testHeartbeatStreamSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) -} - -func (s *testHeartbeatStreamSuite) TearDownTest(c *C) { - s.cancel() -} - -func (s *testHeartbeatStreamSuite) TestActivity(c *C) { +func TestActivity(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - cluster := mockcluster.NewCluster(s.ctx, config.NewTestOptions()) + cluster := mockcluster.NewCluster(ctx, config.NewTestOptions()) cluster.AddRegionStore(1, 1) cluster.AddRegionStore(2, 0) cluster.AddLeaderRegion(1, 1) @@ -66,24 +48,24 @@ func (s *testHeartbeatStreamSuite) TestActivity(c *C) { // Active stream is stream1. hbs.BindStream(1, stream1) - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { hbs.SendMsg(region, proto.Clone(msg).(*pdpb.RegionHeartbeatResponse)) return stream1.Recv() != nil && stream2.Recv() == nil }) // Rebind to stream2. hbs.BindStream(1, stream2) - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { hbs.SendMsg(region, proto.Clone(msg).(*pdpb.RegionHeartbeatResponse)) return stream1.Recv() == nil && stream2.Recv() != nil }) // SendErr to stream2. hbs.SendErr(pdpb.ErrorType_UNKNOWN, "test error", &metapb.Peer{Id: 1, StoreId: 1}) res := stream2.Recv() - c.Assert(res, NotNil) - c.Assert(res.GetHeader().GetError(), NotNil) + re.NotNil(res) + re.NotNil(res.GetHeader().GetError()) // Switch back to 1 again. hbs.BindStream(1, stream1) - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { hbs.SendMsg(region, proto.Clone(msg).(*pdpb.RegionHeartbeatResponse)) return stream1.Recv() != nil && stream2.Recv() == nil }) diff --git a/pkg/movingaverage/avg_over_time_test.go b/pkg/movingaverage/avg_over_time_test.go index 74e54974656..9006fea5d5d 100644 --- a/pkg/movingaverage/avg_over_time_test.go +++ b/pkg/movingaverage/avg_over_time_test.go @@ -16,16 +16,14 @@ package movingaverage import ( "math/rand" + "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testAvgOverTimeSuite{}) - -type testAvgOverTimeSuite struct{} - -func (t *testAvgOverTimeSuite) TestPulse(c *C) { +func TestPulse(t *testing.T) { + re := require.New(t) aot := NewAvgOverTime(5 * time.Second) // warm up for i := 0; i < 5; i++ { @@ -38,27 +36,28 @@ func (t *testAvgOverTimeSuite) TestPulse(c *C) { } else { aot.Add(0, time.Second) } - c.Assert(aot.Get(), LessEqual, 600.) - c.Assert(aot.Get(), GreaterEqual, 400.) + re.LessOrEqual(aot.Get(), 600.) + re.GreaterOrEqual(aot.Get(), 400.) } } -func (t *testAvgOverTimeSuite) TestChange(c *C) { +func TestChange(t *testing.T) { + re := require.New(t) aot := NewAvgOverTime(5 * time.Second) // phase 1: 1000 for i := 0; i < 20; i++ { aot.Add(1000, time.Second) } - c.Assert(aot.Get(), LessEqual, 1010.) - c.Assert(aot.Get(), GreaterEqual, 990.) + re.LessOrEqual(aot.Get(), 1010.) + re.GreaterOrEqual(aot.Get(), 990.) // phase 2: 500 for i := 0; i < 5; i++ { aot.Add(500, time.Second) } - c.Assert(aot.Get(), LessEqual, 900.) - c.Assert(aot.Get(), GreaterEqual, 495.) + re.LessOrEqual(aot.Get(), 900.) + re.GreaterOrEqual(aot.Get(), 495.) for i := 0; i < 15; i++ { aot.Add(500, time.Second) } @@ -67,32 +66,34 @@ func (t *testAvgOverTimeSuite) TestChange(c *C) { for i := 0; i < 5; i++ { aot.Add(100, time.Second) } - c.Assert(aot.Get(), LessEqual, 678.) - c.Assert(aot.Get(), GreaterEqual, 99.) + re.LessOrEqual(aot.Get(), 678.) + re.GreaterOrEqual(aot.Get(), 99.) // clear aot.Set(10) - c.Assert(aot.Get(), Equals, 10.) + re.Equal(10., aot.Get()) } -func (t *testAvgOverTimeSuite) TestMinFilled(c *C) { +func TestMinFilled(t *testing.T) { + re := require.New(t) interval := 10 * time.Second rate := 1.0 for aotSize := 2; aotSize < 10; aotSize++ { for mfSize := 2; mfSize < 10; mfSize++ { tm := NewTimeMedian(aotSize, mfSize, interval) for i := 0; i < tm.GetFilledPeriod(); i++ { - c.Assert(tm.Get(), Equals, 0.0) + re.Equal(0.0, tm.Get()) tm.Add(rate*interval.Seconds(), interval) } - c.Assert(tm.Get(), Equals, rate) + re.Equal(rate, tm.Get()) } } } -func (t *testAvgOverTimeSuite) TestUnstableInterval(c *C) { +func TestUnstableInterval(t *testing.T) { + re := require.New(t) aot := NewAvgOverTime(5 * time.Second) - c.Assert(aot.Get(), Equals, 0.) + re.Equal(0., aot.Get()) // warm up for i := 0; i < 5; i++ { aot.Add(1000, time.Second) @@ -101,8 +102,8 @@ func (t *testAvgOverTimeSuite) TestUnstableInterval(c *C) { for i := 0; i < 1000; i++ { r := float64(rand.Intn(5)) aot.Add(1000*r, time.Second*time.Duration(r)) - c.Assert(aot.Get(), LessEqual, 1010.) - c.Assert(aot.Get(), GreaterEqual, 990.) + re.LessOrEqual(aot.Get(), 1010.) + re.GreaterOrEqual(aot.Get(), 990.) } // warm up for i := 0; i < 5; i++ { @@ -112,7 +113,7 @@ func (t *testAvgOverTimeSuite) TestUnstableInterval(c *C) { for i := 0; i < 1000; i++ { rate := float64(i%5*100) + 500 aot.Add(rate*3, time.Second*3) - c.Assert(aot.Get(), LessEqual, 910.) - c.Assert(aot.Get(), GreaterEqual, 490.) + re.LessOrEqual(aot.Get(), 910.) + re.GreaterOrEqual(aot.Get(), 490.) } } diff --git a/pkg/movingaverage/max_filter_test.go b/pkg/movingaverage/max_filter_test.go index 5651bbb4b8d..7d3906ec93c 100644 --- a/pkg/movingaverage/max_filter_test.go +++ b/pkg/movingaverage/max_filter_test.go @@ -15,22 +15,21 @@ package movingaverage import ( - . "github.com/pingcap/check" -) - -var _ = Suite(&testMaxFilter{}) + "testing" -type testMaxFilter struct{} + "github.com/stretchr/testify/require" +) -func (t *testMaxFilter) TestMaxFilter(c *C) { +func TestMaxFilter(t *testing.T) { + re := require.New(t) var empty float64 = 0 data := []float64{2, 1, 3, 4, 1, 1, 3, 3, 2, 0, 5} expected := []float64{2, 2, 3, 4, 4, 4, 4, 4, 3, 3, 5} mf := NewMaxFilter(5) - c.Assert(mf.Get(), Equals, empty) + re.Equal(empty, mf.Get()) - checkReset(c, mf, empty) - checkAdd(c, mf, data, expected) - checkSet(c, mf, data, expected) + checkReset(re, mf, empty) + checkAdd(re, mf, data, expected) + checkSet(re, mf, data, expected) } diff --git a/pkg/movingaverage/moving_average_test.go b/pkg/movingaverage/moving_average_test.go index 8ef6d89a670..e54aa70b64a 100644 --- a/pkg/movingaverage/moving_average_test.go +++ b/pkg/movingaverage/moving_average_test.go @@ -20,17 +20,9 @@ import ( "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testMovingAvg{}) - -type testMovingAvg struct{} - func addRandData(ma MovingAvg, n int, mx float64) { rand.Seed(time.Now().UnixNano()) for i := 0; i < n; i++ { @@ -40,55 +32,56 @@ func addRandData(ma MovingAvg, n int, mx float64) { // checkReset checks the Reset works properly. // emptyValue is the moving average of empty data set. -func checkReset(c *C, ma MovingAvg, emptyValue float64) { +func checkReset(re *require.Assertions, ma MovingAvg, emptyValue float64) { addRandData(ma, 100, 1000) ma.Reset() - c.Assert(ma.Get(), Equals, emptyValue) + re.Equal(emptyValue, ma.Get()) } // checkAddGet checks Add works properly. -func checkAdd(c *C, ma MovingAvg, data []float64, expected []float64) { - c.Assert(len(data), Equals, len(expected)) +func checkAdd(re *require.Assertions, ma MovingAvg, data []float64, expected []float64) { + re.Equal(len(expected), len(data)) for i, x := range data { ma.Add(x) - c.Assert(math.Abs(ma.Get()-expected[i]), LessEqual, 1e-7) + re.LessOrEqual(math.Abs(ma.Get()-expected[i]), 1e-7) } } // checkSet checks Set = Reset + Add -func checkSet(c *C, ma MovingAvg, data []float64, expected []float64) { - c.Assert(len(data), Equals, len(expected)) +func checkSet(re *require.Assertions, ma MovingAvg, data []float64, expected []float64) { + re.Equal(len(expected), len(data)) // Reset + Add addRandData(ma, 100, 1000) ma.Reset() - checkAdd(c, ma, data, expected) + checkAdd(re, ma, data, expected) // Set addRandData(ma, 100, 1000) ma.Set(data[0]) - c.Assert(ma.Get(), Equals, expected[0]) - checkAdd(c, ma, data[1:], expected[1:]) + re.Equal(expected[0], ma.Get()) + checkAdd(re, ma, data[1:], expected[1:]) } // checkInstantaneous checks GetInstantaneous -func checkInstantaneous(c *C, ma MovingAvg) { +func checkInstantaneous(re *require.Assertions, ma MovingAvg) { value := 100.000000 ma.Add(value) - c.Assert(ma.GetInstantaneous(), Equals, value) + re.Equal(value, ma.GetInstantaneous()) } -func (t *testMovingAvg) TestMedianFilter(c *C) { +func TestMedianFilter(t *testing.T) { + re := require.New(t) var empty float64 = 0 data := []float64{2, 4, 2, 800, 600, 6, 3} expected := []float64{2, 3, 2, 3, 4, 6, 6} mf := NewMedianFilter(5) - c.Assert(mf.Get(), Equals, empty) + re.Equal(empty, mf.Get()) - checkReset(c, mf, empty) - checkAdd(c, mf, data, expected) - checkSet(c, mf, data, expected) + checkReset(re, mf, empty) + checkAdd(re, mf, data, expected) + checkSet(re, mf, data, expected) } type testCase struct { @@ -96,7 +89,8 @@ type testCase struct { expected []float64 } -func (t *testMovingAvg) TestMovingAvg(c *C) { +func TestMovingAvg(t *testing.T) { + re := require.New(t) var empty float64 = 0 data := []float64{1, 1, 1, 1, 5, 1, 1, 1} testCases := []testCase{{ @@ -116,11 +110,11 @@ func (t *testMovingAvg) TestMovingAvg(c *C) { expected: []float64{1.000000, 1.000000, 1.000000, 1.000000, 5.000000, 5.000000, 5.000000, 5.000000}, }, } - for _, test := range testCases { - c.Assert(test.ma.Get(), Equals, empty) - checkReset(c, test.ma, empty) - checkAdd(c, test.ma, data, test.expected) - checkSet(c, test.ma, data, test.expected) - checkInstantaneous(c, test.ma) + for _, testCase := range testCases { + re.Equal(empty, testCase.ma.Get()) + checkReset(re, testCase.ma, empty) + checkAdd(re, testCase.ma, data, testCase.expected) + checkSet(re, testCase.ma, data, testCase.expected) + checkInstantaneous(re, testCase.ma) } } diff --git a/pkg/movingaverage/queue_test.go b/pkg/movingaverage/queue_test.go index 90769bb1249..56c2337c9a1 100644 --- a/pkg/movingaverage/queue_test.go +++ b/pkg/movingaverage/queue_test.go @@ -15,26 +15,30 @@ package movingaverage import ( - . "github.com/pingcap/check" + "testing" + + "github.com/stretchr/testify/require" ) -func (t *testMovingAvg) TestQueue(c *C) { +func TestQueue(t *testing.T) { + re := require.New(t) sq := NewSafeQueue() sq.PushBack(1) sq.PushBack(2) v1 := sq.PopFront() v2 := sq.PopFront() - c.Assert(1, Equals, v1.(int)) - c.Assert(2, Equals, v2.(int)) + re.Equal(1, v1.(int)) + re.Equal(2, v2.(int)) } -func (t *testMovingAvg) TestClone(c *C) { +func TestClone(t *testing.T) { + re := require.New(t) s1 := NewSafeQueue() s1.PushBack(1) s1.PushBack(2) s2 := s1.Clone() s2.PopFront() s2.PopFront() - c.Assert(s1.que.Len(), Equals, 2) - c.Assert(s2.que.Len(), Equals, 0) + re.Equal(2, s1.que.Len()) + re.Equal(0, s2.que.Len()) } diff --git a/pkg/netutil/address_test.go b/pkg/netutil/address_test.go index 8c93be2a124..477f794c243 100644 --- a/pkg/netutil/address_test.go +++ b/pkg/netutil/address_test.go @@ -18,18 +18,11 @@ import ( "net/http" "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testNetSuite{}) - -type testNetSuite struct{} - -func (s *testNetSuite) TestResolveLoopBackAddr(c *C) { +func TestResolveLoopBackAddr(t *testing.T) { + re := require.New(t) nodes := []struct { address string backAddress string @@ -41,24 +34,25 @@ func (s *testNetSuite) TestResolveLoopBackAddr(c *C) { } for _, n := range nodes { - c.Assert(ResolveLoopBackAddr(n.address, n.backAddress), Equals, "192.168.130.22:2379") + re.Equal("192.168.130.22:2379", ResolveLoopBackAddr(n.address, n.backAddress)) } } -func (s *testNetSuite) TestIsEnableHttps(c *C) { - c.Assert(IsEnableHTTPS(http.DefaultClient), IsFalse) +func TestIsEnableHttps(t *testing.T) { + re := require.New(t) + re.False(IsEnableHTTPS(http.DefaultClient)) httpClient := &http.Client{ Transport: &http.Transport{ DisableKeepAlives: true, TLSClientConfig: nil, }, } - c.Assert(IsEnableHTTPS(httpClient), IsFalse) + re.False(IsEnableHTTPS(httpClient)) httpClient = &http.Client{ Transport: &http.Transport{ DisableKeepAlives: true, TLSClientConfig: &tls.Config{}, }, } - c.Assert(IsEnableHTTPS(httpClient), IsFalse) + re.False(IsEnableHTTPS(httpClient)) } diff --git a/pkg/progress/progress_test.go b/pkg/progress/progress_test.go index c4b030941f8..c6fb89bbc8b 100644 --- a/pkg/progress/progress_test.go +++ b/pkg/progress/progress_test.go @@ -20,82 +20,76 @@ import ( "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testProgressSuite{}) - -type testProgressSuite struct{} - -func (s *testProgressSuite) Test(c *C) { + re := require.New(t) n := "test" m := NewManager() - c.Assert(m.AddProgress(n, 100, 100, 10*time.Second), IsFalse) + re.False(m.AddProgress(n, 100, 100, 10*time.Second)) p, ls, cs, err := m.Status(n) - c.Assert(err, IsNil) - c.Assert(p, Equals, 0.0) - c.Assert(ls, Equals, math.MaxFloat64) - c.Assert(cs, Equals, 0.0) + re.NoError(err) + re.Equal(0.0, p) + re.Equal(math.MaxFloat64, ls) + re.Equal(0.0, cs) time.Sleep(time.Second) - c.Assert(m.AddProgress(n, 100, 100, 10*time.Second), IsTrue) + re.True(m.AddProgress(n, 100, 100, 10*time.Second)) m.UpdateProgress(n, 30, 30, false) p, ls, cs, err = m.Status(n) - c.Assert(err, IsNil) - c.Assert(p, Equals, 0.7) + re.NoError(err) + re.Equal(0.7, p) // 30/(70/1s+) > 30/70 - c.Assert(ls, Greater, 30.0/70.0) + re.Greater(ls, 30.0/70.0) // 70/1s+ > 70 - c.Assert(cs, Less, 70.0) + re.Less(cs, 70.0) // there is no scheduling for i := 0; i < 100; i++ { m.UpdateProgress(n, 30, 30, false) } - c.Assert(m.progesses[n].history.Len(), Equals, 61) + re.Equal(61, m.progesses[n].history.Len()) p, ls, cs, err = m.Status(n) - c.Assert(err, IsNil) - c.Assert(p, Equals, 0.7) - c.Assert(ls, Equals, math.MaxFloat64) - c.Assert(cs, Equals, 0.0) + re.NoError(err) + re.Equal(0.7, p) + re.Equal(math.MaxFloat64, ls) + re.Equal(0.0, cs) ps := m.GetProgresses(func(p string) bool { return strings.Contains(p, n) }) - c.Assert(ps, HasLen, 1) - c.Assert(ps[0], Equals, n) + re.Len(ps, 1) + re.Equal(n, ps[0]) ps = m.GetProgresses(func(p string) bool { return strings.Contains(p, "a") }) - c.Assert(ps, HasLen, 0) - c.Assert(m.RemoveProgress(n), IsTrue) - c.Assert(m.RemoveProgress(n), IsFalse) + re.Len(ps, 0) + re.True(m.RemoveProgress(n)) + re.False(m.RemoveProgress(n)) } -func (s *testProgressSuite) TestAbnormal(c *C) { +func TestAbnormal(t *testing.T) { + re := require.New(t) n := "test" m := NewManager() - c.Assert(m.AddProgress(n, 100, 100, 10*time.Second), IsFalse) + re.False(m.AddProgress(n, 100, 100, 10*time.Second)) p, ls, cs, err := m.Status(n) - c.Assert(err, IsNil) - c.Assert(p, Equals, 0.0) - c.Assert(ls, Equals, math.MaxFloat64) - c.Assert(cs, Equals, 0.0) + re.NoError(err) + re.Equal(0.0, p) + re.Equal(math.MaxFloat64, ls) + re.Equal(0.0, cs) // When offline a store, but there are still many write operations m.UpdateProgress(n, 110, 110, false) p, ls, cs, err = m.Status(n) - c.Assert(err, IsNil) - c.Assert(p, Equals, 0.0) - c.Assert(ls, Equals, math.MaxFloat64) - c.Assert(cs, Equals, 0.0) + re.NoError(err) + re.Equal(0.0, p) + re.Equal(math.MaxFloat64, ls) + re.Equal(0.0, cs) // It usually won't happens m.UpdateProgressTotal(n, 10) p, ls, cs, err = m.Status(n) - c.Assert(err, NotNil) - c.Assert(p, Equals, 0.0) - c.Assert(ls, Equals, 0.0) - c.Assert(cs, Equals, 0.0) + re.Error(err) + re.Equal(0.0, p) + re.Equal(0.0, ls) + re.Equal(0.0, cs) } From 0d05fba64372408c5af9a78cd9b47b58c4925ca9 Mon Sep 17 00:00:00 2001 From: Connor Date: Wed, 1 Jun 2022 19:06:27 +0800 Subject: [PATCH 11/82] unsafe recovery: Fix force leader stage's infinite retry (#5088) close tikv/pd#5085 Fix unsafe recovery infinite retry on force leader stage Signed-off-by: Connor1996 --- server/cluster/unsafe_recovery_controller.go | 12 ++++ .../unsafe_recovery_controller_test.go | 57 ++++++++++++++----- 2 files changed, 55 insertions(+), 14 deletions(-) diff --git a/server/cluster/unsafe_recovery_controller.go b/server/cluster/unsafe_recovery_controller.go index 2a84535624c..9782d1a20d8 100644 --- a/server/cluster/unsafe_recovery_controller.go +++ b/server/cluster/unsafe_recovery_controller.go @@ -454,6 +454,7 @@ func (u *unsafeRecoveryController) changeStage(stage unsafeRecoveryStage) { stores += ", " } } + // TODO: clean up existing operators output.Info = fmt.Sprintf("Unsafe recovery enters collect report stage: failed stores %s", stores) case tombstoneTiFlashLearner: output.Info = "Unsafe recovery enters tombstone TiFlash learner stage" @@ -967,6 +968,17 @@ func (u *unsafeRecoveryController) generateForceLeaderPlan(newestRegionTree *reg return true }) + if hasPlan { + for storeID := range u.storeReports { + plan := u.getRecoveryPlan(storeID) + if plan.ForceLeader == nil { + // Fill an empty force leader plan to the stores that doesn't have any force leader plan + // to avoid exiting existing force leaders. + plan.ForceLeader = &pdpb.ForceLeader{} + } + } + } + return hasPlan } diff --git a/server/cluster/unsafe_recovery_controller_test.go b/server/cluster/unsafe_recovery_controller_test.go index 8b70b4cd0b6..edd6bf9c187 100644 --- a/server/cluster/unsafe_recovery_controller_test.go +++ b/server/cluster/unsafe_recovery_controller_test.go @@ -328,14 +328,14 @@ func (s *testUnsafeRecoverySuite) TestForceLeaderFail(c *C) { cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) cluster.coordinator = newCoordinator(s.ctx, cluster, hbstream.NewTestHeartbeatStreams(s.ctx, cluster.meta.GetId(), cluster, true)) cluster.coordinator.run() - for _, store := range newTestStores(3, "6.0.0") { + for _, store := range newTestStores(4, "6.0.0") { c.Assert(cluster.PutStore(store.GetMeta()), IsNil) } recoveryController := newUnsafeRecoveryController(cluster) c.Assert(recoveryController.RemoveFailedStores(map[uint64]struct{}{ - 2: {}, 3: {}, - }, 1), IsNil) + 4: {}, + }, 60), IsNil) reports := map[uint64]*pdpb.StoreReport{ 1: { @@ -345,28 +345,57 @@ func (s *testUnsafeRecoverySuite) TestForceLeaderFail(c *C) { RegionState: &raft_serverpb.RegionLocalState{ Region: &metapb.Region{ Id: 1001, + StartKey: []byte(""), + EndKey: []byte("x"), RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 1}, Peers: []*metapb.Peer{ - {Id: 11, StoreId: 1}, {Id: 21, StoreId: 2}, {Id: 31, StoreId: 3}}}}}, + {Id: 11, StoreId: 1}, {Id: 21, StoreId: 3}, {Id: 31, StoreId: 4}}}}}, + }, + }, + 2: { + PeerReports: []*pdpb.PeerReport{ + { + RaftState: &raft_serverpb.RaftLocalState{LastIndex: 10, HardState: &eraftpb.HardState{Term: 1, Commit: 10}}, + RegionState: &raft_serverpb.RegionLocalState{ + Region: &metapb.Region{ + Id: 1002, + StartKey: []byte("x"), + EndKey: []byte(""), + RegionEpoch: &metapb.RegionEpoch{ConfVer: 10, Version: 1}, + Peers: []*metapb.Peer{ + {Id: 12, StoreId: 2}, {Id: 22, StoreId: 3}, {Id: 32, StoreId: 4}}}}}, }, }, } - req := newStoreHeartbeat(1, reports[1]) - resp := &pdpb.StoreHeartbeatResponse{} - req.StoreReport.Step = 1 - recoveryController.HandleStoreHeartbeat(req, resp) + req1 := newStoreHeartbeat(1, reports[1]) + resp1 := &pdpb.StoreHeartbeatResponse{} + req1.StoreReport.Step = 1 + recoveryController.HandleStoreHeartbeat(req1, resp1) + req2 := newStoreHeartbeat(2, reports[2]) + resp2 := &pdpb.StoreHeartbeatResponse{} + req2.StoreReport.Step = 1 + recoveryController.HandleStoreHeartbeat(req2, resp2) c.Assert(recoveryController.GetStage(), Equals, forceLeader) + recoveryController.HandleStoreHeartbeat(req1, resp1) - applyRecoveryPlan(c, 1, reports, resp) - // force leader doesn't succeed - reports[1].PeerReports[0].IsForceLeader = false - recoveryController.HandleStoreHeartbeat(req, resp) + // force leader on store 1 succeed + applyRecoveryPlan(c, 1, reports, resp1) + applyRecoveryPlan(c, 2, reports, resp2) + // force leader on store 2 doesn't succeed + reports[2].PeerReports[0].IsForceLeader = false + + // force leader should retry on store 2 + recoveryController.HandleStoreHeartbeat(req1, resp1) + recoveryController.HandleStoreHeartbeat(req2, resp2) c.Assert(recoveryController.GetStage(), Equals, forceLeader) + recoveryController.HandleStoreHeartbeat(req1, resp1) // force leader succeed this time - applyRecoveryPlan(c, 1, reports, resp) - recoveryController.HandleStoreHeartbeat(req, resp) + applyRecoveryPlan(c, 1, reports, resp1) + applyRecoveryPlan(c, 2, reports, resp2) + recoveryController.HandleStoreHeartbeat(req1, resp1) + recoveryController.HandleStoreHeartbeat(req2, resp2) c.Assert(recoveryController.GetStage(), Equals, demoteFailedVoter) } From 3362ce2de17945b5ee5466d0b7a3ce00f2420682 Mon Sep 17 00:00:00 2001 From: Yongbo Jiang Date: Thu, 2 Jun 2022 10:50:27 +0800 Subject: [PATCH 12/82] ratelimiter: remove anonymous attribute to keep threadsafe (#5041) close tikv/pd#5037 remove anonymous attribute to keep threadsafe Signed-off-by: Cabinfever_B Co-authored-by: Ryan Leung Co-authored-by: Ti Chi Robot --- pkg/ratelimit/ratelimiter.go | 48 ++++++++++++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/pkg/ratelimit/ratelimiter.go b/pkg/ratelimit/ratelimiter.go index e15c858009e..b2b6e3a036a 100644 --- a/pkg/ratelimit/ratelimiter.go +++ b/pkg/ratelimit/ratelimiter.go @@ -15,6 +15,7 @@ package ratelimit import ( + "context" "time" "github.com/tikv/pd/pkg/syncutil" @@ -25,14 +26,14 @@ import ( // It implements `Available` function which is not included in `golang.org/x/time/rate`. // Note: AvailableN will increase the wait time of WaitN. type RateLimiter struct { - mu syncutil.Mutex - *rate.Limiter + mu syncutil.Mutex + limiter *rate.Limiter } // NewRateLimiter returns a new Limiter that allows events up to rate r (it means limiter refill r token per second) // and permits bursts of at most b tokens. func NewRateLimiter(r float64, b int) *RateLimiter { - return &RateLimiter{Limiter: rate.NewLimiter(rate.Limit(r), b)} + return &RateLimiter{limiter: rate.NewLimiter(rate.Limit(r), b)} } // Available returns whether limiter has enough tokens. @@ -41,7 +42,7 @@ func (l *RateLimiter) Available(n int) bool { l.mu.Lock() defer l.mu.Unlock() now := time.Now() - r := l.Limiter.ReserveN(now, n) + r := l.limiter.ReserveN(now, n) delay := r.DelayFrom(now) r.CancelAt(now) return delay == 0 @@ -57,5 +58,42 @@ func (l *RateLimiter) AllowN(n int) bool { l.mu.Lock() defer l.mu.Unlock() now := time.Now() - return l.Limiter.AllowN(now, n) + return l.limiter.AllowN(now, n) +} + +// SetBurst is shorthand for SetBurstAt(time.Now(), newBurst). +func (l *RateLimiter) SetBurst(burst int) { + l.mu.Lock() + defer l.mu.Unlock() + l.limiter.SetBurst(burst) +} + +// SetLimit is shorthand for SetLimitAt(time.Now(), newLimit). +func (l *RateLimiter) SetLimit(limit rate.Limit) { + l.mu.Lock() + defer l.mu.Unlock() + l.limiter.SetLimit(limit) +} + +// Limit returns the maximum overall event rate. +func (l *RateLimiter) Limit() rate.Limit { + return l.limiter.Limit() +} + +// Burst returns the maximum burst size. Burst is the maximum number of tokens +// that can be consumed in a single call to Allow, Reserve, or Wait, so higher +// Burst values allow more events to happen at once. +// A zero Burst allows no events, unless limit == Inf. +func (l *RateLimiter) Burst() int { + return l.limiter.Burst() +} + +// WaitN blocks until lim permits n events to happen. +// It returns an error if n exceeds the Limiter's burst size, the Context is +// canceled, or the expected wait time exceeds the Context's Deadline. +// The burst limit is ignored if the rate limit is Inf. +func (l *RateLimiter) WaitN(ctx context.Context, n int) error { + l.mu.Lock() + defer l.mu.Unlock() + return l.limiter.WaitN(ctx, n) } From d7435da069ab96b91b9e6a6686a394cce6f42390 Mon Sep 17 00:00:00 2001 From: matchge <74505524+matchge-ca@users.noreply.github.com> Date: Wed, 1 Jun 2022 23:02:27 -0400 Subject: [PATCH 13/82] Add the paused and resume timestamp details for paused scheduler(s) (#5071) close tikv/pd#4487 Signed-off-by: Hua Lu Co-authored-by: Hua Lu Co-authored-by: Ti Chi Robot --- server/api/scheduler.go | 38 ++++++++++++++++++- server/api/scheduler_test.go | 5 +++ server/cluster/cluster.go | 10 +++++ server/cluster/coordinator.go | 49 ++++++++++++++++++++++++- server/cluster/coordinator_test.go | 5 +++ server/handler.go | 18 +++++++++ tools/pd-ctl/pdctl/command/scheduler.go | 4 ++ 7 files changed, 125 insertions(+), 4 deletions(-) diff --git a/server/api/scheduler.go b/server/api/scheduler.go index ecf0d51a2a6..5faa01c764b 100644 --- a/server/api/scheduler.go +++ b/server/api/scheduler.go @@ -18,6 +18,7 @@ import ( "fmt" "net/http" "strings" + "time" "github.com/gorilla/mux" "github.com/pingcap/errors" @@ -44,6 +45,12 @@ func newSchedulerHandler(svr *server.Server, r *render.Render) *schedulerHandler } } +type schedulerPausedPeriod struct { + Name string `json:"name"` + PausedAt time.Time `json:"paused_at"` + ResumeAt time.Time `json:"resume_at"` +} + // @Tags scheduler // @Summary List all created schedulers by status. // @Produce json @@ -58,9 +65,11 @@ func (h *schedulerHandler) GetSchedulers(w http.ResponseWriter, r *http.Request) } status := r.URL.Query().Get("status") + _, tsFlag := r.URL.Query()["timestamp"] switch status { case "paused": var pausedSchedulers []string + pausedPeriods := []schedulerPausedPeriod{} for _, scheduler := range schedulers { paused, err := h.Handler.IsSchedulerPaused(scheduler) if err != nil { @@ -69,10 +78,35 @@ func (h *schedulerHandler) GetSchedulers(w http.ResponseWriter, r *http.Request) } if paused { - pausedSchedulers = append(pausedSchedulers, scheduler) + if tsFlag { + s := schedulerPausedPeriod{ + Name: scheduler, + PausedAt: time.Time{}, + ResumeAt: time.Time{}, + } + pausedAt, err := h.Handler.GetPausedSchedulerDelayAt(scheduler) + if err != nil { + h.r.JSON(w, http.StatusInternalServerError, err.Error()) + return + } + s.PausedAt = time.Unix(pausedAt, 0) + resumeAt, err := h.Handler.GetPausedSchedulerDelayUntil(scheduler) + if err != nil { + h.r.JSON(w, http.StatusInternalServerError, err.Error()) + return + } + s.ResumeAt = time.Unix(resumeAt, 0) + pausedPeriods = append(pausedPeriods, s) + } else { + pausedSchedulers = append(pausedSchedulers, scheduler) + } } } - h.r.JSON(w, http.StatusOK, pausedSchedulers) + if tsFlag { + h.r.JSON(w, http.StatusOK, pausedPeriods) + } else { + h.r.JSON(w, http.StatusOK, pausedSchedulers) + } return case "disabled": var disabledSchedulers []string diff --git a/server/api/scheduler_test.go b/server/api/scheduler_test.go index 04a2aee900e..8c20bdf6182 100644 --- a/server/api/scheduler_test.go +++ b/server/api/scheduler_test.go @@ -490,6 +490,11 @@ func (s *testScheduleSuite) testPauseOrResume(name, createdName string, body []b c.Assert(err, IsNil) err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/"+createdName, pauseArgs, tu.StatusOK(c)) c.Assert(err, IsNil) + pausedAt, err := handler.GetPausedSchedulerDelayAt(createdName) + c.Assert(err, IsNil) + resumeAt, err := handler.GetPausedSchedulerDelayUntil(createdName) + c.Assert(err, IsNil) + c.Assert(resumeAt-pausedAt, Equals, int64(1)) time.Sleep(time.Second) isPaused, err = handler.IsSchedulerPaused(createdName) c.Assert(err, IsNil) diff --git a/server/cluster/cluster.go b/server/cluster/cluster.go index 5a24444a1ad..4ff1232752e 100644 --- a/server/cluster/cluster.go +++ b/server/cluster/cluster.go @@ -2351,3 +2351,13 @@ func newCacheCluster(c *RaftCluster) *cacheCluster { stores: c.GetStores(), } } + +// GetPausedSchedulerDelayAt returns DelayAt of a paused scheduler +func (c *RaftCluster) GetPausedSchedulerDelayAt(name string) (int64, error) { + return c.coordinator.getPausedSchedulerDelayAt(name) +} + +// GetPausedSchedulerDelayUntil returns DelayUntil of a paused scheduler +func (c *RaftCluster) GetPausedSchedulerDelayUntil(name string) (int64, error) { + return c.coordinator.getPausedSchedulerDelayUntil(name) +} diff --git a/server/cluster/coordinator.go b/server/cluster/coordinator.go index 108a538034e..530e858877f 100644 --- a/server/cluster/coordinator.go +++ b/server/cluster/coordinator.go @@ -718,10 +718,12 @@ func (c *coordinator) pauseOrResumeScheduler(name string, t int64) error { } var err error for _, sc := range s { - var delayUntil int64 + var delayAt, delayUntil int64 if t > 0 { - delayUntil = time.Now().Unix() + t + delayAt = time.Now().Unix() + delayUntil = delayAt + t } + atomic.StoreInt64(&sc.delayAt, delayAt) atomic.StoreInt64(&sc.delayUntil, delayUntil) } return err @@ -851,6 +853,7 @@ type scheduleController struct { nextInterval time.Duration ctx context.Context cancel context.CancelFunc + delayAt int64 delayUntil int64 } @@ -909,3 +912,45 @@ func (s *scheduleController) IsPaused() bool { delayUntil := atomic.LoadInt64(&s.delayUntil) return time.Now().Unix() < delayUntil } + +// GetPausedSchedulerDelayAt returns paused timestamp of a paused scheduler +func (s *scheduleController) GetDelayAt() int64 { + if s.IsPaused() { + return atomic.LoadInt64(&s.delayAt) + } + return 0 +} + +// GetPausedSchedulerDelayUntil returns resume timestamp of a paused scheduler +func (s *scheduleController) GetDelayUntil() int64 { + if s.IsPaused() { + return atomic.LoadInt64(&s.delayUntil) + } + return 0 +} + +func (c *coordinator) getPausedSchedulerDelayAt(name string) (int64, error) { + c.RLock() + defer c.RUnlock() + if c.cluster == nil { + return -1, errs.ErrNotBootstrapped.FastGenByArgs() + } + s, ok := c.schedulers[name] + if !ok { + return -1, errs.ErrSchedulerNotFound.FastGenByArgs() + } + return s.GetDelayAt(), nil +} + +func (c *coordinator) getPausedSchedulerDelayUntil(name string) (int64, error) { + c.RLock() + defer c.RUnlock() + if c.cluster == nil { + return -1, errs.ErrNotBootstrapped.FastGenByArgs() + } + s, ok := c.schedulers[name] + if !ok { + return -1, errs.ErrSchedulerNotFound.FastGenByArgs() + } + return s.GetDelayUntil(), nil +} diff --git a/server/cluster/coordinator_test.go b/server/cluster/coordinator_test.go index 20ab1f4f8fa..b234374a765 100644 --- a/server/cluster/coordinator_test.go +++ b/server/cluster/coordinator_test.go @@ -907,6 +907,11 @@ func (s *testCoordinatorSuite) TestPauseScheduler(c *C) { co.pauseOrResumeScheduler(schedulers.BalanceLeaderName, 60) paused, _ := co.isSchedulerPaused(schedulers.BalanceLeaderName) c.Assert(paused, Equals, true) + pausedAt, err := co.getPausedSchedulerDelayAt(schedulers.BalanceLeaderName) + c.Assert(err, IsNil) + resumeAt, err := co.getPausedSchedulerDelayUntil(schedulers.BalanceLeaderName) + c.Assert(err, IsNil) + c.Assert(resumeAt-pausedAt, Equals, int64(60)) allowed, _ := co.isSchedulerAllowed(schedulers.BalanceLeaderName) c.Assert(allowed, Equals, false) } diff --git a/server/handler.go b/server/handler.go index 8567c6ec12b..238d0dfdcc3 100644 --- a/server/handler.go +++ b/server/handler.go @@ -1123,3 +1123,21 @@ func (h *Handler) AddEvictOrGrant(storeID float64, name string) error { } return nil } + +// GetPausedSchedulerDelayAt returns paused unix timestamp when a scheduler is paused +func (h *Handler) GetPausedSchedulerDelayAt(name string) (int64, error) { + rc, err := h.GetRaftCluster() + if err != nil { + return -1, err + } + return rc.GetPausedSchedulerDelayAt(name) +} + +// GetPausedSchedulerDelayUntil returns resume unix timestamp when a scheduler is paused +func (h *Handler) GetPausedSchedulerDelayUntil(name string) (int64, error) { + rc, err := h.GetRaftCluster() + if err != nil { + return -1, err + } + return rc.GetPausedSchedulerDelayUntil(name) +} diff --git a/tools/pd-ctl/pdctl/command/scheduler.go b/tools/pd-ctl/pdctl/command/scheduler.go index 5c334d86dae..0a0d9635021 100644 --- a/tools/pd-ctl/pdctl/command/scheduler.go +++ b/tools/pd-ctl/pdctl/command/scheduler.go @@ -104,6 +104,7 @@ func NewShowSchedulerCommand() *cobra.Command { Run: showSchedulerCommandFunc, } c.Flags().String("status", "", "the scheduler status value can be [paused | disabled]") + c.Flags().BoolP("timestamp", "t", false, "fetch the paused and resume timestamp for paused scheduler(s)") return c } @@ -116,6 +117,9 @@ func showSchedulerCommandFunc(cmd *cobra.Command, args []string) { url := schedulersPrefix if flag := cmd.Flag("status"); flag != nil && flag.Value.String() != "" { url = fmt.Sprintf("%s?status=%s", url, flag.Value.String()) + if tsFlag, _ := cmd.Flags().GetBool("timestamp"); tsFlag { + url += "×tamp=true" + } } r, err := doRequest(cmd, url, http.MethodGet, http.Header{}) if err != nil { From fc55b446f9ad0b32119e86b68e1353ba78c928aa Mon Sep 17 00:00:00 2001 From: JmPotato Date: Thu, 2 Jun 2022 11:56:28 +0800 Subject: [PATCH 14/82] * : testify all pkg tests (#5091) ref tikv/pd#4813 Testify all pkg tests. Signed-off-by: JmPotato Co-authored-by: Ti Chi Robot --- pkg/progress/progress_test.go | 2 +- pkg/rangetree/range_tree_test.go | 57 ++++----- pkg/ratelimit/concurrency_limiter_test.go | 29 ++--- pkg/ratelimit/limiter_test.go | 138 +++++++++++----------- pkg/ratelimit/ratelimiter_test.go | 38 +++--- pkg/reflectutil/tag_test.go | 37 +++--- pkg/requestutil/context_test.go | 45 +++---- pkg/slice/slice_test.go | 41 +++---- pkg/systimemon/systimemon_test.go | 6 +- pkg/typeutil/comparison_test.go | 37 +++--- pkg/typeutil/conversion_test.go | 54 ++++----- pkg/typeutil/duration_test.go | 25 ++-- pkg/typeutil/size_test.go | 30 +++-- pkg/typeutil/string_slice_test.go | 31 +++-- pkg/typeutil/time_test.go | 46 ++++---- 15 files changed, 279 insertions(+), 337 deletions(-) diff --git a/pkg/progress/progress_test.go b/pkg/progress/progress_test.go index c6fb89bbc8b..72d23c40a6a 100644 --- a/pkg/progress/progress_test.go +++ b/pkg/progress/progress_test.go @@ -23,7 +23,7 @@ import ( "github.com/stretchr/testify/require" ) -func Test(t *testing.T) { +func TestProgress(t *testing.T) { re := require.New(t) n := "test" m := NewManager() diff --git a/pkg/rangetree/range_tree_test.go b/pkg/rangetree/range_tree_test.go index d1e9cd79de5..695183f2f90 100644 --- a/pkg/rangetree/range_tree_test.go +++ b/pkg/rangetree/range_tree_test.go @@ -18,19 +18,10 @@ import ( "bytes" "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/btree" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testRangeTreeSuite{}) - -type testRangeTreeSuite struct { -} - type simpleBucketItem struct { startKey []byte endKey []byte @@ -79,7 +70,7 @@ func bucketDebrisFactory(startKey, endKey []byte, item RangeItem) []RangeItem { left := maxKey(startKey, item.GetStartKey()) right := minKey(endKey, item.GetEndKey()) - // they have no intersection if they are neighbour like |010 - 100| and |100 - 200|. + // they have no intersection if they are neighbors like |010 - 100| and |100 - 200|. if bytes.Compare(left, right) >= 0 { return nil } @@ -94,52 +85,54 @@ func bucketDebrisFactory(startKey, endKey []byte, item RangeItem) []RangeItem { return res } -func (bs *testRangeTreeSuite) TestRingPutItem(c *C) { +func TestRingPutItem(t *testing.T) { + re := require.New(t) bucketTree := NewRangeTree(2, bucketDebrisFactory) bucketTree.Update(newSimpleBucketItem([]byte("002"), []byte("100"))) - c.Assert(bucketTree.Len(), Equals, 1) + re.Equal(1, bucketTree.Len()) bucketTree.Update(newSimpleBucketItem([]byte("100"), []byte("200"))) - c.Assert(bucketTree.Len(), Equals, 2) + re.Equal(2, bucketTree.Len()) // init key range: [002,100], [100,200] - c.Assert(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("000"), []byte("002"))), HasLen, 0) - c.Assert(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("000"), []byte("009"))), HasLen, 1) - c.Assert(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("010"), []byte("090"))), HasLen, 1) - c.Assert(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("010"), []byte("110"))), HasLen, 2) - c.Assert(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("200"), []byte("300"))), HasLen, 0) + re.Len(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("000"), []byte("002"))), 0) + re.Len(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("000"), []byte("009"))), 1) + re.Len(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("010"), []byte("090"))), 1) + re.Len(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("010"), []byte("110"))), 2) + re.Len(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("200"), []byte("300"))), 0) // test1: insert one key range, the old overlaps will retain like split buckets. // key range: [002,010],[010,090],[090,100],[100,200] bucketTree.Update(newSimpleBucketItem([]byte("010"), []byte("090"))) - c.Assert(bucketTree.Len(), Equals, 4) - c.Assert(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("010"), []byte("090"))), HasLen, 1) + re.Equal(4, bucketTree.Len()) + re.Len(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("010"), []byte("090"))), 1) // test2: insert one key range, the old overlaps will retain like merge . // key range: [001,080], [080,090],[090,100],[100,200] bucketTree.Update(newSimpleBucketItem([]byte("001"), []byte("080"))) - c.Assert(bucketTree.Len(), Equals, 4) - c.Assert(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("010"), []byte("090"))), HasLen, 2) + re.Equal(4, bucketTree.Len()) + re.Len(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("010"), []byte("090"))), 2) // test2: insert one keyrange, the old overlaps will retain like merge . // key range: [001,120],[120,200] bucketTree.Update(newSimpleBucketItem([]byte("001"), []byte("120"))) - c.Assert(bucketTree.Len(), Equals, 2) - c.Assert(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("010"), []byte("090"))), HasLen, 1) + re.Equal(2, bucketTree.Len()) + re.Len(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("010"), []byte("090"))), 1) } -func (bs *testRangeTreeSuite) TestDebris(c *C) { +func TestDebris(t *testing.T) { + re := require.New(t) ringItem := newSimpleBucketItem([]byte("010"), []byte("090")) var overlaps []RangeItem overlaps = bucketDebrisFactory([]byte("000"), []byte("100"), ringItem) - c.Assert(overlaps, HasLen, 0) + re.Len(overlaps, 0) overlaps = bucketDebrisFactory([]byte("000"), []byte("080"), ringItem) - c.Assert(overlaps, HasLen, 1) + re.Len(overlaps, 1) overlaps = bucketDebrisFactory([]byte("020"), []byte("080"), ringItem) - c.Assert(overlaps, HasLen, 2) + re.Len(overlaps, 2) overlaps = bucketDebrisFactory([]byte("010"), []byte("090"), ringItem) - c.Assert(overlaps, HasLen, 0) + re.Len(overlaps, 0) overlaps = bucketDebrisFactory([]byte("010"), []byte("100"), ringItem) - c.Assert(overlaps, HasLen, 0) + re.Len(overlaps, 0) overlaps = bucketDebrisFactory([]byte("100"), []byte("200"), ringItem) - c.Assert(overlaps, HasLen, 0) + re.Len(overlaps, 0) } diff --git a/pkg/ratelimit/concurrency_limiter_test.go b/pkg/ratelimit/concurrency_limiter_test.go index 86dfda0eef6..6a2a5c80b9c 100644 --- a/pkg/ratelimit/concurrency_limiter_test.go +++ b/pkg/ratelimit/concurrency_limiter_test.go @@ -15,29 +15,24 @@ package ratelimit import ( - . "github.com/pingcap/check" -) - -var _ = Suite(&testConcurrencyLimiterSuite{}) - -type testConcurrencyLimiterSuite struct { -} + "testing" -func (s *testConcurrencyLimiterSuite) TestConcurrencyLimiter(c *C) { - c.Parallel() + "github.com/stretchr/testify/require" +) +func TestConcurrencyLimiter(t *testing.T) { + re := require.New(t) cl := newConcurrencyLimiter(10) - for i := 0; i < 10; i++ { - c.Assert(cl.allow(), Equals, true) + re.True(cl.allow()) } - c.Assert(cl.allow(), Equals, false) + re.False(cl.allow()) cl.release() - c.Assert(cl.allow(), Equals, true) - c.Assert(cl.getLimit(), Equals, uint64(10)) + re.True(cl.allow()) + re.Equal(uint64(10), cl.getLimit()) cl.setLimit(5) - c.Assert(cl.getLimit(), Equals, uint64(5)) - c.Assert(cl.getCurrent(), Equals, uint64(10)) + re.Equal(uint64(5), cl.getLimit()) + re.Equal(uint64(10), cl.getCurrent()) cl.release() - c.Assert(cl.getCurrent(), Equals, uint64(9)) + re.Equal(uint64(9), cl.getCurrent()) } diff --git a/pkg/ratelimit/limiter_test.go b/pkg/ratelimit/limiter_test.go index cf75d76152a..d1a570ccb35 100644 --- a/pkg/ratelimit/limiter_test.go +++ b/pkg/ratelimit/limiter_test.go @@ -16,155 +16,151 @@ package ratelimit import ( "sync" + "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "golang.org/x/time/rate" ) -var _ = Suite(&testRatelimiterSuite{}) - -type testRatelimiterSuite struct { -} - -func (s *testRatelimiterSuite) TestUpdateConcurrencyLimiter(c *C) { - c.Parallel() +func TestUpdateConcurrencyLimiter(t *testing.T) { + re := require.New(t) opts := []Option{UpdateConcurrencyLimiter(10)} limiter := NewLimiter() label := "test" status := limiter.Update(label, opts...) - c.Assert(status&ConcurrencyChanged != 0, IsTrue) + re.True(status&ConcurrencyChanged != 0) var lock sync.Mutex successCount, failedCount := 0, 0 var wg sync.WaitGroup for i := 0; i < 15; i++ { wg.Add(1) go func() { - CountRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) }() } wg.Wait() - c.Assert(failedCount, Equals, 5) - c.Assert(successCount, Equals, 10) + re.Equal(5, failedCount) + re.Equal(10, successCount) for i := 0; i < 10; i++ { limiter.Release(label) } limit, current := limiter.GetConcurrencyLimiterStatus(label) - c.Assert(limit, Equals, uint64(10)) - c.Assert(current, Equals, uint64(0)) + re.Equal(uint64(10), limit) + re.Equal(uint64(0), current) status = limiter.Update(label, UpdateConcurrencyLimiter(10)) - c.Assert(status&ConcurrencyNoChange != 0, IsTrue) + re.True(status&ConcurrencyNoChange != 0) status = limiter.Update(label, UpdateConcurrencyLimiter(5)) - c.Assert(status&ConcurrencyChanged != 0, IsTrue) + re.True(status&ConcurrencyChanged != 0) failedCount = 0 successCount = 0 for i := 0; i < 15; i++ { wg.Add(1) - go CountRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) } wg.Wait() - c.Assert(failedCount, Equals, 10) - c.Assert(successCount, Equals, 5) + re.Equal(10, failedCount) + re.Equal(5, successCount) for i := 0; i < 5; i++ { limiter.Release(label) } status = limiter.Update(label, UpdateConcurrencyLimiter(0)) - c.Assert(status&ConcurrencyDeleted != 0, IsTrue) + re.True(status&ConcurrencyDeleted != 0) failedCount = 0 successCount = 0 for i := 0; i < 15; i++ { wg.Add(1) - go CountRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) } wg.Wait() - c.Assert(failedCount, Equals, 0) - c.Assert(successCount, Equals, 15) + re.Equal(0, failedCount) + re.Equal(15, successCount) limit, current = limiter.GetConcurrencyLimiterStatus(label) - c.Assert(limit, Equals, uint64(0)) - c.Assert(current, Equals, uint64(0)) + re.Equal(uint64(0), limit) + re.Equal(uint64(0), current) } -func (s *testRatelimiterSuite) TestBlockList(c *C) { - c.Parallel() +func TestBlockList(t *testing.T) { + re := require.New(t) opts := []Option{AddLabelAllowList()} limiter := NewLimiter() label := "test" - c.Assert(limiter.IsInAllowList(label), Equals, false) + re.False(limiter.IsInAllowList(label)) for _, opt := range opts { opt(label, limiter) } - c.Assert(limiter.IsInAllowList(label), Equals, true) + re.True(limiter.IsInAllowList(label)) status := UpdateQPSLimiter(float64(rate.Every(time.Second)), 1)(label, limiter) - c.Assert(status&InAllowList != 0, Equals, true) + re.True(status&InAllowList != 0) for i := 0; i < 10; i++ { - c.Assert(limiter.Allow(label), Equals, true) + re.True(limiter.Allow(label)) } } -func (s *testRatelimiterSuite) TestUpdateQPSLimiter(c *C) { - c.Parallel() +func TestUpdateQPSLimiter(t *testing.T) { + re := require.New(t) opts := []Option{UpdateQPSLimiter(float64(rate.Every(time.Second)), 1)} limiter := NewLimiter() label := "test" status := limiter.Update(label, opts...) - c.Assert(status&QPSChanged != 0, IsTrue) + re.True(status&QPSChanged != 0) var lock sync.Mutex successCount, failedCount := 0, 0 var wg sync.WaitGroup wg.Add(3) for i := 0; i < 3; i++ { - go CountRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) } wg.Wait() - c.Assert(failedCount, Equals, 2) - c.Assert(successCount, Equals, 1) + re.Equal(2, failedCount) + re.Equal(1, successCount) limit, burst := limiter.GetQPSLimiterStatus(label) - c.Assert(limit, Equals, rate.Limit(1)) - c.Assert(burst, Equals, 1) + re.Equal(rate.Limit(1), limit) + re.Equal(1, burst) status = limiter.Update(label, UpdateQPSLimiter(float64(rate.Every(time.Second)), 1)) - c.Assert(status&QPSNoChange != 0, IsTrue) + re.True(status&QPSNoChange != 0) status = limiter.Update(label, UpdateQPSLimiter(5, 5)) - c.Assert(status&QPSChanged != 0, IsTrue) + re.True(status&QPSChanged != 0) limit, burst = limiter.GetQPSLimiterStatus(label) - c.Assert(limit, Equals, rate.Limit(5)) - c.Assert(burst, Equals, 5) + re.Equal(rate.Limit(5), limit) + re.Equal(5, burst) time.Sleep(time.Second) for i := 0; i < 10; i++ { if i < 5 { - c.Assert(limiter.Allow(label), Equals, true) + re.True(limiter.Allow(label)) } else { - c.Assert(limiter.Allow(label), Equals, false) + re.False(limiter.Allow(label)) } } time.Sleep(time.Second) status = limiter.Update(label, UpdateQPSLimiter(0, 0)) - c.Assert(status&QPSDeleted != 0, IsTrue) + re.True(status&QPSDeleted != 0) for i := 0; i < 10; i++ { - c.Assert(limiter.Allow(label), Equals, true) + re.True(limiter.Allow(label)) } qLimit, qCurrent := limiter.GetQPSLimiterStatus(label) - c.Assert(qLimit, Equals, rate.Limit(0)) - c.Assert(qCurrent, Equals, 0) + re.Equal(rate.Limit(0), qLimit) + re.Equal(0, qCurrent) } -func (s *testRatelimiterSuite) TestQPSLimiter(c *C) { - c.Parallel() +func TestQPSLimiter(t *testing.T) { + re := require.New(t) opts := []Option{UpdateQPSLimiter(float64(rate.Every(3*time.Second)), 100)} limiter := NewLimiter() @@ -178,22 +174,22 @@ func (s *testRatelimiterSuite) TestQPSLimiter(c *C) { var wg sync.WaitGroup wg.Add(200) for i := 0; i < 200; i++ { - go CountRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) } wg.Wait() - c.Assert(failedCount+successCount, Equals, 200) - c.Assert(failedCount, Equals, 100) - c.Assert(successCount, Equals, 100) + re.Equal(200, failedCount+successCount) + re.Equal(100, failedCount) + re.Equal(100, successCount) time.Sleep(4 * time.Second) // 3+1 wg.Add(1) - CountRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) wg.Wait() - c.Assert(successCount, Equals, 101) + re.Equal(101, successCount) } -func (s *testRatelimiterSuite) TestTwoLimiters(c *C) { - c.Parallel() +func TestTwoLimiters(t *testing.T) { + re := require.New(t) cfg := &DimensionConfig{ QPS: 100, QPSBurst: 100, @@ -212,20 +208,20 @@ func (s *testRatelimiterSuite) TestTwoLimiters(c *C) { var wg sync.WaitGroup wg.Add(200) for i := 0; i < 200; i++ { - go CountRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) } wg.Wait() - c.Assert(failedCount, Equals, 100) - c.Assert(successCount, Equals, 100) + re.Equal(100, failedCount) + re.Equal(100, successCount) time.Sleep(1 * time.Second) wg.Add(100) for i := 0; i < 100; i++ { - go CountRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) } wg.Wait() - c.Assert(failedCount, Equals, 200) - c.Assert(successCount, Equals, 100) + re.Equal(200, failedCount) + re.Equal(100, successCount) for i := 0; i < 100; i++ { limiter.Release(label) @@ -233,17 +229,17 @@ func (s *testRatelimiterSuite) TestTwoLimiters(c *C) { limiter.Update(label, UpdateQPSLimiter(float64(rate.Every(10*time.Second)), 1)) wg.Add(100) for i := 0; i < 100; i++ { - go CountRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) } wg.Wait() - c.Assert(successCount, Equals, 101) - c.Assert(failedCount, Equals, 299) + re.Equal(101, successCount) + re.Equal(299, failedCount) limit, current := limiter.GetConcurrencyLimiterStatus(label) - c.Assert(limit, Equals, uint64(100)) - c.Assert(current, Equals, uint64(1)) + re.Equal(uint64(100), limit) + re.Equal(uint64(1), current) } -func CountRateLimiterHandleResult(limiter *Limiter, label string, successCount *int, +func countRateLimiterHandleResult(limiter *Limiter, label string, successCount *int, failedCount *int, lock *sync.Mutex, wg *sync.WaitGroup) { result := limiter.Allow(label) lock.Lock() diff --git a/pkg/ratelimit/ratelimiter_test.go b/pkg/ratelimit/ratelimiter_test.go index ccc8d05090a..f16bb6a83d2 100644 --- a/pkg/ratelimit/ratelimiter_test.go +++ b/pkg/ratelimit/ratelimiter_test.go @@ -18,34 +18,24 @@ import ( "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testRateLimiterSuite{}) - -type testRateLimiterSuite struct { -} - -func (s *testRateLimiterSuite) TestRateLimiter(c *C) { - c.Parallel() - +func TestRateLimiter(t *testing.T) { + re := require.New(t) limiter := NewRateLimiter(100, 100) - c.Assert(limiter.Available(1), Equals, true) + re.True(limiter.Available(1)) - c.Assert(limiter.AllowN(50), Equals, true) - c.Assert(limiter.Available(50), Equals, true) - c.Assert(limiter.Available(100), Equals, false) - c.Assert(limiter.Available(50), Equals, true) - c.Assert(limiter.AllowN(50), Equals, true) - c.Assert(limiter.Available(50), Equals, false) + re.True(limiter.AllowN(50)) + re.True(limiter.Available(50)) + re.False(limiter.Available(100)) + re.True(limiter.Available(50)) + re.True(limiter.AllowN(50)) + re.False(limiter.Available(50)) time.Sleep(time.Second) - c.Assert(limiter.Available(1), Equals, true) - c.Assert(limiter.AllowN(99), Equals, true) - c.Assert(limiter.Allow(), Equals, true) - c.Assert(limiter.Available(1), Equals, false) + re.True(limiter.Available(1)) + re.True(limiter.AllowN(99)) + re.True(limiter.Allow()) + re.False(limiter.Available(1)) } diff --git a/pkg/reflectutil/tag_test.go b/pkg/reflectutil/tag_test.go index d7ddccd22e8..8e8c4dc7754 100644 --- a/pkg/reflectutil/tag_test.go +++ b/pkg/reflectutil/tag_test.go @@ -18,7 +18,7 @@ import ( "reflect" "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) type testStruct1 struct { @@ -34,41 +34,36 @@ type testStruct3 struct { Enable bool `json:"enable,string"` } -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testTagSuite{}) - -type testTagSuite struct{} - -func (s *testTagSuite) TestFindJSONFullTagByChildTag(c *C) { +func TestFindJSONFullTagByChildTag(t *testing.T) { + re := require.New(t) key := "enable" result := FindJSONFullTagByChildTag(reflect.TypeOf(testStruct1{}), key) - c.Assert(result, Equals, "object.action.enable") + re.Equal("object.action.enable", result) key = "action" result = FindJSONFullTagByChildTag(reflect.TypeOf(testStruct1{}), key) - c.Assert(result, Equals, "object.action") + re.Equal("object.action", result) key = "disable" result = FindJSONFullTagByChildTag(reflect.TypeOf(testStruct1{}), key) - c.Assert(result, HasLen, 0) + re.Len(result, 0) } -func (s *testTagSuite) TestFindSameFieldByJSON(c *C) { +func TestFindSameFieldByJSON(t *testing.T) { + re := require.New(t) input := map[string]interface{}{ "name": "test2", } t2 := testStruct2{} - c.Assert(FindSameFieldByJSON(&t2, input), Equals, true) + re.True(FindSameFieldByJSON(&t2, input)) input = map[string]interface{}{ "enable": "test2", } - c.Assert(FindSameFieldByJSON(&t2, input), Equals, false) + re.False(FindSameFieldByJSON(&t2, input)) } -func (s *testTagSuite) TestFindFieldByJSONTag(c *C) { +func TestFindFieldByJSONTag(t *testing.T) { + re := require.New(t) t1 := testStruct1{} t2 := testStruct2{} t3 := testStruct3{} @@ -77,17 +72,17 @@ func (s *testTagSuite) TestFindFieldByJSONTag(c *C) { tags := []string{"object"} result := FindFieldByJSONTag(reflect.TypeOf(t1), tags) - c.Assert(result, Equals, type2) + re.Equal(type2, result) tags = []string{"object", "action"} result = FindFieldByJSONTag(reflect.TypeOf(t1), tags) - c.Assert(result, Equals, type3) + re.Equal(type3, result) tags = []string{"object", "name"} result = FindFieldByJSONTag(reflect.TypeOf(t1), tags) - c.Assert(result.Kind(), Equals, reflect.String) + re.Equal(reflect.String, result.Kind()) tags = []string{"object", "action", "enable"} result = FindFieldByJSONTag(reflect.TypeOf(t1), tags) - c.Assert(result.Kind(), Equals, reflect.Bool) + re.Equal(reflect.Bool, result.Kind()) } diff --git a/pkg/requestutil/context_test.go b/pkg/requestutil/context_test.go index 98560577183..fe93182d537 100644 --- a/pkg/requestutil/context_test.go +++ b/pkg/requestutil/context_test.go @@ -19,22 +19,14 @@ import ( "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testRequestContextSuite{}) - -type testRequestContextSuite struct { -} - -func (s *testRequestContextSuite) TestRequestInfo(c *C) { +func TestRequestInfo(t *testing.T) { + re := require.New(t) ctx := context.Background() _, ok := RequestInfoFrom(ctx) - c.Assert(ok, Equals, false) + re.False(ok) timeNow := time.Now().Unix() ctx = WithRequestInfo(ctx, RequestInfo{ @@ -47,25 +39,26 @@ func (s *testRequestContextSuite) TestRequestInfo(c *C) { StartTimeStamp: timeNow, }) result, ok := RequestInfoFrom(ctx) - c.Assert(result, NotNil) - c.Assert(ok, Equals, true) - c.Assert(result.ServiceLabel, Equals, "test label") - c.Assert(result.Method, Equals, "POST") - c.Assert(result.Component, Equals, "pdctl") - c.Assert(result.IP, Equals, "localhost") - c.Assert(result.URLParam, Equals, "{\"id\"=1}") - c.Assert(result.BodyParam, Equals, "{\"state\"=\"Up\"}") - c.Assert(result.StartTimeStamp, Equals, timeNow) + re.NotNil(result) + re.True(ok) + re.Equal("test label", result.ServiceLabel) + re.Equal("POST", result.Method) + re.Equal("pdctl", result.Component) + re.Equal("localhost", result.IP) + re.Equal("{\"id\"=1}", result.URLParam) + re.Equal("{\"state\"=\"Up\"}", result.BodyParam) + re.Equal(timeNow, result.StartTimeStamp) } -func (s *testRequestContextSuite) TestEndTime(c *C) { +func TestEndTime(t *testing.T) { + re := require.New(t) ctx := context.Background() _, ok := EndTimeFrom(ctx) - c.Assert(ok, Equals, false) + re.False(ok) timeNow := time.Now().Unix() ctx = WithEndTime(ctx, timeNow) result, ok := EndTimeFrom(ctx) - c.Assert(result, NotNil) - c.Assert(ok, Equals, true) - c.Assert(result, Equals, timeNow) + re.NotNil(result) + re.True(ok) + re.Equal(timeNow, result) } diff --git a/pkg/slice/slice_test.go b/pkg/slice/slice_test.go index 6c7030b977e..809dd2c54b3 100644 --- a/pkg/slice/slice_test.go +++ b/pkg/slice/slice_test.go @@ -17,21 +17,13 @@ package slice_test import ( "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/slice" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testSliceSuite{}) - -type testSliceSuite struct { -} - -func (s *testSliceSuite) Test(c *C) { - tests := []struct { +func TestSlice(t *testing.T) { + re := require.New(t) + testCases := []struct { a []int anyOf bool noneOf bool @@ -43,24 +35,25 @@ func (s *testSliceSuite) Test(c *C) { {[]int{2, 2, 4}, true, false, true}, } - for _, t := range tests { - even := func(i int) bool { return t.a[i]%2 == 0 } - c.Assert(slice.AnyOf(t.a, even), Equals, t.anyOf) - c.Assert(slice.NoneOf(t.a, even), Equals, t.noneOf) - c.Assert(slice.AllOf(t.a, even), Equals, t.allOf) + for _, testCase := range testCases { + even := func(i int) bool { return testCase.a[i]%2 == 0 } + re.Equal(testCase.anyOf, slice.AnyOf(testCase.a, even)) + re.Equal(testCase.noneOf, slice.NoneOf(testCase.a, even)) + re.Equal(testCase.allOf, slice.AllOf(testCase.a, even)) } } -func (s *testSliceSuite) TestSliceContains(c *C) { +func TestSliceContains(t *testing.T) { + re := require.New(t) ss := []string{"a", "b", "c"} - c.Assert(slice.Contains(ss, "a"), IsTrue) - c.Assert(slice.Contains(ss, "d"), IsFalse) + re.Contains(ss, "a") + re.NotContains(ss, "d") us := []uint64{1, 2, 3} - c.Assert(slice.Contains(us, uint64(1)), IsTrue) - c.Assert(slice.Contains(us, uint64(4)), IsFalse) + re.Contains(us, uint64(1)) + re.NotContains(us, uint64(4)) is := []int64{1, 2, 3} - c.Assert(slice.Contains(is, int64(1)), IsTrue) - c.Assert(slice.Contains(is, int64(4)), IsFalse) + re.Contains(is, int64(1)) + re.NotContains(is, int64(4)) } diff --git a/pkg/systimemon/systimemon_test.go b/pkg/systimemon/systimemon_test.go index 73be25e2edb..d267d15d965 100644 --- a/pkg/systimemon/systimemon_test.go +++ b/pkg/systimemon/systimemon_test.go @@ -26,11 +26,11 @@ func TestSystimeMonitor(t *testing.T) { defer cancel() var jumpForward int32 - trigged := false + triggered := false go StartMonitor(ctx, func() time.Time { - if !trigged { - trigged = true + if !triggered { + triggered = true return time.Now() } diff --git a/pkg/typeutil/comparison_test.go b/pkg/typeutil/comparison_test.go index 7f6c7348040..24934684b03 100644 --- a/pkg/typeutil/comparison_test.go +++ b/pkg/typeutil/comparison_test.go @@ -18,31 +18,26 @@ import ( "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -func TestTypeUtil(t *testing.T) { - TestingT(t) +func TestMinUint64(t *testing.T) { + re := require.New(t) + re.Equal(uint64(1), MinUint64(1, 2)) + re.Equal(uint64(1), MinUint64(2, 1)) + re.Equal(uint64(1), MinUint64(1, 1)) } -var _ = Suite(&testMinMaxSuite{}) - -type testMinMaxSuite struct{} - -func (s *testMinMaxSuite) TestMinUint64(c *C) { - c.Assert(MinUint64(1, 2), Equals, uint64(1)) - c.Assert(MinUint64(2, 1), Equals, uint64(1)) - c.Assert(MinUint64(1, 1), Equals, uint64(1)) -} - -func (s *testMinMaxSuite) TestMaxUint64(c *C) { - c.Assert(MaxUint64(1, 2), Equals, uint64(2)) - c.Assert(MaxUint64(2, 1), Equals, uint64(2)) - c.Assert(MaxUint64(1, 1), Equals, uint64(1)) +func TestMaxUint64(t *testing.T) { + re := require.New(t) + re.Equal(uint64(2), MaxUint64(1, 2)) + re.Equal(uint64(2), MaxUint64(2, 1)) + re.Equal(uint64(1), MaxUint64(1, 1)) } -func (s *testMinMaxSuite) TestMinDuration(c *C) { - c.Assert(MinDuration(time.Minute, time.Second), Equals, time.Second) - c.Assert(MinDuration(time.Second, time.Minute), Equals, time.Second) - c.Assert(MinDuration(time.Second, time.Second), Equals, time.Second) +func TestMinDuration(t *testing.T) { + re := require.New(t) + re.Equal(time.Second, MinDuration(time.Minute, time.Second)) + re.Equal(time.Second, MinDuration(time.Second, time.Minute)) + re.Equal(time.Second, MinDuration(time.Second, time.Second)) } diff --git a/pkg/typeutil/conversion_test.go b/pkg/typeutil/conversion_test.go index a2d9764ade0..4d28fa152f3 100644 --- a/pkg/typeutil/conversion_test.go +++ b/pkg/typeutil/conversion_test.go @@ -17,33 +17,29 @@ package typeutil import ( "encoding/json" "reflect" + "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testUint64BytesSuite{}) - -type testUint64BytesSuite struct{} - -func (s *testUint64BytesSuite) TestBytesToUint64(c *C) { +func TestBytesToUint64(t *testing.T) { + re := require.New(t) str := "\x00\x00\x00\x00\x00\x00\x03\xe8" a, err := BytesToUint64([]byte(str)) - c.Assert(err, IsNil) - c.Assert(a, Equals, uint64(1000)) + re.NoError(err) + re.Equal(uint64(1000), a) } -func (s *testUint64BytesSuite) TestUint64ToBytes(c *C) { +func TestUint64ToBytes(t *testing.T) { + re := require.New(t) var a uint64 = 1000 b := Uint64ToBytes(a) str := "\x00\x00\x00\x00\x00\x00\x03\xe8" - c.Assert(b, DeepEquals, []byte(str)) + re.True(reflect.DeepEqual([]byte(str), b)) } -var _ = Suite(&testJSONSuite{}) - -type testJSONSuite struct{} - -func (s *testJSONSuite) TestJSONToUint64Slice(c *C) { +func TestJSONToUint64Slice(t *testing.T) { + re := require.New(t) type testArray struct { Array []uint64 `json:"array"` } @@ -51,16 +47,16 @@ func (s *testJSONSuite) TestJSONToUint64Slice(c *C) { Array: []uint64{1, 2, 3}, } bytes, _ := json.Marshal(a) - var t map[string]interface{} - err := json.Unmarshal(bytes, &t) - c.Assert(err, IsNil) + var jsonStr map[string]interface{} + err := json.Unmarshal(bytes, &jsonStr) + re.NoError(err) // valid case - res, ok := JSONToUint64Slice(t["array"]) - c.Assert(ok, IsTrue) - c.Assert(reflect.TypeOf(res[0]).Kind(), Equals, reflect.Uint64) + res, ok := JSONToUint64Slice(jsonStr["array"]) + re.True(ok) + re.Equal(reflect.Uint64, reflect.TypeOf(res[0]).Kind()) // invalid case - _, ok = t["array"].([]uint64) - c.Assert(ok, IsFalse) + _, ok = jsonStr["array"].([]uint64) + re.False(ok) // invalid type type testArray1 struct { @@ -70,10 +66,10 @@ func (s *testJSONSuite) TestJSONToUint64Slice(c *C) { Array: []string{"1", "2", "3"}, } bytes, _ = json.Marshal(a1) - var t1 map[string]interface{} - err = json.Unmarshal(bytes, &t1) - c.Assert(err, IsNil) - res, ok = JSONToUint64Slice(t1["array"]) - c.Assert(ok, IsFalse) - c.Assert(res, IsNil) + var jsonStr1 map[string]interface{} + err = json.Unmarshal(bytes, &jsonStr1) + re.NoError(err) + res, ok = JSONToUint64Slice(jsonStr1["array"]) + re.False(ok) + re.Nil(res) } diff --git a/pkg/typeutil/duration_test.go b/pkg/typeutil/duration_test.go index c3b6f9182da..a7db13ffd04 100644 --- a/pkg/typeutil/duration_test.go +++ b/pkg/typeutil/duration_test.go @@ -16,35 +16,34 @@ package typeutil import ( "encoding/json" + "testing" "github.com/BurntSushi/toml" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testDurationSuite{}) - -type testDurationSuite struct{} - type example struct { Interval Duration `json:"interval" toml:"interval"` } -func (s *testDurationSuite) TestJSON(c *C) { +func TestDurationJSON(t *testing.T) { + re := require.New(t) example := &example{} text := []byte(`{"interval":"1h1m1s"}`) - c.Assert(json.Unmarshal(text, example), IsNil) - c.Assert(example.Interval.Seconds(), Equals, float64(60*60+60+1)) + re.Nil(json.Unmarshal(text, example)) + re.Equal(float64(60*60+60+1), example.Interval.Seconds()) b, err := json.Marshal(example) - c.Assert(err, IsNil) - c.Assert(string(b), Equals, string(text)) + re.NoError(err) + re.Equal(string(text), string(b)) } -func (s *testDurationSuite) TestTOML(c *C) { +func TestDurationTOML(t *testing.T) { + re := require.New(t) example := &example{} text := []byte(`interval = "1h1m1s"`) - c.Assert(toml.Unmarshal(text, example), IsNil) - c.Assert(example.Interval.Seconds(), Equals, float64(60*60+60+1)) + re.Nil(toml.Unmarshal(text, example)) + re.Equal(float64(60*60+60+1), example.Interval.Seconds()) } diff --git a/pkg/typeutil/size_test.go b/pkg/typeutil/size_test.go index eae092cdb5c..4cc9e66f3de 100644 --- a/pkg/typeutil/size_test.go +++ b/pkg/typeutil/size_test.go @@ -16,32 +16,30 @@ package typeutil import ( "encoding/json" + "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testSizeSuite{}) - -type testSizeSuite struct { -} - -func (s *testSizeSuite) TestJSON(c *C) { +func TestSizeJSON(t *testing.T) { + re := require.New(t) b := ByteSize(265421587) o, err := json.Marshal(b) - c.Assert(err, IsNil) + re.NoError(err) var nb ByteSize err = json.Unmarshal(o, &nb) - c.Assert(err, IsNil) + re.NoError(err) b = ByteSize(1756821276000) o, err = json.Marshal(b) - c.Assert(err, IsNil) - c.Assert(string(o), Equals, `"1.598TiB"`) + re.NoError(err) + re.Equal(`"1.598TiB"`, string(o)) } -func (s *testSizeSuite) TestParseMbFromText(c *C) { - testdata := []struct { +func TestParseMbFromText(t *testing.T) { + re := require.New(t) + testCases := []struct { body []string size uint64 }{{ @@ -55,9 +53,9 @@ func (s *testSizeSuite) TestParseMbFromText(c *C) { size: uint64(1), }} - for _, t := range testdata { - for _, b := range t.body { - c.Assert(int(ParseMBFromText(b, 1)), Equals, int(t.size)) + for _, testCase := range testCases { + for _, b := range testCase.body { + re.Equal(int(testCase.size), int(ParseMBFromText(b, 1))) } } } diff --git a/pkg/typeutil/string_slice_test.go b/pkg/typeutil/string_slice_test.go index 8950dea1e00..f50ddb9218d 100644 --- a/pkg/typeutil/string_slice_test.go +++ b/pkg/typeutil/string_slice_test.go @@ -16,34 +16,33 @@ package typeutil import ( "encoding/json" + "reflect" + "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testStringSliceSuite{}) - -type testStringSliceSuite struct { -} - -func (s *testStringSliceSuite) TestJSON(c *C) { +func TestStringSliceJSON(t *testing.T) { + re := require.New(t) b := StringSlice([]string{"zone", "rack"}) o, err := json.Marshal(b) - c.Assert(err, IsNil) - c.Assert(string(o), Equals, "\"zone,rack\"") + re.NoError(err) + re.Equal("\"zone,rack\"", string(o)) var nb StringSlice err = json.Unmarshal(o, &nb) - c.Assert(err, IsNil) - c.Assert(nb, DeepEquals, b) + re.NoError(err) + re.True(reflect.DeepEqual(b, nb)) } -func (s *testStringSliceSuite) TestEmpty(c *C) { +func TestEmpty(t *testing.T) { + re := require.New(t) ss := StringSlice([]string{}) b, err := json.Marshal(ss) - c.Assert(err, IsNil) - c.Assert(string(b), Equals, "\"\"") + re.NoError(err) + re.Equal("\"\"", string(b)) var ss2 StringSlice - c.Assert(ss2.UnmarshalJSON(b), IsNil) - c.Assert(ss2, DeepEquals, ss) + re.NoError(ss2.UnmarshalJSON(b)) + re.True(reflect.DeepEqual(ss, ss2)) } diff --git a/pkg/typeutil/time_test.go b/pkg/typeutil/time_test.go index 3e728c14eb9..b8078f63fa8 100644 --- a/pkg/typeutil/time_test.go +++ b/pkg/typeutil/time_test.go @@ -16,62 +16,62 @@ package typeutil import ( "math/rand" + "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testTimeSuite{}) - -type testTimeSuite struct{} - -func (s *testTimeSuite) TestParseTimestamp(c *C) { +func TestParseTimestamp(t *testing.T) { + re := require.New(t) for i := 0; i < 3; i++ { t := time.Now().Add(time.Second * time.Duration(rand.Int31n(1000))) data := Uint64ToBytes(uint64(t.UnixNano())) nt, err := ParseTimestamp(data) - c.Assert(err, IsNil) - c.Assert(nt.Equal(t), IsTrue) + re.NoError(err) + re.True(nt.Equal(t)) } data := []byte("pd") nt, err := ParseTimestamp(data) - c.Assert(err, NotNil) - c.Assert(nt.Equal(ZeroTime), IsTrue) + re.Error(err) + re.True(nt.Equal(ZeroTime)) } -func (s *testTimeSuite) TestSubTimeByWallClock(c *C) { +func TestSubTimeByWallClock(t *testing.T) { + re := require.New(t) for i := 0; i < 100; i++ { r := rand.Int63n(1000) t1 := time.Now() // Add r seconds. t2 := t1.Add(time.Second * time.Duration(r)) duration := SubRealTimeByWallClock(t2, t1) - c.Assert(duration, Equals, time.Second*time.Duration(r)) + re.Equal(time.Second*time.Duration(r), duration) milliseconds := SubTSOPhysicalByWallClock(t2, t1) - c.Assert(milliseconds, Equals, r*time.Second.Milliseconds()) - // Add r millionseconds. + re.Equal(r*time.Second.Milliseconds(), milliseconds) + // Add r milliseconds. t3 := t1.Add(time.Millisecond * time.Duration(r)) milliseconds = SubTSOPhysicalByWallClock(t3, t1) - c.Assert(milliseconds, Equals, r) + re.Equal(r, milliseconds) // Add r nanoseconds. t4 := t1.Add(time.Duration(-r)) duration = SubRealTimeByWallClock(t4, t1) - c.Assert(duration, Equals, time.Duration(-r)) + re.Equal(time.Duration(-r), duration) // For the millisecond comparison, please see TestSmallTimeDifference. } } -func (s *testTimeSuite) TestSmallTimeDifference(c *C) { +func TestSmallTimeDifference(t *testing.T) { + re := require.New(t) t1, err := time.Parse("2006-01-02 15:04:05.999", "2021-04-26 00:44:25.682") - c.Assert(err, IsNil) + re.NoError(err) t2, err := time.Parse("2006-01-02 15:04:05.999", "2021-04-26 00:44:25.681918") - c.Assert(err, IsNil) + re.NoError(err) duration := SubRealTimeByWallClock(t1, t2) - c.Assert(duration, Equals, time.Duration(82)*time.Microsecond) + re.Equal(time.Duration(82)*time.Microsecond, duration) duration = SubRealTimeByWallClock(t2, t1) - c.Assert(duration, Equals, time.Duration(-82)*time.Microsecond) + re.Equal(time.Duration(-82)*time.Microsecond, duration) milliseconds := SubTSOPhysicalByWallClock(t1, t2) - c.Assert(milliseconds, Equals, int64(1)) + re.Equal(int64(1), milliseconds) milliseconds = SubTSOPhysicalByWallClock(t2, t1) - c.Assert(milliseconds, Equals, int64(-1)) + re.Equal(int64(-1), milliseconds) } From d114cd676ab039c3de6b801aced9bc78ab37bb70 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Thu, 2 Jun 2022 17:30:28 +0800 Subject: [PATCH 15/82] api: fix the testServiceSuite (#5096) ref tikv/pd#4813 Fix the `testServiceSuite`. Signed-off-by: JmPotato --- server/api/admin_test.go | 19 ------------------- server/api/server_test.go | 8 ++++---- 2 files changed, 4 insertions(+), 23 deletions(-) diff --git a/server/api/admin_test.go b/server/api/admin_test.go index 3be3b38b484..1ece28a5239 100644 --- a/server/api/admin_test.go +++ b/server/api/admin_test.go @@ -170,22 +170,3 @@ func (s *testTSOSuite) TestResetTS(c *C) { tu.StringEqual(c, "\"invalid tso value\"\n")) c.Assert(err, IsNil) } - -var _ = Suite(&testServiceSuite{}) - -type testServiceSuite struct { - svr *server.Server - cleanup cleanUpFunc -} - -func (s *testServiceSuite) SetUpSuite(c *C) { - s.svr, s.cleanup = mustNewServer(c) - mustWaitLeader(c, []*server.Server{s.svr}) - - mustBootstrapCluster(c, s.svr) - mustPutStore(c, s.svr, 1, metapb.StoreState_Up, metapb.NodeState_Serving, nil) -} - -func (s *testServiceSuite) TearDownSuite(c *C) { - s.cleanup() -} diff --git a/server/api/server_test.go b/server/api/server_test.go index 51467db3938..8d9f1b4c227 100644 --- a/server/api/server_test.go +++ b/server/api/server_test.go @@ -154,14 +154,14 @@ func mustBootstrapCluster(c *C, s *server.Server) { c.Assert(resp.GetHeader().GetError().GetType(), Equals, pdpb.ErrorType_OK) } -var _ = Suite(&testServerServiceSuite{}) +var _ = Suite(&testServiceSuite{}) -type testServerServiceSuite struct { +type testServiceSuite struct { svr *server.Server cleanup cleanUpFunc } -func (s *testServerServiceSuite) SetUpSuite(c *C) { +func (s *testServiceSuite) SetUpSuite(c *C) { s.svr, s.cleanup = mustNewServer(c) mustWaitLeader(c, []*server.Server{s.svr}) @@ -169,7 +169,7 @@ func (s *testServerServiceSuite) SetUpSuite(c *C) { mustPutStore(c, s.svr, 1, metapb.StoreState_Up, metapb.NodeState_Serving, nil) } -func (s *testServerServiceSuite) TearDownSuite(c *C) { +func (s *testServiceSuite) TearDownSuite(c *C) { s.cleanup() } From f964149e39f53e83b770641691ac6c3df0521e6d Mon Sep 17 00:00:00 2001 From: disksing Date: Mon, 6 Jun 2022 11:18:28 +0800 Subject: [PATCH 16/82] rangelist: migrate test framework to testify (#5102) close tikv/pd#5098 Signed-off-by: disksing --- server/schedule/rangelist/range_list_test.go | 61 ++++++++++---------- 1 file changed, 29 insertions(+), 32 deletions(-) diff --git a/server/schedule/rangelist/range_list_test.go b/server/schedule/rangelist/range_list_test.go index 4b737d7f1fd..0f9ba595aa0 100644 --- a/server/schedule/rangelist/range_list_test.go +++ b/server/schedule/rangelist/range_list_test.go @@ -17,45 +17,42 @@ package rangelist import ( "testing" - "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) func TestRangeList(t *testing.T) { - check.TestingT(t) -} - -var _ = check.Suite(&testRangeListSuite{}) - -type testRangeListSuite struct{} - -func (s *testRangeListSuite) TestRangeList(c *check.C) { + re := require.New(t) rl := NewBuilder().Build() - c.Assert(rl.Len(), check.Equals, 0) + re.Equal(0, rl.Len()) i, data := rl.GetDataByKey([]byte("a")) - c.Assert(i, check.Equals, -1) - c.Assert(data, check.IsNil) + re.Equal(-1, i) + re.Nil(data) + i, data = rl.GetData([]byte("a"), []byte("b")) - c.Assert(i, check.Equals, -1) - c.Assert(data, check.IsNil) - c.Assert(rl.GetSplitKeys(nil, []byte("foo")), check.IsNil) + re.Equal(-1, i) + re.Nil(data) + + re.Nil(rl.GetSplitKeys(nil, []byte("foo"))) b := NewBuilder() b.AddItem(nil, nil, 1) rl = b.Build() - c.Assert(rl.Len(), check.Equals, 1) + re.Equal(1, rl.Len()) key, data := rl.Get(0) - c.Assert(key, check.IsNil) - c.Assert(data, check.DeepEquals, []interface{}{1}) + re.Nil(key) + + re.Equal([]interface{}{1}, data) i, data = rl.GetDataByKey([]byte("foo")) - c.Assert(i, check.Equals, 0) - c.Assert(data, check.DeepEquals, []interface{}{1}) + re.Equal(0, i) + re.Equal([]interface{}{1}, data) i, data = rl.GetData([]byte("a"), []byte("b")) - c.Assert(i, check.Equals, 0) - c.Assert(data, check.DeepEquals, []interface{}{1}) - c.Assert(rl.GetSplitKeys(nil, []byte("foo")), check.IsNil) + re.Equal(0, i) + re.Equal([]interface{}{1}, data) + re.Nil(rl.GetSplitKeys(nil, []byte("foo"))) } -func (s *testRangeListSuite) TestRangeList2(c *check.C) { +func TestRangeList2(t *testing.T) { + re := require.New(t) b := NewBuilder() b.SetCompareFunc(func(a, b interface{}) int { if a.(int) > b.(int) { @@ -88,11 +85,11 @@ func (s *testRangeListSuite) TestRangeList2(c *check.C) { } rl := b.Build() - c.Assert(rl.Len(), check.Equals, len(expectKeys)) + re.Equal(len(expectKeys), rl.Len()) for i := 0; i < rl.Len(); i++ { key, data := rl.Get(i) - c.Assert(key, check.DeepEquals, expectKeys[i]) - c.Assert(data, check.DeepEquals, expectData[i]) + re.Equal(expectKeys[i], key) + re.Equal(expectData[i], data) } getDataByKeyCases := []struct { @@ -103,8 +100,8 @@ func (s *testRangeListSuite) TestRangeList2(c *check.C) { } for _, tc := range getDataByKeyCases { i, data := rl.GetDataByKey([]byte(tc.key)) - c.Assert(i, check.Equals, tc.pos) - c.Assert(data, check.DeepEquals, expectData[i]) + re.Equal(tc.pos, i) + re.Equal(expectData[i], data) } getDataCases := []struct { @@ -116,9 +113,9 @@ func (s *testRangeListSuite) TestRangeList2(c *check.C) { } for _, tc := range getDataCases { i, data := rl.GetData([]byte(tc.start), []byte(tc.end)) - c.Assert(i, check.Equals, tc.pos) + re.Equal(tc.pos, i) if i >= 0 { - c.Assert(data, check.DeepEquals, expectData[i]) + re.Equal(expectData[i], data) } } @@ -131,6 +128,6 @@ func (s *testRangeListSuite) TestRangeList2(c *check.C) { {"cc", "fx", 4, 7}, } for _, tc := range getSplitKeysCases { - c.Assert(rl.GetSplitKeys([]byte(tc.start), []byte(tc.end)), check.DeepEquals, expectKeys[tc.indexStart:tc.indexEnd]) + re.Equal(expectKeys[tc.indexStart:tc.indexEnd], rl.GetSplitKeys([]byte(tc.start), []byte(tc.end))) } } From 3de8d68cf65438db501276a8b2b35161704c4095 Mon Sep 17 00:00:00 2001 From: disksing Date: Mon, 6 Jun 2022 16:24:29 +0800 Subject: [PATCH 17/82] replication: migrate tests to testify (#5100) close tikv/pd#5099 Signed-off-by: disksing Co-authored-by: Ti Chi Robot --- server/replication/replication_mode_test.go | 299 ++++++++++---------- 1 file changed, 151 insertions(+), 148 deletions(-) diff --git a/server/replication/replication_mode_test.go b/server/replication/replication_mode_test.go index 1f4afb01ca9..8162da599ff 100644 --- a/server/replication/replication_mode_test.go +++ b/server/replication/replication_mode_test.go @@ -21,9 +21,9 @@ import ( "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/pdpb" pb "github.com/pingcap/kvproto/pkg/replication_modepb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/mock/mockcluster" "github.com/tikv/pd/pkg/typeutil" "github.com/tikv/pd/server/config" @@ -31,32 +31,16 @@ import ( "github.com/tikv/pd/server/storage" ) -func TestReplicationMode(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testReplicationMode{}) - -type testReplicationMode struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *testReplicationMode) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) -} - -func (s *testReplicationMode) TearDownTest(c *C) { - s.cancel() -} - -func (s *testReplicationMode) TestInitial(c *C) { +func TestInitial(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() store := storage.NewStorageWithMemoryBackend() conf := config.ReplicationModeConfig{ReplicationMode: modeMajority} - cluster := mockcluster.NewCluster(s.ctx, config.NewTestOptions()) + cluster := mockcluster.NewCluster(ctx, config.NewTestOptions()) rep, err := NewReplicationModeManager(conf, store, cluster, newMockReplicator([]uint64{1})) - c.Assert(err, IsNil) - c.Assert(rep.GetReplicationStatus(), DeepEquals, &pb.ReplicationStatus{Mode: pb.ReplicationMode_MAJORITY}) + re.NoError(err) + re.Equal(&pb.ReplicationStatus{Mode: pb.ReplicationMode_MAJORITY}, rep.GetReplicationStatus()) conf = config.ReplicationModeConfig{ReplicationMode: modeDRAutoSync, DRAutoSync: config.DRAutoSyncReplicationConfig{ LabelKey: "dr-label", @@ -68,8 +52,8 @@ func (s *testReplicationMode) TestInitial(c *C) { WaitSyncTimeout: typeutil.Duration{Duration: time.Minute}, }} rep, err = NewReplicationModeManager(conf, store, cluster, newMockReplicator([]uint64{1})) - c.Assert(err, IsNil) - c.Assert(rep.GetReplicationStatus(), DeepEquals, &pb.ReplicationStatus{ + re.NoError(err) + re.Equal(&pb.ReplicationStatus{ Mode: pb.ReplicationMode_DR_AUTO_SYNC, DrAutoSync: &pb.DRAutoSync{ LabelKey: "dr-label", @@ -77,19 +61,22 @@ func (s *testReplicationMode) TestInitial(c *C) { StateId: 1, WaitSyncTimeoutHint: 60, }, - }) + }, rep.GetReplicationStatus()) } -func (s *testReplicationMode) TestStatus(c *C) { +func TestStatus(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() store := storage.NewStorageWithMemoryBackend() conf := config.ReplicationModeConfig{ReplicationMode: modeDRAutoSync, DRAutoSync: config.DRAutoSyncReplicationConfig{ LabelKey: "dr-label", WaitSyncTimeout: typeutil.Duration{Duration: time.Minute}, }} - cluster := mockcluster.NewCluster(s.ctx, config.NewTestOptions()) + cluster := mockcluster.NewCluster(ctx, config.NewTestOptions()) rep, err := NewReplicationModeManager(conf, store, cluster, newMockReplicator([]uint64{1})) - c.Assert(err, IsNil) - c.Assert(rep.GetReplicationStatus(), DeepEquals, &pb.ReplicationStatus{ + re.NoError(err) + re.Equal(&pb.ReplicationStatus{ Mode: pb.ReplicationMode_DR_AUTO_SYNC, DrAutoSync: &pb.DRAutoSync{ LabelKey: "dr-label", @@ -97,11 +84,11 @@ func (s *testReplicationMode) TestStatus(c *C) { StateId: 1, WaitSyncTimeoutHint: 60, }, - }) + }, rep.GetReplicationStatus()) err = rep.drSwitchToAsync(nil) - c.Assert(err, IsNil) - c.Assert(rep.GetReplicationStatus(), DeepEquals, &pb.ReplicationStatus{ + re.NoError(err) + re.Equal(&pb.ReplicationStatus{ Mode: pb.ReplicationMode_DR_AUTO_SYNC, DrAutoSync: &pb.DRAutoSync{ LabelKey: "dr-label", @@ -109,12 +96,12 @@ func (s *testReplicationMode) TestStatus(c *C) { StateId: 2, WaitSyncTimeoutHint: 60, }, - }) + }, rep.GetReplicationStatus()) err = rep.drSwitchToSyncRecover() - c.Assert(err, IsNil) + re.NoError(err) stateID := rep.drAutoSync.StateID - c.Assert(rep.GetReplicationStatus(), DeepEquals, &pb.ReplicationStatus{ + re.Equal(&pb.ReplicationStatus{ Mode: pb.ReplicationMode_DR_AUTO_SYNC, DrAutoSync: &pb.DRAutoSync{ LabelKey: "dr-label", @@ -122,16 +109,16 @@ func (s *testReplicationMode) TestStatus(c *C) { StateId: stateID, WaitSyncTimeoutHint: 60, }, - }) + }, rep.GetReplicationStatus()) // test reload rep, err = NewReplicationModeManager(conf, store, cluster, newMockReplicator([]uint64{1})) - c.Assert(err, IsNil) - c.Assert(rep.drAutoSync.State, Equals, drStateSyncRecover) + re.NoError(err) + re.Equal(drStateSyncRecover, rep.drAutoSync.State) err = rep.drSwitchToSync() - c.Assert(err, IsNil) - c.Assert(rep.GetReplicationStatus(), DeepEquals, &pb.ReplicationStatus{ + re.NoError(err) + re.Equal(&pb.ReplicationStatus{ Mode: pb.ReplicationMode_DR_AUTO_SYNC, DrAutoSync: &pb.DRAutoSync{ LabelKey: "dr-label", @@ -139,7 +126,7 @@ func (s *testReplicationMode) TestStatus(c *C) { StateId: rep.drAutoSync.StateID, WaitSyncTimeoutHint: 60, }, - }) + }, rep.GetReplicationStatus()) } type mockFileReplicator struct { @@ -172,7 +159,10 @@ func newMockReplicator(ids []uint64) *mockFileReplicator { } } -func (s *testReplicationMode) TestStateSwitch(c *C) { +func TestStateSwitch(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() store := storage.NewStorageWithMemoryBackend() conf := config.ReplicationModeConfig{ReplicationMode: modeDRAutoSync, DRAutoSync: config.DRAutoSyncReplicationConfig{ LabelKey: "zone", @@ -183,10 +173,10 @@ func (s *testReplicationMode) TestStateSwitch(c *C) { WaitStoreTimeout: typeutil.Duration{Duration: time.Minute}, WaitSyncTimeout: typeutil.Duration{Duration: time.Minute}, }} - cluster := mockcluster.NewCluster(s.ctx, config.NewTestOptions()) + cluster := mockcluster.NewCluster(ctx, config.NewTestOptions()) replicator := newMockReplicator([]uint64{1}) rep, err := NewReplicationModeManager(conf, store, cluster, replicator) - c.Assert(err, IsNil) + re.NoError(err) cluster.AddLabelsStore(1, 1, map[string]string{"zone": "zone1"}) cluster.AddLabelsStore(2, 1, map[string]string{"zone": "zone1"}) @@ -194,12 +184,12 @@ func (s *testReplicationMode) TestStateSwitch(c *C) { cluster.AddLabelsStore(4, 1, map[string]string{"zone": "zone1"}) // initial state is sync - c.Assert(rep.drGetState(), Equals, drStateSync) + re.Equal(drStateSync, rep.drGetState()) stateID := rep.drAutoSync.StateID - c.Assert(stateID, Not(Equals), uint64(0)) - c.Assert(replicator.lastData[1], Equals, fmt.Sprintf(`{"state":"sync","state_id":%d}`, stateID)) + re.NotEqual(uint64(0), stateID) + re.Equal(fmt.Sprintf(`{"state":"sync","state_id":%d}`, stateID), replicator.lastData[1]) assertStateIDUpdate := func() { - c.Assert(rep.drAutoSync.StateID, Not(Equals), stateID) + re.NotEqual(stateID, rep.drAutoSync.StateID) stateID = rep.drAutoSync.StateID } syncStoreStatus := func(storeIDs ...uint64) { @@ -211,124 +201,124 @@ func (s *testReplicationMode) TestStateSwitch(c *C) { // only one zone, sync -> async_wait -> async rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateAsyncWait) + re.Equal(drStateAsyncWait, rep.drGetState()) assertStateIDUpdate() - c.Assert(replicator.lastData[1], Equals, fmt.Sprintf(`{"state":"async_wait","state_id":%d,"available_stores":[1,2,3,4]}`, stateID)) + re.Equal(fmt.Sprintf(`{"state":"async_wait","state_id":%d,"available_stores":[1,2,3,4]}`, stateID), replicator.lastData[1]) - c.Assert(rep.GetReplicationStatus().GetDrAutoSync().GetPauseRegionSplit(), IsFalse) + re.False(rep.GetReplicationStatus().GetDrAutoSync().GetPauseRegionSplit()) conf.DRAutoSync.PauseRegionSplit = true rep.UpdateConfig(conf) - c.Assert(rep.GetReplicationStatus().GetDrAutoSync().GetPauseRegionSplit(), IsTrue) + re.True(rep.GetReplicationStatus().GetDrAutoSync().GetPauseRegionSplit()) syncStoreStatus(1, 2, 3, 4) rep.tickDR() assertStateIDUpdate() - c.Assert(replicator.lastData[1], Equals, fmt.Sprintf(`{"state":"async","state_id":%d,"available_stores":[1,2,3,4]}`, stateID)) + re.Equal(fmt.Sprintf(`{"state":"async","state_id":%d,"available_stores":[1,2,3,4]}`, stateID), replicator.lastData[1]) // add new store in dr zone. cluster.AddLabelsStore(5, 1, map[string]string{"zone": "zone2"}) cluster.AddLabelsStore(6, 1, map[string]string{"zone": "zone2"}) // async -> sync rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateSyncRecover) + re.Equal(drStateSyncRecover, rep.drGetState()) rep.drSwitchToSync() - c.Assert(rep.drGetState(), Equals, drStateSync) + re.Equal(drStateSync, rep.drGetState()) assertStateIDUpdate() // sync -> async_wait rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateSync) - s.setStoreState(cluster, "down", "up", "up", "up", "up", "up") + re.Equal(drStateSync, rep.drGetState()) + setStoreState(cluster, "down", "up", "up", "up", "up", "up") rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateSync) - s.setStoreState(cluster, "down", "down", "up", "up", "up", "up") - s.setStoreState(cluster, "down", "down", "down", "up", "up", "up") + re.Equal(drStateSync, rep.drGetState()) + setStoreState(cluster, "down", "down", "up", "up", "up", "up") + setStoreState(cluster, "down", "down", "down", "up", "up", "up") rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateSync) // cannot guarantee majority, keep sync. + re.Equal(drStateSync, rep.drGetState()) // cannot guarantee majority, keep sync. - s.setStoreState(cluster, "up", "up", "up", "up", "down", "up") + setStoreState(cluster, "up", "up", "up", "up", "down", "up") rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateAsyncWait) + re.Equal(drStateAsyncWait, rep.drGetState()) assertStateIDUpdate() rep.drSwitchToSync() replicator.errors[2] = errors.New("fail to replicate") rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateAsyncWait) + re.Equal(drStateAsyncWait, rep.drGetState()) assertStateIDUpdate() delete(replicator.errors, 1) // async_wait -> sync - s.setStoreState(cluster, "up", "up", "up", "up", "up", "up") + setStoreState(cluster, "up", "up", "up", "up", "up", "up") rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateSync) - c.Assert(rep.GetReplicationStatus().GetDrAutoSync().GetPauseRegionSplit(), IsFalse) + re.Equal(drStateSync, rep.drGetState()) + re.False(rep.GetReplicationStatus().GetDrAutoSync().GetPauseRegionSplit()) // async_wait -> async_wait - s.setStoreState(cluster, "up", "up", "up", "up", "down", "up") + setStoreState(cluster, "up", "up", "up", "up", "down", "up") rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateAsyncWait) + re.Equal(drStateAsyncWait, rep.drGetState()) assertStateIDUpdate() - c.Assert(replicator.lastData[1], Equals, fmt.Sprintf(`{"state":"async_wait","state_id":%d,"available_stores":[1,2,3,4]}`, stateID)) - s.setStoreState(cluster, "down", "up", "up", "up", "down", "up") + re.Equal(fmt.Sprintf(`{"state":"async_wait","state_id":%d,"available_stores":[1,2,3,4]}`, stateID), replicator.lastData[1]) + setStoreState(cluster, "down", "up", "up", "up", "down", "up") rep.tickDR() assertStateIDUpdate() - c.Assert(replicator.lastData[1], Equals, fmt.Sprintf(`{"state":"async_wait","state_id":%d,"available_stores":[2,3,4]}`, stateID)) - s.setStoreState(cluster, "up", "down", "up", "up", "down", "up") + re.Equal(fmt.Sprintf(`{"state":"async_wait","state_id":%d,"available_stores":[2,3,4]}`, stateID), replicator.lastData[1]) + setStoreState(cluster, "up", "down", "up", "up", "down", "up") rep.tickDR() assertStateIDUpdate() - c.Assert(replicator.lastData[1], Equals, fmt.Sprintf(`{"state":"async_wait","state_id":%d,"available_stores":[1,3,4]}`, stateID)) + re.Equal(fmt.Sprintf(`{"state":"async_wait","state_id":%d,"available_stores":[1,3,4]}`, stateID), replicator.lastData[1]) // async_wait -> async rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateAsyncWait) + re.Equal(drStateAsyncWait, rep.drGetState()) syncStoreStatus(1, 3) rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateAsyncWait) + re.Equal(drStateAsyncWait, rep.drGetState()) syncStoreStatus(4) rep.tickDR() assertStateIDUpdate() - c.Assert(replicator.lastData[1], Equals, fmt.Sprintf(`{"state":"async","state_id":%d,"available_stores":[1,3,4]}`, stateID)) + re.Equal(fmt.Sprintf(`{"state":"async","state_id":%d,"available_stores":[1,3,4]}`, stateID), replicator.lastData[1]) // async -> async - s.setStoreState(cluster, "up", "up", "up", "up", "down", "up") + setStoreState(cluster, "up", "up", "up", "up", "down", "up") rep.tickDR() // store 2 won't be available before it syncs status. - c.Assert(replicator.lastData[1], Equals, fmt.Sprintf(`{"state":"async","state_id":%d,"available_stores":[1,3,4]}`, stateID)) + re.Equal(fmt.Sprintf(`{"state":"async","state_id":%d,"available_stores":[1,3,4]}`, stateID), replicator.lastData[1]) syncStoreStatus(1, 2, 3, 4) rep.tickDR() assertStateIDUpdate() - c.Assert(replicator.lastData[1], Equals, fmt.Sprintf(`{"state":"async","state_id":%d,"available_stores":[1,2,3,4]}`, stateID)) + re.Equal(fmt.Sprintf(`{"state":"async","state_id":%d,"available_stores":[1,2,3,4]}`, stateID), replicator.lastData[1]) // async -> sync_recover - s.setStoreState(cluster, "up", "up", "up", "up", "up", "up") + setStoreState(cluster, "up", "up", "up", "up", "up", "up") rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateSyncRecover) + re.Equal(drStateSyncRecover, rep.drGetState()) assertStateIDUpdate() rep.drSwitchToAsync([]uint64{1, 2, 3, 4, 5}) - s.setStoreState(cluster, "down", "up", "up", "up", "up", "up") + setStoreState(cluster, "down", "up", "up", "up", "up", "up") rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateSyncRecover) + re.Equal(drStateSyncRecover, rep.drGetState()) assertStateIDUpdate() // sync_recover -> async rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateSyncRecover) - s.setStoreState(cluster, "up", "up", "up", "up", "down", "up") + re.Equal(drStateSyncRecover, rep.drGetState()) + setStoreState(cluster, "up", "up", "up", "up", "down", "up") rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateAsync) + re.Equal(drStateAsync, rep.drGetState()) assertStateIDUpdate() // lost majority, does not switch to async. rep.drSwitchToSyncRecover() assertStateIDUpdate() - s.setStoreState(cluster, "down", "down", "up", "up", "down", "up") + setStoreState(cluster, "down", "down", "up", "up", "down", "up") rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateSyncRecover) + re.Equal(drStateSyncRecover, rep.drGetState()) // sync_recover -> sync rep.drSwitchToSyncRecover() assertStateIDUpdate() - s.setStoreState(cluster, "up", "up", "up", "up", "up", "up") + setStoreState(cluster, "up", "up", "up", "up", "up", "up") cluster.AddLeaderRegion(1, 1, 2, 3, 4, 5) region := cluster.GetRegion(1) @@ -337,7 +327,7 @@ func (s *testReplicationMode) TestStateSwitch(c *C) { })) cluster.PutRegion(region) rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateSyncRecover) + re.Equal(drStateSyncRecover, rep.drGetState()) region = region.Clone(core.SetReplicationStatus(&pb.RegionReplicationStatus{ State: pb.RegionReplicationState_INTEGRITY_OVER_LABEL, @@ -345,18 +335,21 @@ func (s *testReplicationMode) TestStateSwitch(c *C) { })) cluster.PutRegion(region) rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateSyncRecover) + re.Equal(drStateSyncRecover, rep.drGetState()) region = region.Clone(core.SetReplicationStatus(&pb.RegionReplicationStatus{ State: pb.RegionReplicationState_INTEGRITY_OVER_LABEL, StateId: rep.drAutoSync.StateID, })) cluster.PutRegion(region) rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateSync) + re.Equal(drStateSync, rep.drGetState()) assertStateIDUpdate() } -func (s *testReplicationMode) TestReplicateState(c *C) { +func TestReplicateState(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() store := storage.NewStorageWithMemoryBackend() conf := config.ReplicationModeConfig{ReplicationMode: modeDRAutoSync, DRAutoSync: config.DRAutoSyncReplicationConfig{ LabelKey: "zone", @@ -367,36 +360,39 @@ func (s *testReplicationMode) TestReplicateState(c *C) { WaitStoreTimeout: typeutil.Duration{Duration: time.Minute}, WaitSyncTimeout: typeutil.Duration{Duration: time.Minute}, }} - cluster := mockcluster.NewCluster(s.ctx, config.NewTestOptions()) + cluster := mockcluster.NewCluster(ctx, config.NewTestOptions()) replicator := newMockReplicator([]uint64{1}) rep, err := NewReplicationModeManager(conf, store, cluster, replicator) - c.Assert(err, IsNil) + re.NoError(err) stateID := rep.drAutoSync.StateID // replicate after initialized - c.Assert(replicator.lastData[1], Equals, fmt.Sprintf(`{"state":"sync","state_id":%d}`, stateID)) + re.Equal(fmt.Sprintf(`{"state":"sync","state_id":%d}`, stateID), replicator.lastData[1]) // repliate state to new member replicator.memberIDs = append(replicator.memberIDs, 2, 3) rep.checkReplicateFile() - c.Assert(replicator.lastData[2], Equals, fmt.Sprintf(`{"state":"sync","state_id":%d}`, stateID)) - c.Assert(replicator.lastData[3], Equals, fmt.Sprintf(`{"state":"sync","state_id":%d}`, stateID)) + re.Equal(fmt.Sprintf(`{"state":"sync","state_id":%d}`, stateID), replicator.lastData[2]) + re.Equal(fmt.Sprintf(`{"state":"sync","state_id":%d}`, stateID), replicator.lastData[3]) // inject error replicator.errors[2] = errors.New("failed to persist") rep.tickDR() // switch async_wait since there is only one zone newStateID := rep.drAutoSync.StateID - c.Assert(replicator.lastData[1], Equals, fmt.Sprintf(`{"state":"async_wait","state_id":%d}`, newStateID)) - c.Assert(replicator.lastData[2], Equals, fmt.Sprintf(`{"state":"sync","state_id":%d}`, stateID)) - c.Assert(replicator.lastData[3], Equals, fmt.Sprintf(`{"state":"async_wait","state_id":%d}`, newStateID)) + re.Equal(fmt.Sprintf(`{"state":"async_wait","state_id":%d}`, newStateID), replicator.lastData[1]) + re.Equal(fmt.Sprintf(`{"state":"sync","state_id":%d}`, stateID), replicator.lastData[2]) + re.Equal(fmt.Sprintf(`{"state":"async_wait","state_id":%d}`, newStateID), replicator.lastData[3]) // clear error, replicate to node 2 next time delete(replicator.errors, 2) rep.checkReplicateFile() - c.Assert(replicator.lastData[2], Equals, fmt.Sprintf(`{"state":"async_wait","state_id":%d}`, newStateID)) + re.Equal(fmt.Sprintf(`{"state":"async_wait","state_id":%d}`, newStateID), replicator.lastData[2]) } -func (s *testReplicationMode) TestAsynctimeout(c *C) { +func TestAsynctimeout(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() store := storage.NewStorageWithMemoryBackend() conf := config.ReplicationModeConfig{ReplicationMode: modeDRAutoSync, DRAutoSync: config.DRAutoSyncReplicationConfig{ LabelKey: "zone", @@ -408,34 +404,34 @@ func (s *testReplicationMode) TestAsynctimeout(c *C) { WaitSyncTimeout: typeutil.Duration{Duration: time.Minute}, WaitAsyncTimeout: typeutil.Duration{Duration: 2 * time.Minute}, }} - cluster := mockcluster.NewCluster(s.ctx, config.NewTestOptions()) + cluster := mockcluster.NewCluster(ctx, config.NewTestOptions()) var replicator mockFileReplicator rep, err := NewReplicationModeManager(conf, store, cluster, &replicator) - c.Assert(err, IsNil) + re.NoError(err) cluster.AddLabelsStore(1, 1, map[string]string{"zone": "zone1"}) cluster.AddLabelsStore(2, 1, map[string]string{"zone": "zone1"}) cluster.AddLabelsStore(3, 1, map[string]string{"zone": "zone2"}) - s.setStoreState(cluster, "up", "up", "down") + setStoreState(cluster, "up", "up", "down") rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateSync) // cannot switch state due to recently start + re.Equal(drStateSync, rep.drGetState()) // cannot switch state due to recently start rep.initTime = time.Now().Add(-3 * time.Minute) rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateAsyncWait) + re.Equal(drStateAsyncWait, rep.drGetState()) rep.drSwitchToSync() rep.UpdateMemberWaitAsyncTime(42) rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateSync) // cannot switch state due to member not timeout + re.Equal(drStateSync, rep.drGetState()) // cannot switch state due to member not timeout rep.drMemberWaitAsyncTime[42] = time.Now().Add(-3 * time.Minute) rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateAsyncWait) + re.Equal(drStateAsyncWait, rep.drGetState()) } -func (s *testReplicationMode) setStoreState(cluster *mockcluster.Cluster, states ...string) { +func setStoreState(cluster *mockcluster.Cluster, states ...string) { for i, state := range states { store := cluster.GetStore(uint64(i + 1)) if state == "down" { @@ -447,7 +443,11 @@ func (s *testReplicationMode) setStoreState(cluster *mockcluster.Cluster, states } } -func (s *testReplicationMode) TestRecoverProgress(c *C) { +func TestRecoverProgress(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + regionScanBatchSize = 10 regionMinSampleSize = 5 @@ -461,14 +461,14 @@ func (s *testReplicationMode) TestRecoverProgress(c *C) { WaitStoreTimeout: typeutil.Duration{Duration: time.Minute}, WaitSyncTimeout: typeutil.Duration{Duration: time.Minute}, }} - cluster := mockcluster.NewCluster(s.ctx, config.NewTestOptions()) + cluster := mockcluster.NewCluster(ctx, config.NewTestOptions()) cluster.AddLabelsStore(1, 1, map[string]string{}) rep, err := NewReplicationModeManager(conf, store, cluster, newMockReplicator([]uint64{1})) - c.Assert(err, IsNil) + re.NoError(err) prepare := func(n int, asyncRegions []int) { rep.drSwitchToSyncRecover() - regions := s.genRegions(cluster, rep.drAutoSync.StateID, n) + regions := genRegions(cluster, rep.drAutoSync.StateID, n) for _, i := range asyncRegions { regions[i] = regions[i].Clone(core.SetReplicationStatus(&pb.RegionReplicationStatus{ State: pb.RegionReplicationState_SIMPLE_MAJORITY, @@ -482,32 +482,35 @@ func (s *testReplicationMode) TestRecoverProgress(c *C) { } prepare(20, nil) - c.Assert(rep.drRecoverCount, Equals, 20) - c.Assert(rep.estimateProgress(), Equals, float32(1.0)) + re.Equal(20, rep.drRecoverCount) + re.Equal(float32(1.0), rep.estimateProgress()) prepare(10, []int{9}) - c.Assert(rep.drRecoverCount, Equals, 9) - c.Assert(rep.drTotalRegion, Equals, 10) - c.Assert(rep.drSampleTotalRegion, Equals, 1) - c.Assert(rep.drSampleRecoverCount, Equals, 0) - c.Assert(rep.estimateProgress(), Equals, float32(9)/float32(10)) + re.Equal(9, rep.drRecoverCount) + re.Equal(10, rep.drTotalRegion) + re.Equal(1, rep.drSampleTotalRegion) + re.Equal(0, rep.drSampleRecoverCount) + re.Equal(float32(9)/float32(10), rep.estimateProgress()) prepare(30, []int{3, 4, 5, 6, 7, 8, 9}) - c.Assert(rep.drRecoverCount, Equals, 3) - c.Assert(rep.drTotalRegion, Equals, 30) - c.Assert(rep.drSampleTotalRegion, Equals, 7) - c.Assert(rep.drSampleRecoverCount, Equals, 0) - c.Assert(rep.estimateProgress(), Equals, float32(3)/float32(30)) + re.Equal(3, rep.drRecoverCount) + re.Equal(30, rep.drTotalRegion) + re.Equal(7, rep.drSampleTotalRegion) + re.Equal(0, rep.drSampleRecoverCount) + re.Equal(float32(3)/float32(30), rep.estimateProgress()) prepare(30, []int{9, 13, 14}) - c.Assert(rep.drRecoverCount, Equals, 9) - c.Assert(rep.drTotalRegion, Equals, 30) - c.Assert(rep.drSampleTotalRegion, Equals, 6) // 9 + 10,11,12,13,14 - c.Assert(rep.drSampleRecoverCount, Equals, 3) - c.Assert(rep.estimateProgress(), Equals, (float32(9)+float32(30-9)/2)/float32(30)) + re.Equal(9, rep.drRecoverCount) + re.Equal(30, rep.drTotalRegion) + re.Equal(6, rep.drSampleTotalRegion) // 9 + 10,11,12,13,14 + re.Equal(3, rep.drSampleRecoverCount) + re.Equal((float32(9)+float32(30-9)/2)/float32(30), rep.estimateProgress()) } -func (s *testReplicationMode) TestRecoverProgressWithSplitAndMerge(c *C) { +func TestRecoverProgressWithSplitAndMerge(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() regionScanBatchSize = 10 regionMinSampleSize = 5 @@ -521,14 +524,14 @@ func (s *testReplicationMode) TestRecoverProgressWithSplitAndMerge(c *C) { WaitStoreTimeout: typeutil.Duration{Duration: time.Minute}, WaitSyncTimeout: typeutil.Duration{Duration: time.Minute}, }} - cluster := mockcluster.NewCluster(s.ctx, config.NewTestOptions()) + cluster := mockcluster.NewCluster(ctx, config.NewTestOptions()) cluster.AddLabelsStore(1, 1, map[string]string{}) rep, err := NewReplicationModeManager(conf, store, cluster, newMockReplicator([]uint64{1})) - c.Assert(err, IsNil) + re.NoError(err) prepare := func(n int, asyncRegions []int) { rep.drSwitchToSyncRecover() - regions := s.genRegions(cluster, rep.drAutoSync.StateID, n) + regions := genRegions(cluster, rep.drAutoSync.StateID, n) for _, i := range asyncRegions { regions[i] = regions[i].Clone(core.SetReplicationStatus(&pb.RegionReplicationStatus{ State: pb.RegionReplicationState_SIMPLE_MAJORITY, @@ -545,8 +548,8 @@ func (s *testReplicationMode) TestRecoverProgressWithSplitAndMerge(c *C) { r := cluster.GetRegion(1).Clone(core.WithEndKey(cluster.GetRegion(2).GetEndKey())) cluster.PutRegion(r) rep.updateProgress() - c.Assert(rep.drRecoverCount, Equals, 19) - c.Assert(rep.estimateProgress(), Equals, float32(1.0)) + re.Equal(19, rep.drRecoverCount) + re.Equal(float32(1.0), rep.estimateProgress()) // merged happened during the scan prepare(20, nil) @@ -557,23 +560,23 @@ func (s *testReplicationMode) TestRecoverProgressWithSplitAndMerge(c *C) { rep.drRecoverCount = 1 rep.drRecoverKey = r1.GetEndKey() rep.updateProgress() - c.Assert(rep.drRecoverCount, Equals, 20) - c.Assert(rep.estimateProgress(), Equals, float32(1.0)) + re.Equal(20, rep.drRecoverCount) + re.Equal(float32(1.0), rep.estimateProgress()) // split, region gap happened during the scan rep.drRecoverCount, rep.drRecoverKey = 0, nil cluster.PutRegion(r1) rep.updateProgress() - c.Assert(rep.drRecoverCount, Equals, 1) - c.Assert(rep.estimateProgress(), Not(Equals), float32(1.0)) + re.Equal(1, rep.drRecoverCount) + re.NotEqual(float32(1.0), rep.estimateProgress()) // region gap missing cluster.PutRegion(r2) rep.updateProgress() - c.Assert(rep.drRecoverCount, Equals, 20) - c.Assert(rep.estimateProgress(), Equals, float32(1.0)) + re.Equal(20, rep.drRecoverCount) + re.Equal(float32(1.0), rep.estimateProgress()) } -func (s *testReplicationMode) genRegions(cluster *mockcluster.Cluster, stateID uint64, n int) []*core.RegionInfo { +func genRegions(cluster *mockcluster.Cluster, stateID uint64, n int) []*core.RegionInfo { var regions []*core.RegionInfo for i := 1; i <= n; i++ { cluster.AddLeaderRegion(uint64(i), 1) From 8359788ae8d7c33baa76fdf4db85cde85603920f Mon Sep 17 00:00:00 2001 From: JmPotato Date: Mon, 6 Jun 2022 19:02:29 +0800 Subject: [PATCH 18/82] *: use require.Equal to replace reflect.DeepEqual (#5109) close tikv/pd#5104 Use require.Equal to replace reflect.DeepEqual. Signed-off-by: JmPotato --- client/client_test.go | 7 ++-- pkg/autoscaling/calculation_test.go | 17 +++++---- pkg/cache/cache_test.go | 55 ++++++++++++++--------------- pkg/logutil/log_test.go | 7 ++-- pkg/typeutil/conversion_test.go | 2 +- pkg/typeutil/string_slice_test.go | 5 ++- tests/client/client_test.go | 20 +++++------ 7 files changed, 54 insertions(+), 59 deletions(-) diff --git a/client/client_test.go b/client/client_test.go index c80b78bb96b..93d9e5d8de9 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -16,7 +16,6 @@ package pd import ( "context" - "reflect" "testing" "time" @@ -58,11 +57,11 @@ func TestUpdateURLs(t *testing.T) { cli := &baseClient{option: newOption()} cli.urls.Store([]string{}) cli.updateURLs(members[1:]) - re.True(reflect.DeepEqual(getURLs([]*pdpb.Member{members[1], members[3], members[2]}), cli.GetURLs())) + re.Equal(getURLs([]*pdpb.Member{members[1], members[3], members[2]}), cli.GetURLs()) cli.updateURLs(members[1:]) - re.True(reflect.DeepEqual(getURLs([]*pdpb.Member{members[1], members[3], members[2]}), cli.GetURLs())) + re.Equal(getURLs([]*pdpb.Member{members[1], members[3], members[2]}), cli.GetURLs()) cli.updateURLs(members) - re.True(reflect.DeepEqual(getURLs([]*pdpb.Member{members[1], members[3], members[2], members[0]}), cli.GetURLs())) + re.Equal(getURLs([]*pdpb.Member{members[1], members[3], members[2], members[0]}), cli.GetURLs()) } const testClientURL = "tmp://test.url:5255" diff --git a/pkg/autoscaling/calculation_test.go b/pkg/autoscaling/calculation_test.go index f5ac3313ba4..5f9eaaf9767 100644 --- a/pkg/autoscaling/calculation_test.go +++ b/pkg/autoscaling/calculation_test.go @@ -19,7 +19,6 @@ import ( "encoding/json" "fmt" "math" - "reflect" "testing" "time" @@ -70,7 +69,7 @@ func TestGetScaledTiKVGroups(t *testing.T) { informer core.StoreSetInformer healthyInstances []instance expectedPlan []*Plan - noError bool + errorChecker func(err error, msgAndArgs ...interface{}) }{ { name: "no scaled tikv group", @@ -90,7 +89,7 @@ func TestGetScaledTiKVGroups(t *testing.T) { }, }, expectedPlan: nil, - noError: true, + errorChecker: re.NoError, }, { name: "exist 1 scaled tikv group", @@ -120,7 +119,7 @@ func TestGetScaledTiKVGroups(t *testing.T) { }, }, }, - noError: true, + errorChecker: re.NoError, }, { name: "exist 1 tikv scaled group with inconsistency healthy instances", @@ -140,7 +139,7 @@ func TestGetScaledTiKVGroups(t *testing.T) { }, }, expectedPlan: nil, - noError: false, + errorChecker: re.Error, }, { name: "exist 1 tikv scaled group with less healthy instances", @@ -166,7 +165,7 @@ func TestGetScaledTiKVGroups(t *testing.T) { }, }, }, - noError: true, + errorChecker: re.NoError, }, { name: "existed other tikv group", @@ -186,7 +185,7 @@ func TestGetScaledTiKVGroups(t *testing.T) { }, }, expectedPlan: nil, - noError: true, + errorChecker: re.NoError, }, } @@ -195,9 +194,9 @@ func TestGetScaledTiKVGroups(t *testing.T) { plans, err := getScaledTiKVGroups(testCase.informer, testCase.healthyInstances) if testCase.expectedPlan == nil { re.Len(plans, 0) - re.Equal(testCase.noError, err == nil) + testCase.errorChecker(err) } else { - re.True(reflect.DeepEqual(testCase.expectedPlan, plans)) + re.Equal(testCase.expectedPlan, plans) } } } diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go index bf1626450f7..ef257157006 100644 --- a/pkg/cache/cache_test.go +++ b/pkg/cache/cache_test.go @@ -16,7 +16,6 @@ package cache import ( "context" - "reflect" "sort" "testing" "time" @@ -76,7 +75,7 @@ func TestExpireRegionCache(t *testing.T) { re.Equal(3, cache.Len()) - re.True(reflect.DeepEqual(sortIDs(cache.GetAllID()), []uint64{1, 2, 3})) + re.Equal(sortIDs(cache.GetAllID()), []uint64{1, 2, 3}) time.Sleep(2 * time.Second) @@ -93,7 +92,7 @@ func TestExpireRegionCache(t *testing.T) { re.Equal(3.0, value) re.Equal(2, cache.Len()) - re.True(reflect.DeepEqual(sortIDs(cache.GetAllID()), []uint64{2, 3})) + re.Equal(sortIDs(cache.GetAllID()), []uint64{2, 3}) cache.Remove(2) @@ -106,7 +105,7 @@ func TestExpireRegionCache(t *testing.T) { re.Equal(3.0, value) re.Equal(1, cache.Len()) - re.True(reflect.DeepEqual(sortIDs(cache.GetAllID()), []uint64{3})) + re.Equal(sortIDs(cache.GetAllID()), []uint64{3}) } func sortIDs(ids []uint64) []uint64 { @@ -125,15 +124,15 @@ func TestLRUCache(t *testing.T) { val, ok := cache.Get(3) re.True(ok) - re.True(reflect.DeepEqual(val, "3")) + re.Equal(val, "3") val, ok = cache.Get(2) re.True(ok) - re.True(reflect.DeepEqual(val, "2")) + re.Equal(val, "2") val, ok = cache.Get(1) re.True(ok) - re.True(reflect.DeepEqual(val, "1")) + re.Equal(val, "1") re.Equal(3, cache.Len()) @@ -147,27 +146,27 @@ func TestLRUCache(t *testing.T) { val, ok = cache.Get(1) re.True(ok) - re.True(reflect.DeepEqual(val, "1")) + re.Equal(val, "1") val, ok = cache.Get(2) re.True(ok) - re.True(reflect.DeepEqual(val, "2")) + re.Equal(val, "2") val, ok = cache.Get(4) re.True(ok) - re.True(reflect.DeepEqual(val, "4")) + re.Equal(val, "4") re.Equal(3, cache.Len()) val, ok = cache.Peek(1) re.True(ok) - re.True(reflect.DeepEqual(val, "1")) + re.Equal(val, "1") elems := cache.Elems() re.Len(elems, 3) - re.True(reflect.DeepEqual(elems[0].Value, "4")) - re.True(reflect.DeepEqual(elems[1].Value, "2")) - re.True(reflect.DeepEqual(elems[2].Value, "1")) + re.Equal(elems[0].Value, "4") + re.Equal(elems[1].Value, "2") + re.Equal(elems[2].Value, "1") cache.Remove(1) cache.Remove(2) @@ -205,13 +204,13 @@ func TestFifoCache(t *testing.T) { elems := cache.Elems() re.Len(elems, 3) - re.True(reflect.DeepEqual(elems[0].Value, "2")) - re.True(reflect.DeepEqual(elems[1].Value, "3")) - re.True(reflect.DeepEqual(elems[2].Value, "4")) + re.Equal(elems[0].Value, "2") + re.Equal(elems[1].Value, "3") + re.Equal(elems[2].Value, "4") elems = cache.FromElems(3) re.Len(elems, 1) - re.True(reflect.DeepEqual(elems[0].Value, "4")) + re.Equal(elems[0].Value, "4") cache.Remove() cache.Remove() @@ -228,15 +227,15 @@ func TestTwoQueueCache(t *testing.T) { val, ok := cache.Get(3) re.True(ok) - re.True(reflect.DeepEqual(val, "3")) + re.Equal(val, "3") val, ok = cache.Get(2) re.True(ok) - re.True(reflect.DeepEqual(val, "2")) + re.Equal(val, "2") val, ok = cache.Get(1) re.True(ok) - re.True(reflect.DeepEqual(val, "1")) + re.Equal(val, "1") re.Equal(3, cache.Len()) @@ -250,27 +249,27 @@ func TestTwoQueueCache(t *testing.T) { val, ok = cache.Get(1) re.True(ok) - re.True(reflect.DeepEqual(val, "1")) + re.Equal(val, "1") val, ok = cache.Get(2) re.True(ok) - re.True(reflect.DeepEqual(val, "2")) + re.Equal(val, "2") val, ok = cache.Get(4) re.True(ok) - re.True(reflect.DeepEqual(val, "4")) + re.Equal(val, "4") re.Equal(3, cache.Len()) val, ok = cache.Peek(1) re.True(ok) - re.True(reflect.DeepEqual(val, "1")) + re.Equal(val, "1") elems := cache.Elems() re.Len(elems, 3) - re.True(reflect.DeepEqual(elems[0].Value, "4")) - re.True(reflect.DeepEqual(elems[1].Value, "2")) - re.True(reflect.DeepEqual(elems[2].Value, "1")) + re.Equal(elems[0].Value, "4") + re.Equal(elems[1].Value, "2") + re.Equal(elems[2].Value, "1") cache.Remove(1) cache.Remove(2) diff --git a/pkg/logutil/log_test.go b/pkg/logutil/log_test.go index 270a8e5b0ba..81913905704 100644 --- a/pkg/logutil/log_test.go +++ b/pkg/logutil/log_test.go @@ -16,7 +16,6 @@ package logutil import ( "fmt" - "reflect" "testing" "github.com/stretchr/testify/require" @@ -73,11 +72,11 @@ func TestRedactLog(t *testing.T) { SetRedactLog(testCase.enableRedactLog) switch r := testCase.arg.(type) { case []byte: - re.True(reflect.DeepEqual(testCase.expect, RedactBytes(r))) + re.Equal(testCase.expect, RedactBytes(r)) case string: - re.True(reflect.DeepEqual(testCase.expect, RedactString(r))) + re.Equal(testCase.expect, RedactString(r)) case fmt.Stringer: - re.True(reflect.DeepEqual(testCase.expect, RedactStringer(r))) + re.Equal(testCase.expect, RedactStringer(r)) default: panic("unmatched case") } diff --git a/pkg/typeutil/conversion_test.go b/pkg/typeutil/conversion_test.go index 4d28fa152f3..3398a1ee618 100644 --- a/pkg/typeutil/conversion_test.go +++ b/pkg/typeutil/conversion_test.go @@ -35,7 +35,7 @@ func TestUint64ToBytes(t *testing.T) { var a uint64 = 1000 b := Uint64ToBytes(a) str := "\x00\x00\x00\x00\x00\x00\x03\xe8" - re.True(reflect.DeepEqual([]byte(str), b)) + re.Equal([]byte(str), b) } func TestJSONToUint64Slice(t *testing.T) { diff --git a/pkg/typeutil/string_slice_test.go b/pkg/typeutil/string_slice_test.go index f50ddb9218d..9177cee0eb9 100644 --- a/pkg/typeutil/string_slice_test.go +++ b/pkg/typeutil/string_slice_test.go @@ -16,7 +16,6 @@ package typeutil import ( "encoding/json" - "reflect" "testing" "github.com/stretchr/testify/require" @@ -32,7 +31,7 @@ func TestStringSliceJSON(t *testing.T) { var nb StringSlice err = json.Unmarshal(o, &nb) re.NoError(err) - re.True(reflect.DeepEqual(b, nb)) + re.Equal(b, nb) } func TestEmpty(t *testing.T) { @@ -44,5 +43,5 @@ func TestEmpty(t *testing.T) { var ss2 StringSlice re.NoError(ss2.UnmarshalJSON(b)) - re.True(reflect.DeepEqual(ss, ss2)) + re.Equal(ss, ss2) } diff --git a/tests/client/client_test.go b/tests/client/client_test.go index 3afda979c44..9d24e5439fd 100644 --- a/tests/client/client_test.go +++ b/tests/client/client_test.go @@ -113,7 +113,7 @@ func TestClientLeaderChange(t *testing.T) { urls := cli.(client).GetURLs() sort.Strings(urls) sort.Strings(endpoints) - re.True(reflect.DeepEqual(endpoints, urls)) + re.Equal(endpoints, urls) } func TestLeaderTransfer(t *testing.T) { @@ -264,7 +264,7 @@ func TestTSOAllocatorLeader(t *testing.T) { urls := cli.(client).GetURLs() sort.Strings(urls) sort.Strings(endpoints) - re.True(reflect.DeepEqual(endpoints, urls)) + re.Equal(endpoints, urls) continue } pdName, exist := allocatorLeaderMap[dcLocation] @@ -974,20 +974,20 @@ func (suite *clientTestSuite) TestScanRegions() { t.Log("scanRegions", scanRegions) t.Log("expect", expect) for i := range expect { - suite.True(reflect.DeepEqual(expect[i], scanRegions[i].Meta)) + suite.Equal(expect[i], scanRegions[i].Meta) if scanRegions[i].Meta.GetId() == region3.GetID() { - suite.True(reflect.DeepEqual(&metapb.Peer{}, scanRegions[i].Leader)) + suite.Equal(&metapb.Peer{}, scanRegions[i].Leader) } else { - suite.True(reflect.DeepEqual(expect[i].Peers[0], scanRegions[i].Leader)) + suite.Equal(expect[i].Peers[0], scanRegions[i].Leader) } if scanRegions[i].Meta.GetId() == region4.GetID() { - suite.True(reflect.DeepEqual([]*metapb.Peer{expect[i].Peers[1]}, scanRegions[i].DownPeers)) + suite.Equal([]*metapb.Peer{expect[i].Peers[1]}, scanRegions[i].DownPeers) } if scanRegions[i].Meta.GetId() == region5.GetID() { - suite.True(reflect.DeepEqual([]*metapb.Peer{expect[i].Peers[1], expect[i].Peers[2]}, scanRegions[i].PendingPeers)) + suite.Equal([]*metapb.Peer{expect[i].Peers[1], expect[i].Peers[2]}, scanRegions[i].PendingPeers) } } } @@ -1036,7 +1036,7 @@ func (suite *clientTestSuite) TestGetStore() { // Get an up store should be OK. n, err := suite.client.GetStore(context.Background(), store.GetId()) suite.NoError(err) - suite.True(reflect.DeepEqual(store, n)) + suite.Equal(store, n) actualStores, err := suite.client.GetAllStores(context.Background()) suite.NoError(err) @@ -1053,7 +1053,7 @@ func (suite *clientTestSuite) TestGetStore() { // Get an offline store should be OK. n, err = suite.client.GetStore(context.Background(), store.GetId()) suite.NoError(err) - suite.True(reflect.DeepEqual(offlineStore, n)) + suite.Equal(offlineStore, n) // Should return offline stores. contains := false @@ -1062,7 +1062,7 @@ func (suite *clientTestSuite) TestGetStore() { for _, store := range stores { if store.GetId() == offlineStore.GetId() { contains = true - suite.True(reflect.DeepEqual(offlineStore, store)) + suite.Equal(offlineStore, store) } } suite.True(contains) From d02a98ac784c412f885fbab741e344d803b6db4d Mon Sep 17 00:00:00 2001 From: disksing Date: Mon, 6 Jun 2022 22:06:29 +0800 Subject: [PATCH 19/82] placement: migrate test framework to testify (#5101) close tikv/pd#5097 Signed-off-by: disksing Co-authored-by: Ti Chi Robot --- server/schedule/placement/config_test.go | 16 +- server/schedule/placement/fit_test.go | 53 ++- .../placement/label_constraint_test.go | 21 +- .../placement/region_rule_cache_test.go | 20 +- .../schedule/placement/rule_manager_test.go | 314 +++++++++--------- server/schedule/placement/rule_test.go | 37 +-- 6 files changed, 233 insertions(+), 228 deletions(-) diff --git a/server/schedule/placement/config_test.go b/server/schedule/placement/config_test.go index 0e51aed73d3..eeb9e7d3b0e 100644 --- a/server/schedule/placement/config_test.go +++ b/server/schedule/placement/config_test.go @@ -15,15 +15,13 @@ package placement import ( - . "github.com/pingcap/check" -) - -var _ = Suite(&testConfigSuite{}) + "testing" -type testConfigSuite struct { -} + "github.com/stretchr/testify/require" +) -func (s *testConfigSuite) TestTrim(c *C) { +func TestTrim(t *testing.T) { + re := require.New(t) rc := newRuleConfig() rc.setRule(&Rule{GroupID: "g1", ID: "id1"}) rc.setRule(&Rule{GroupID: "g1", ID: "id2"}) @@ -76,7 +74,7 @@ func (s *testConfigSuite) TestTrim(c *C) { p := rc.beginPatch() tc.ops(p) p.trim() - c.Assert(p.mut.rules, DeepEquals, tc.mutRules) - c.Assert(p.mut.groups, DeepEquals, tc.mutGroups) + re.Equal(tc.mutRules, p.mut.rules) + re.Equal(tc.mutGroups, p.mut.groups) } } diff --git a/server/schedule/placement/fit_test.go b/server/schedule/placement/fit_test.go index e3804de33ef..0a070c38f65 100644 --- a/server/schedule/placement/fit_test.go +++ b/server/schedule/placement/fit_test.go @@ -18,17 +18,15 @@ import ( "fmt" "strconv" "strings" + "testing" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/tikv/pd/server/core" ) -var _ = Suite(&testFitSuite{}) - -type testFitSuite struct{} - -func (s *testFitSuite) makeStores() StoreSet { +func makeStores() StoreSet { stores := core.NewStoresInfo() for zone := 1; zone <= 5; zone++ { for rack := 1; rack <= 5; rack++ { @@ -50,7 +48,7 @@ func (s *testFitSuite) makeStores() StoreSet { } // example: "1111_leader,1234,2111_learner" -func (s *testFitSuite) makeRegion(def string) *core.RegionInfo { +func makeRegion(def string) *core.RegionInfo { var regionMeta metapb.Region var leader *metapb.Peer for _, peerDef := range strings.Split(def, ",") { @@ -71,7 +69,7 @@ func (s *testFitSuite) makeRegion(def string) *core.RegionInfo { // example: "3/voter/zone=zone1+zone2,rack=rack2/zone,rack,host" // count role constraints location_labels -func (s *testFitSuite) makeRule(def string) *Rule { +func makeRule(def string) *Rule { var rule Rule splits := strings.Split(def, "/") rule.Count, _ = strconv.Atoi(splits[0]) @@ -92,7 +90,7 @@ func (s *testFitSuite) makeRule(def string) *Rule { return &rule } -func (s *testFitSuite) checkPeerMatch(peers []*metapb.Peer, expect string) bool { +func checkPeerMatch(peers []*metapb.Peer, expect string) bool { if len(peers) == 0 && expect == "" { return true } @@ -111,8 +109,9 @@ func (s *testFitSuite) checkPeerMatch(peers []*metapb.Peer, expect string) bool return len(m) == 0 } -func (s *testFitSuite) TestFitRegion(c *C) { - stores := s.makeStores() +func TestFitRegion(t *testing.T) { + re := require.New(t) + stores := makeStores() cases := []struct { region string @@ -140,34 +139,34 @@ func (s *testFitSuite) TestFitRegion(c *C) { } for _, cc := range cases { - region := s.makeRegion(cc.region) + region := makeRegion(cc.region) var rules []*Rule for _, r := range cc.rules { - rules = append(rules, s.makeRule(r)) + rules = append(rules, makeRule(r)) } rf := fitRegion(stores.GetStores(), region, rules) expects := strings.Split(cc.fitPeers, "/") for i, f := range rf.RuleFits { - c.Assert(s.checkPeerMatch(f.Peers, expects[i]), IsTrue) + re.True(checkPeerMatch(f.Peers, expects[i])) } if len(rf.RuleFits) < len(expects) { - c.Assert(s.checkPeerMatch(rf.OrphanPeers, expects[len(rf.RuleFits)]), IsTrue) + re.True(checkPeerMatch(rf.OrphanPeers, expects[len(rf.RuleFits)])) } } } - -func (s *testFitSuite) TestIsolationScore(c *C) { - stores := s.makeStores() +func TestIsolationScore(t *testing.T) { + as := assert.New(t) + stores := makeStores() testCases := []struct { - peers1 []uint64 - Checker - peers2 []uint64 + checker func(interface{}, interface{}, ...interface{}) bool + peers1 []uint64 + peers2 []uint64 }{ - {[]uint64{1111, 1112}, Less, []uint64{1111, 1121}}, - {[]uint64{1111, 1211}, Less, []uint64{1111, 2111}}, - {[]uint64{1111, 1211, 1311, 2111, 3111}, Less, []uint64{1111, 1211, 2111, 2211, 3111}}, - {[]uint64{1111, 1211, 2111, 2211, 3111}, Equals, []uint64{1111, 2111, 2211, 3111, 3211}}, - {[]uint64{1111, 1211, 2111, 2211, 3111}, Greater, []uint64{1111, 1121, 2111, 2211, 3111}}, + {as.Less, []uint64{1111, 1112}, []uint64{1111, 1121}}, + {as.Less, []uint64{1111, 1211}, []uint64{1111, 2111}}, + {as.Less, []uint64{1111, 1211, 1311, 2111, 3111}, []uint64{1111, 1211, 2111, 2211, 3111}}, + {as.Equal, []uint64{1111, 1211, 2111, 2211, 3111}, []uint64{1111, 2111, 2211, 3111, 3211}}, + {as.Greater, []uint64{1111, 1211, 2111, 2211, 3111}, []uint64{1111, 1121, 2111, 2211, 3111}}, } makePeers := func(ids []uint64) []*fitPeer { @@ -185,6 +184,6 @@ func (s *testFitSuite) TestIsolationScore(c *C) { peers1, peers2 := makePeers(tc.peers1), makePeers(tc.peers2) score1 := isolationScore(peers1, []string{"zone", "rack", "host"}) score2 := isolationScore(peers2, []string{"zone", "rack", "host"}) - c.Assert(score1, tc.Checker, score2) + tc.checker(score1, score2) } } diff --git a/server/schedule/placement/label_constraint_test.go b/server/schedule/placement/label_constraint_test.go index 705ea639c8d..dd6feebe94f 100644 --- a/server/schedule/placement/label_constraint_test.go +++ b/server/schedule/placement/label_constraint_test.go @@ -17,19 +17,12 @@ package placement import ( "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/server/core" ) -func TestPlacement(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testLabelConstraintsSuite{}) - -type testLabelConstraintsSuite struct{} - -func (s *testLabelConstraintsSuite) TestLabelConstraint(c *C) { +func TestLabelConstraint(t *testing.T) { + re := require.New(t) stores := []map[string]string{ {"zone": "zone1", "rack": "rack1"}, // 1 {"zone": "zone1", "rack": "rack2"}, // 2 @@ -61,11 +54,11 @@ func (s *testLabelConstraintsSuite) TestLabelConstraint(c *C) { matched = append(matched, j+1) } } - c.Assert(matched, DeepEquals, expect[i]) + re.Equal(expect[i], matched) } } - -func (s *testLabelConstraintsSuite) TestLabelConstraints(c *C) { +func TestLabelConstraints(t *testing.T) { + re := require.New(t) stores := []map[string]string{ {}, // 1 {"k1": "v1"}, // 2 @@ -100,6 +93,6 @@ func (s *testLabelConstraintsSuite) TestLabelConstraints(c *C) { matched = append(matched, j+1) } } - c.Assert(matched, DeepEquals, expect[i]) + re.Equal(expect[i], matched) } } diff --git a/server/schedule/placement/region_rule_cache_test.go b/server/schedule/placement/region_rule_cache_test.go index 54f32ad26ac..f38d13eba87 100644 --- a/server/schedule/placement/region_rule_cache_test.go +++ b/server/schedule/placement/region_rule_cache_test.go @@ -15,15 +15,17 @@ package placement import ( + "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/server/core" ) -func (s *testRuleSuite) TestRegionRuleFitCache(c *C) { +func TestRegionRuleFitCache(t *testing.T) { + re := require.New(t) originRegion := mockRegion(3, 0) originRules := addExtraRules(0) originStores := mockStores(3) @@ -174,20 +176,20 @@ func (s *testRuleSuite) TestRegionRuleFitCache(c *C) { }, } for _, testcase := range testcases { - c.Log(testcase.name) - c.Assert(cache.IsUnchanged(testcase.region, testcase.rules, mockStores(3)), Equals, testcase.unchanged) + t.Log(testcase.name) + re.Equal(testcase.unchanged, cache.IsUnchanged(testcase.region, testcase.rules, mockStores(3))) } for _, testcase := range testcases { - c.Log(testcase.name) - c.Assert(cache.IsUnchanged(testcase.region, testcase.rules, mockStoresNoHeartbeat(3)), Equals, false) + t.Log(testcase.name) + re.Equal(false, cache.IsUnchanged(testcase.region, testcase.rules, mockStoresNoHeartbeat(3))) } // Invalid Input4 - c.Assert(cache.IsUnchanged(mockRegion(3, 0), addExtraRules(0), nil), IsFalse) + re.False(cache.IsUnchanged(mockRegion(3, 0), addExtraRules(0), nil)) // Invalid Input5 - c.Assert(cache.IsUnchanged(mockRegion(3, 0), addExtraRules(0), []*core.StoreInfo{}), IsFalse) + re.False(cache.IsUnchanged(mockRegion(3, 0), addExtraRules(0), []*core.StoreInfo{})) // origin rules changed, assert whether cache is changed originRules[0].Version++ - c.Assert(cache.IsUnchanged(originRegion, originRules, originStores), IsFalse) + re.False(cache.IsUnchanged(originRegion, originRules, originStores)) } func mockRegionRuleFitCache(region *core.RegionInfo, rules []*Rule, regionStores []*core.StoreInfo) *RegionRuleFitCache { diff --git a/server/schedule/placement/rule_manager_test.go b/server/schedule/placement/rule_manager_test.go index ae750fe5f9b..0fdfd2f67a8 100644 --- a/server/schedule/placement/rule_manager_test.go +++ b/server/schedule/placement/rule_manager_test.go @@ -16,43 +16,43 @@ package placement import ( "encoding/hex" + "testing" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/codec" "github.com/tikv/pd/server/core" "github.com/tikv/pd/server/storage" "github.com/tikv/pd/server/storage/endpoint" ) -var _ = Suite(&testManagerSuite{}) - -type testManagerSuite struct { - store endpoint.RuleStorage - manager *RuleManager -} - -func (s *testManagerSuite) SetUpTest(c *C) { - s.store = storage.NewStorageWithMemoryBackend() +func newTestManager(t *testing.T) (endpoint.RuleStorage, *RuleManager) { + re := require.New(t) + store := storage.NewStorageWithMemoryBackend() var err error - s.manager = NewRuleManager(s.store, nil, nil) - err = s.manager.Initialize(3, []string{"zone", "rack", "host"}) - c.Assert(err, IsNil) + manager := NewRuleManager(store, nil, nil) + err = manager.Initialize(3, []string{"zone", "rack", "host"}) + re.NoError(err) + return store, manager } -func (s *testManagerSuite) TestDefault(c *C) { - rules := s.manager.GetAllRules() - c.Assert(rules, HasLen, 1) - c.Assert(rules[0].GroupID, Equals, "pd") - c.Assert(rules[0].ID, Equals, "default") - c.Assert(rules[0].Index, Equals, 0) - c.Assert(rules[0].StartKey, HasLen, 0) - c.Assert(rules[0].EndKey, HasLen, 0) - c.Assert(rules[0].Role, Equals, Voter) - c.Assert(rules[0].LocationLabels, DeepEquals, []string{"zone", "rack", "host"}) +func TestDefault(t *testing.T) { + re := require.New(t) + _, manager := newTestManager(t) + rules := manager.GetAllRules() + re.Len(rules, 1) + re.Equal("pd", rules[0].GroupID) + re.Equal("default", rules[0].ID) + re.Equal(0, rules[0].Index) + re.Len(rules[0].StartKey, 0) + re.Len(rules[0].EndKey, 0) + re.Equal(Voter, rules[0].Role) + re.Equal([]string{"zone", "rack", "host"}, rules[0].LocationLabels) } -func (s *testManagerSuite) TestAdjustRule(c *C) { +func TestAdjustRule(t *testing.T) { + re := require.New(t) + _, manager := newTestManager(t) rules := []Rule{ {GroupID: "group", ID: "id", StartKeyHex: "123abc", EndKeyHex: "123abf", Role: "voter", Count: 3}, {GroupID: "", ID: "id", StartKeyHex: "123abc", EndKeyHex: "123abf", Role: "voter", Count: 3}, @@ -65,32 +65,38 @@ func (s *testManagerSuite) TestAdjustRule(c *C) { {GroupID: "group", ID: "id", StartKeyHex: "123abc", EndKeyHex: "123abf", Role: "voter", Count: -1}, {GroupID: "group", ID: "id", StartKeyHex: "123abc", EndKeyHex: "123abf", Role: "voter", Count: 3, LabelConstraints: []LabelConstraint{{Op: "foo"}}}, } - c.Assert(s.manager.adjustRule(&rules[0], "group"), IsNil) - c.Assert(rules[0].StartKey, DeepEquals, []byte{0x12, 0x3a, 0xbc}) - c.Assert(rules[0].EndKey, DeepEquals, []byte{0x12, 0x3a, 0xbf}) - c.Assert(s.manager.adjustRule(&rules[1], ""), NotNil) + re.NoError(manager.adjustRule(&rules[0], "group")) + + re.Equal([]byte{0x12, 0x3a, 0xbc}, rules[0].StartKey) + re.Equal([]byte{0x12, 0x3a, 0xbf}, rules[0].EndKey) + re.Error(manager.adjustRule(&rules[1], "")) + for i := 2; i < len(rules); i++ { - c.Assert(s.manager.adjustRule(&rules[i], "group"), NotNil) + re.Error(manager.adjustRule(&rules[i], "group")) } - s.manager.SetKeyType(core.Table.String()) - c.Assert(s.manager.adjustRule(&Rule{GroupID: "group", ID: "id", StartKeyHex: "123abc", EndKeyHex: "123abf", Role: "voter", Count: 3}, "group"), NotNil) - s.manager.SetKeyType(core.Txn.String()) - c.Assert(s.manager.adjustRule(&Rule{GroupID: "group", ID: "id", StartKeyHex: "123abc", EndKeyHex: "123abf", Role: "voter", Count: 3}, "group"), NotNil) - c.Assert(s.manager.adjustRule(&Rule{ + manager.SetKeyType(core.Table.String()) + re.Error(manager.adjustRule(&Rule{GroupID: "group", ID: "id", StartKeyHex: "123abc", EndKeyHex: "123abf", Role: "voter", Count: 3}, "group")) + + manager.SetKeyType(core.Txn.String()) + re.Error(manager.adjustRule(&Rule{GroupID: "group", ID: "id", StartKeyHex: "123abc", EndKeyHex: "123abf", Role: "voter", Count: 3}, "group")) + + re.Error(manager.adjustRule(&Rule{ GroupID: "group", ID: "id", StartKeyHex: hex.EncodeToString(codec.EncodeBytes([]byte{0})), EndKeyHex: "123abf", Role: "voter", Count: 3, - }, "group"), NotNil) + }, "group")) } -func (s *testManagerSuite) TestLeaderCheck(c *C) { - c.Assert(s.manager.SetRule(&Rule{GroupID: "pd", ID: "default", Role: "learner", Count: 3}), ErrorMatches, ".*needs at least one leader or voter.*") - c.Assert(s.manager.SetRule(&Rule{GroupID: "g2", ID: "33", Role: "leader", Count: 2}), ErrorMatches, ".*define multiple leaders by count 2.*") - c.Assert(s.manager.Batch([]RuleOp{ +func TestLeaderCheck(t *testing.T) { + re := require.New(t) + _, manager := newTestManager(t) + re.Regexp(".*needs at least one leader or voter.*", manager.SetRule(&Rule{GroupID: "pd", ID: "default", Role: "learner", Count: 3}).Error()) + re.Regexp(".*define multiple leaders by count 2.*", manager.SetRule(&Rule{GroupID: "g2", ID: "33", Role: "leader", Count: 2}).Error()) + re.Regexp(".*multiple leader replicas.*", manager.Batch([]RuleOp{ { Rule: &Rule{GroupID: "g2", ID: "foo1", Role: "leader", Count: 1}, Action: RuleOpAdd, @@ -99,49 +105,55 @@ func (s *testManagerSuite) TestLeaderCheck(c *C) { Rule: &Rule{GroupID: "g2", ID: "foo2", Role: "leader", Count: 1}, Action: RuleOpAdd, }, - }), ErrorMatches, ".*multiple leader replicas.*") + }).Error()) } -func (s *testManagerSuite) TestSaveLoad(c *C) { +func TestSaveLoad(t *testing.T) { + re := require.New(t) + store, manager := newTestManager(t) rules := []*Rule{ {GroupID: "pd", ID: "default", Role: "voter", Count: 5}, {GroupID: "foo", ID: "baz", StartKeyHex: "", EndKeyHex: "abcd", Role: "voter", Count: 1}, {GroupID: "foo", ID: "bar", Role: "learner", Count: 1}, } for _, r := range rules { - c.Assert(s.manager.SetRule(r.Clone()), IsNil) + re.NoError(manager.SetRule(r.Clone())) } - m2 := NewRuleManager(s.store, nil, nil) + m2 := NewRuleManager(store, nil, nil) err := m2.Initialize(3, []string{"no", "labels"}) - c.Assert(err, IsNil) - c.Assert(m2.GetAllRules(), HasLen, 3) - c.Assert(m2.GetRule("pd", "default").String(), Equals, rules[0].String()) - c.Assert(m2.GetRule("foo", "baz").String(), Equals, rules[1].String()) - c.Assert(m2.GetRule("foo", "bar").String(), Equals, rules[2].String()) + re.NoError(err) + re.Len(m2.GetAllRules(), 3) + re.Equal(rules[0].String(), m2.GetRule("pd", "default").String()) + re.Equal(rules[1].String(), m2.GetRule("foo", "baz").String()) + re.Equal(rules[2].String(), m2.GetRule("foo", "bar").String()) } -// https://github.com/tikv/pd/issues/3886 -func (s *testManagerSuite) TestSetAfterGet(c *C) { - rule := s.manager.GetRule("pd", "default") +func TestSetAfterGet(t *testing.T) { + re := require.New(t) + store, manager := newTestManager(t) + rule := manager.GetRule("pd", "default") rule.Count = 1 - s.manager.SetRule(rule) + manager.SetRule(rule) - m2 := NewRuleManager(s.store, nil, nil) + m2 := NewRuleManager(store, nil, nil) err := m2.Initialize(100, []string{}) - c.Assert(err, IsNil) + re.NoError(err) rule = m2.GetRule("pd", "default") - c.Assert(rule.Count, Equals, 1) + re.Equal(1, rule.Count) } -func (s *testManagerSuite) checkRules(c *C, rules []*Rule, expect [][2]string) { - c.Assert(rules, HasLen, len(expect)) +func checkRules(t *testing.T, rules []*Rule, expect [][2]string) { + re := require.New(t) + re.Len(rules, len(expect)) for i := range rules { - c.Assert(rules[i].Key(), DeepEquals, expect[i]) + re.Equal(expect[i], rules[i].Key()) } } -func (s *testManagerSuite) TestKeys(c *C) { +func TestKeys(t *testing.T) { + re := require.New(t) + _, manager := newTestManager(t) rules := []*Rule{ {GroupID: "1", ID: "1", Role: "voter", Count: 1, StartKeyHex: "", EndKeyHex: ""}, {GroupID: "2", ID: "2", Role: "voter", Count: 1, StartKeyHex: "11", EndKeyHex: "ff"}, @@ -150,24 +162,24 @@ func (s *testManagerSuite) TestKeys(c *C) { toDelete := []RuleOp{} for _, r := range rules { - s.manager.SetRule(r) + manager.SetRule(r) toDelete = append(toDelete, RuleOp{ Rule: r, Action: RuleOpDel, DeleteByIDPrefix: false, }) } - s.checkRules(c, s.manager.GetAllRules(), [][2]string{{"1", "1"}, {"2", "2"}, {"2", "3"}, {"pd", "default"}}) - s.manager.Batch(toDelete) - s.checkRules(c, s.manager.GetAllRules(), [][2]string{{"pd", "default"}}) + checkRules(t, manager.GetAllRules(), [][2]string{{"1", "1"}, {"2", "2"}, {"2", "3"}, {"pd", "default"}}) + manager.Batch(toDelete) + checkRules(t, manager.GetAllRules(), [][2]string{{"pd", "default"}}) rules = append(rules, &Rule{GroupID: "3", ID: "4", Role: "voter", Count: 1, StartKeyHex: "44", EndKeyHex: "ee"}, &Rule{GroupID: "3", ID: "5", Role: "voter", Count: 1, StartKeyHex: "44", EndKeyHex: "dd"}) - s.manager.SetRules(rules) - s.checkRules(c, s.manager.GetAllRules(), [][2]string{{"1", "1"}, {"2", "2"}, {"2", "3"}, {"3", "4"}, {"3", "5"}, {"pd", "default"}}) + manager.SetRules(rules) + checkRules(t, manager.GetAllRules(), [][2]string{{"1", "1"}, {"2", "2"}, {"2", "3"}, {"3", "4"}, {"3", "5"}, {"pd", "default"}}) - s.manager.DeleteRule("pd", "default") - s.checkRules(c, s.manager.GetAllRules(), [][2]string{{"1", "1"}, {"2", "2"}, {"2", "3"}, {"3", "4"}, {"3", "5"}}) + manager.DeleteRule("pd", "default") + checkRules(t, manager.GetAllRules(), [][2]string{{"1", "1"}, {"2", "2"}, {"2", "3"}, {"3", "4"}, {"3", "5"}}) splitKeys := [][]string{ {"", "", "11", "22", "44", "dd", "ee", "ff"}, @@ -176,10 +188,10 @@ func (s *testManagerSuite) TestKeys(c *C) { {"22", "ef", "44", "dd", "ee"}, } for _, keys := range splitKeys { - splits := s.manager.GetSplitKeys(s.dhex(keys[0]), s.dhex(keys[1])) - c.Assert(splits, HasLen, len(keys)-2) + splits := manager.GetSplitKeys(dhex(keys[0]), dhex(keys[1])) + re.Len(splits, len(keys)-2) for i := range splits { - c.Assert(splits[i], DeepEquals, s.dhex(keys[i+2])) + re.Equal(dhex(keys[i+2]), splits[i]) } } @@ -190,12 +202,12 @@ func (s *testManagerSuite) TestKeys(c *C) { {{"11", "33"}}, } for _, keys := range regionKeys { - region := core.NewRegionInfo(&metapb.Region{StartKey: s.dhex(keys[0][0]), EndKey: s.dhex(keys[0][1])}, nil) - rules := s.manager.GetRulesForApplyRegion(region) - c.Assert(rules, HasLen, len(keys)-1) + region := core.NewRegionInfo(&metapb.Region{StartKey: dhex(keys[0][0]), EndKey: dhex(keys[0][1])}, nil) + rules := manager.GetRulesForApplyRegion(region) + re.Len(rules, len(keys)-1) for i := range rules { - c.Assert(rules[i].StartKeyHex, Equals, keys[i+1][0]) - c.Assert(rules[i].EndKeyHex, Equals, keys[i+1][1]) + re.Equal(keys[i+1][0], rules[i].StartKeyHex) + re.Equal(keys[i+1][1], rules[i].EndKeyHex) } } @@ -205,11 +217,11 @@ func (s *testManagerSuite) TestKeys(c *C) { {"33", "", "", "11", "ff", "22", "dd"}, } for _, keys := range ruleByKeys { - rules := s.manager.GetRulesByKey(s.dhex(keys[0])) - c.Assert(rules, HasLen, (len(keys)-1)/2) + rules := manager.GetRulesByKey(dhex(keys[0])) + re.Len(rules, (len(keys)-1)/2) for i := range rules { - c.Assert(rules[i].StartKeyHex, Equals, keys[i*2+1]) - c.Assert(rules[i].EndKeyHex, Equals, keys[i*2+2]) + re.Equal(keys[i*2+1], rules[i].StartKeyHex) + re.Equal(keys[i*2+2], rules[i].EndKeyHex) } } @@ -220,126 +232,130 @@ func (s *testManagerSuite) TestKeys(c *C) { {"4"}, } for _, keys := range rulesByGroup { - rules := s.manager.GetRulesByGroup(keys[0]) - c.Assert(rules, HasLen, (len(keys)-1)/2) + rules := manager.GetRulesByGroup(keys[0]) + re.Len(rules, (len(keys)-1)/2) for i := range rules { - c.Assert(rules[i].StartKeyHex, Equals, keys[i*2+1]) - c.Assert(rules[i].EndKeyHex, Equals, keys[i*2+2]) + re.Equal(keys[i*2+1], rules[i].StartKeyHex) + re.Equal(keys[i*2+2], rules[i].EndKeyHex) } } } -func (s *testManagerSuite) TestDeleteByIDPrefix(c *C) { - s.manager.SetRules([]*Rule{ +func TestDeleteByIDPrefix(t *testing.T) { + _, manager := newTestManager(t) + manager.SetRules([]*Rule{ {GroupID: "g1", ID: "foo1", Role: "voter", Count: 1}, {GroupID: "g2", ID: "foo1", Role: "voter", Count: 1}, {GroupID: "g2", ID: "foobar", Role: "voter", Count: 1}, {GroupID: "g2", ID: "baz2", Role: "voter", Count: 1}, }) - s.manager.DeleteRule("pd", "default") - s.checkRules(c, s.manager.GetAllRules(), [][2]string{{"g1", "foo1"}, {"g2", "baz2"}, {"g2", "foo1"}, {"g2", "foobar"}}) + manager.DeleteRule("pd", "default") + checkRules(t, manager.GetAllRules(), [][2]string{{"g1", "foo1"}, {"g2", "baz2"}, {"g2", "foo1"}, {"g2", "foobar"}}) - s.manager.Batch([]RuleOp{{ + manager.Batch([]RuleOp{{ Rule: &Rule{GroupID: "g2", ID: "foo"}, Action: RuleOpDel, DeleteByIDPrefix: true, }}) - s.checkRules(c, s.manager.GetAllRules(), [][2]string{{"g1", "foo1"}, {"g2", "baz2"}}) + checkRules(t, manager.GetAllRules(), [][2]string{{"g1", "foo1"}, {"g2", "baz2"}}) } -func (s *testManagerSuite) TestRangeGap(c *C) { - // |-- default --| - // cannot delete the last rule - err := s.manager.DeleteRule("pd", "default") - c.Assert(err, NotNil) +func TestRangeGap(t *testing.T) { + re := require.New(t) + _, manager := newTestManager(t) + err := manager.DeleteRule("pd", "default") + re.Error(err) - err = s.manager.SetRule(&Rule{GroupID: "pd", ID: "foo", StartKeyHex: "", EndKeyHex: "abcd", Role: "voter", Count: 1}) - c.Assert(err, IsNil) + err = manager.SetRule(&Rule{GroupID: "pd", ID: "foo", StartKeyHex: "", EndKeyHex: "abcd", Role: "voter", Count: 1}) + re.NoError(err) // |-- default --| // |-- foo --| // still cannot delete default since it will cause ("abcd", "") has no rules inside. - err = s.manager.DeleteRule("pd", "default") - c.Assert(err, NotNil) - err = s.manager.SetRule(&Rule{GroupID: "pd", ID: "bar", StartKeyHex: "abcd", EndKeyHex: "", Role: "voter", Count: 1}) - c.Assert(err, IsNil) + err = manager.DeleteRule("pd", "default") + re.Error(err) + err = manager.SetRule(&Rule{GroupID: "pd", ID: "bar", StartKeyHex: "abcd", EndKeyHex: "", Role: "voter", Count: 1}) + re.NoError(err) // now default can be deleted. - err = s.manager.DeleteRule("pd", "default") - c.Assert(err, IsNil) + err = manager.DeleteRule("pd", "default") + re.NoError(err) // cannot change range since it will cause ("abaa", "abcd") has no rules inside. - err = s.manager.SetRule(&Rule{GroupID: "pd", ID: "foo", StartKeyHex: "", EndKeyHex: "abaa", Role: "voter", Count: 1}) - c.Assert(err, NotNil) + err = manager.SetRule(&Rule{GroupID: "pd", ID: "foo", StartKeyHex: "", EndKeyHex: "abaa", Role: "voter", Count: 1}) + re.Error(err) } -func (s *testManagerSuite) TestGroupConfig(c *C) { - // group pd +func TestGroupConfig(t *testing.T) { + re := require.New(t) + _, manager := newTestManager(t) pd1 := &RuleGroup{ID: "pd"} - c.Assert(s.manager.GetRuleGroup("pd"), DeepEquals, pd1) + re.Equal(pd1, manager.GetRuleGroup("pd")) // update group pd pd2 := &RuleGroup{ID: "pd", Index: 100, Override: true} - err := s.manager.SetRuleGroup(pd2) - c.Assert(err, IsNil) - c.Assert(s.manager.GetRuleGroup("pd"), DeepEquals, pd2) + err := manager.SetRuleGroup(pd2) + re.NoError(err) + re.Equal(pd2, manager.GetRuleGroup("pd")) // new group g without config - err = s.manager.SetRule(&Rule{GroupID: "g", ID: "1", Role: "voter", Count: 1}) - c.Assert(err, IsNil) + err = manager.SetRule(&Rule{GroupID: "g", ID: "1", Role: "voter", Count: 1}) + re.NoError(err) g1 := &RuleGroup{ID: "g"} - c.Assert(s.manager.GetRuleGroup("g"), DeepEquals, g1) - c.Assert(s.manager.GetRuleGroups(), DeepEquals, []*RuleGroup{g1, pd2}) + re.Equal(g1, manager.GetRuleGroup("g")) + re.Equal([]*RuleGroup{g1, pd2}, manager.GetRuleGroups()) // update group g g2 := &RuleGroup{ID: "g", Index: 2, Override: true} - err = s.manager.SetRuleGroup(g2) - c.Assert(err, IsNil) - c.Assert(s.manager.GetRuleGroups(), DeepEquals, []*RuleGroup{g2, pd2}) + err = manager.SetRuleGroup(g2) + re.NoError(err) + re.Equal([]*RuleGroup{g2, pd2}, manager.GetRuleGroups()) // delete pd group, restore to default config - err = s.manager.DeleteRuleGroup("pd") - c.Assert(err, IsNil) - c.Assert(s.manager.GetRuleGroups(), DeepEquals, []*RuleGroup{pd1, g2}) + err = manager.DeleteRuleGroup("pd") + re.NoError(err) + re.Equal([]*RuleGroup{pd1, g2}, manager.GetRuleGroups()) // delete rule, the group is removed too - err = s.manager.DeleteRule("pd", "default") - c.Assert(err, IsNil) - c.Assert(s.manager.GetRuleGroups(), DeepEquals, []*RuleGroup{g2}) + err = manager.DeleteRule("pd", "default") + re.NoError(err) + re.Equal([]*RuleGroup{g2}, manager.GetRuleGroups()) } -func (s *testManagerSuite) TestRuleVersion(c *C) { - // default rule - rule1 := s.manager.GetRule("pd", "default") - c.Assert(rule1.Version, Equals, uint64(0)) +func TestRuleVersion(t *testing.T) { + re := require.New(t) + _, manager := newTestManager(t) + rule1 := manager.GetRule("pd", "default") + re.Equal(uint64(0), rule1.Version) // create new rule newRule := &Rule{GroupID: "g1", ID: "id", StartKeyHex: "123abc", EndKeyHex: "123abf", Role: "voter", Count: 3} - err := s.manager.SetRule(newRule) - c.Assert(err, IsNil) - newRule = s.manager.GetRule("g1", "id") - c.Assert(newRule.Version, Equals, uint64(0)) + err := manager.SetRule(newRule) + re.NoError(err) + newRule = manager.GetRule("g1", "id") + re.Equal(uint64(0), newRule.Version) // update rule newRule = &Rule{GroupID: "g1", ID: "id", StartKeyHex: "123abc", EndKeyHex: "123abf", Role: "voter", Count: 2} - err = s.manager.SetRule(newRule) - c.Assert(err, IsNil) - newRule = s.manager.GetRule("g1", "id") - c.Assert(newRule.Version, Equals, uint64(1)) + err = manager.SetRule(newRule) + re.NoError(err) + newRule = manager.GetRule("g1", "id") + re.Equal(uint64(1), newRule.Version) // delete rule - err = s.manager.DeleteRule("g1", "id") - c.Assert(err, IsNil) + err = manager.DeleteRule("g1", "id") + re.NoError(err) // recreate new rule - err = s.manager.SetRule(newRule) - c.Assert(err, IsNil) + err = manager.SetRule(newRule) + re.NoError(err) // assert version should be 0 again - newRule = s.manager.GetRule("g1", "id") - c.Assert(newRule.Version, Equals, uint64(0)) + newRule = manager.GetRule("g1", "id") + re.Equal(uint64(0), newRule.Version) } -func (s *testManagerSuite) TestCheckApplyRules(c *C) { +func TestCheckApplyRules(t *testing.T) { + re := require.New(t) err := checkApplyRules([]*Rule{ { Role: Leader, Count: 1, }, }) - c.Assert(err, IsNil) + re.NoError(err) err = checkApplyRules([]*Rule{ { @@ -347,7 +363,7 @@ func (s *testManagerSuite) TestCheckApplyRules(c *C) { Count: 1, }, }) - c.Assert(err, IsNil) + re.NoError(err) err = checkApplyRules([]*Rule{ { @@ -359,7 +375,7 @@ func (s *testManagerSuite) TestCheckApplyRules(c *C) { Count: 1, }, }) - c.Assert(err, IsNil) + re.NoError(err) err = checkApplyRules([]*Rule{ { @@ -367,7 +383,7 @@ func (s *testManagerSuite) TestCheckApplyRules(c *C) { Count: 3, }, }) - c.Assert(err, ErrorMatches, "multiple leader replicas") + re.Regexp("multiple leader replicas", err.Error()) err = checkApplyRules([]*Rule{ { @@ -379,7 +395,7 @@ func (s *testManagerSuite) TestCheckApplyRules(c *C) { Count: 1, }, }) - c.Assert(err, ErrorMatches, "multiple leader replicas") + re.Regexp("multiple leader replicas", err.Error()) err = checkApplyRules([]*Rule{ { @@ -391,10 +407,10 @@ func (s *testManagerSuite) TestCheckApplyRules(c *C) { Count: 1, }, }) - c.Assert(err, ErrorMatches, "needs at least one leader or voter") + re.Regexp("needs at least one leader or voter", err.Error()) } -func (s *testManagerSuite) dhex(hk string) []byte { +func dhex(hk string) []byte { k, err := hex.DecodeString(hk) if err != nil { panic("decode fail") diff --git a/server/schedule/placement/rule_test.go b/server/schedule/placement/rule_test.go index 2e0ea8566c5..94f623ef93d 100644 --- a/server/schedule/placement/rule_test.go +++ b/server/schedule/placement/rule_test.go @@ -17,15 +17,13 @@ package placement import ( "encoding/hex" "math/rand" + "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testRuleSuite{}) - -type testRuleSuite struct{} - -func (s *testRuleSuite) TestPrepareRulesForApply(c *C) { +func TestPrepareRulesForApply(t *testing.T) { + re := require.New(t) rules := []*Rule{ {GroupID: "g1", Index: 0, ID: "id5"}, {GroupID: "g1", Index: 0, ID: "id6"}, @@ -56,13 +54,13 @@ func (s *testRuleSuite) TestPrepareRulesForApply(c *C) { sortRules(rules) rules = prepareRulesForApply(rules) - c.Assert(len(rules), Equals, len(expected)) + re.Equal(len(expected), len(rules)) for i := range rules { - c.Assert(rules[i].Key(), Equals, expected[i]) + re.Equal(expected[i], rules[i].Key()) } } - -func (s *testRuleSuite) TestGroupProperties(c *C) { +func TestGroupProperties(t *testing.T) { + re := require.New(t) testCases := []struct { rules []*Rule expect [][2]string @@ -103,15 +101,14 @@ func (s *testRuleSuite) TestGroupProperties(c *C) { rand.Shuffle(len(tc.rules), func(i, j int) { tc.rules[i], tc.rules[j] = tc.rules[j], tc.rules[i] }) sortRules(tc.rules) rules := prepareRulesForApply(tc.rules) - c.Assert(rules, HasLen, len(tc.expect)) + re.Len(rules, len(tc.expect)) for i := range rules { - c.Assert(rules[i].Key(), Equals, tc.expect[i]) + re.Equal(tc.expect[i], rules[i].Key()) } } } - -// TODO: fulfill unit test case to cover BuildRuleList -func (s *testRuleSuite) TestBuildRuleList(c *C) { +func TestBuildRuleList(t *testing.T) { + re := require.New(t) defaultRule := &Rule{ GroupID: "pd", ID: "default", @@ -121,9 +118,9 @@ func (s *testRuleSuite) TestBuildRuleList(c *C) { Count: 3, } byteStart, err := hex.DecodeString("a1") - c.Check(err, IsNil) + re.NoError(err) byteEnd, err := hex.DecodeString("a2") - c.Check(err, IsNil) + re.NoError(err) ruleMeta := &Rule{ GroupID: "pd", ID: "meta", @@ -182,10 +179,10 @@ func (s *testRuleSuite) TestBuildRuleList(c *C) { } for _, testcase := range testcases { - c.Log(testcase.name) + t.Log(testcase.name) config := &ruleConfig{rules: testcase.rules} result, err := buildRuleList(config) - c.Assert(err, IsNil) - c.Assert(result.ranges, DeepEquals, testcase.expect.ranges) + re.NoError(err) + re.Equal(testcase.expect.ranges, result.ranges) } } From e19dc71ac50522bd91ac72052f7c8b39bbb06b3f Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Tue, 7 Jun 2022 10:18:29 +0800 Subject: [PATCH 20/82] *: fix the wrong pending status (#5080) close tikv/pd#5095 Signed-off-by: Ryan Leung --- server/api/region.go | 9 +++---- server/cluster/coordinator.go | 4 +++ server/schedule/checker/checker_controller.go | 6 +++++ server/schedule/checker/rule_checker.go | 26 ++++++++++++++----- server/schedule/checker/rule_checker_test.go | 21 +++++++++++++++ 5 files changed, 53 insertions(+), 13 deletions(-) diff --git a/server/api/region.go b/server/api/region.go index 2fa7866d22c..fa25ca1bd17 100644 --- a/server/api/region.go +++ b/server/api/region.go @@ -280,13 +280,10 @@ func (h *regionsHandler) CheckRegionsReplicated(w http.ResponseWriter, r *http.R for _, region := range regions { if !schedule.IsRegionReplicated(rc, region) { state = "INPROGRESS" - for _, item := range rc.GetCoordinator().GetWaitingRegions() { - if item.Key == region.GetID() { - state = "PENDING" - break - } + if rc.GetCoordinator().IsPendingRegion(region.GetID()) { + state = "PENDING" + break } - break } } failpoint.Inject("mockPending", func(val failpoint.Value) { diff --git a/server/cluster/coordinator.go b/server/cluster/coordinator.go index 530e858877f..b3f72a3be5f 100644 --- a/server/cluster/coordinator.go +++ b/server/cluster/coordinator.go @@ -97,6 +97,10 @@ func (c *coordinator) GetWaitingRegions() []*cache.Item { return c.checkers.GetWaitingRegions() } +func (c *coordinator) IsPendingRegion(region uint64) bool { + return c.checkers.IsPendingRegion(region) +} + // patrolRegions is used to scan regions. // The checkers will check these regions to decide if they need to do some operations. func (c *coordinator) patrolRegions() { diff --git a/server/schedule/checker/checker_controller.go b/server/schedule/checker/checker_controller.go index 11d1638096a..4e7e28334d0 100644 --- a/server/schedule/checker/checker_controller.go +++ b/server/schedule/checker/checker_controller.go @@ -204,6 +204,12 @@ func (c *Controller) ClearSuspectKeyRanges() { c.suspectKeyRanges.Clear() } +// IsPendingRegion returns true if the given region is in the pending list. +func (c *Controller) IsPendingRegion(regionID uint64) bool { + _, exist := c.ruleChecker.pendingList.Get(regionID) + return exist +} + // GetPauseController returns pause controller of the checker func (c *Controller) GetPauseController(name string) (*PauseController, error) { switch name { diff --git a/server/schedule/checker/rule_checker.go b/server/schedule/checker/rule_checker.go index cffc4c9b39e..8db66a9370b 100644 --- a/server/schedule/checker/rule_checker.go +++ b/server/schedule/checker/rule_checker.go @@ -39,6 +39,8 @@ var ( errNoNewLeader = errors.New("no new leader") ) +const maxPendingListLen = 100000 + // RuleChecker fix/improve region by placement rules. type RuleChecker struct { PauseController @@ -46,6 +48,7 @@ type RuleChecker struct { ruleManager *placement.RuleManager name string regionWaitingList cache.Cache + pendingList cache.Cache record *recorder } @@ -56,6 +59,7 @@ func NewRuleChecker(cluster schedule.Cluster, ruleManager *placement.RuleManager ruleManager: ruleManager, name: "rule-checker", regionWaitingList: regionWaitingList, + pendingList: cache.NewDefaultCache(maxPendingListLen), record: newRecord(), } } @@ -107,6 +111,7 @@ func (c *RuleChecker) CheckWithFit(region *core.RegionInfo, fit *placement.Regio if err != nil { log.Debug("fail to fix orphan peer", errs.ZapError(err)) } else if op != nil { + c.pendingList.Remove(region.GetID()) return op } for _, rf := range fit.RuleFits { @@ -116,6 +121,7 @@ func (c *RuleChecker) CheckWithFit(region *core.RegionInfo, fit *placement.Regio continue } if op != nil { + c.pendingList.Remove(region.GetID()) return op } } @@ -164,9 +170,7 @@ func (c *RuleChecker) addRulePeer(region *core.RegionInfo, rf *placement.RuleFit store, filterByTempState := c.strategy(region, rf.Rule).SelectStoreToAdd(ruleStores) if store == 0 { checkerCounter.WithLabelValues("rule_checker", "no-store-add").Inc() - if filterByTempState { - c.regionWaitingList.Put(region.GetID(), nil) - } + c.handleFilterState(region, filterByTempState) return nil, errNoStoreToAdd } peer := &metapb.Peer{StoreId: store, Role: rf.Rule.Role.MetaPeerRole()} @@ -184,9 +188,7 @@ func (c *RuleChecker) replaceUnexpectRulePeer(region *core.RegionInfo, rf *place store, filterByTempState := c.strategy(region, rf.Rule).SelectStoreToFix(ruleStores, peer.GetStoreId()) if store == 0 { checkerCounter.WithLabelValues("rule_checker", "no-store-replace").Inc() - if filterByTempState { - c.regionWaitingList.Put(region.GetID(), nil) - } + c.handleFilterState(region, filterByTempState) return nil, errNoStoreToReplace } newPeer := &metapb.Peer{StoreId: store, Role: rf.Rule.Role.MetaPeerRole()} @@ -291,9 +293,10 @@ func (c *RuleChecker) fixBetterLocation(region *core.RegionInfo, rf *placement.R if oldStore == 0 { return nil, nil } - newStore, _ := strategy.SelectStoreToImprove(ruleStores, oldStore) + newStore, filterByTempState := strategy.SelectStoreToImprove(ruleStores, oldStore) if newStore == 0 { log.Debug("no replacement store", zap.Uint64("region-id", region.GetID())) + c.handleFilterState(region, filterByTempState) return nil, nil } checkerCounter.WithLabelValues("rule_checker", "move-to-better-location").Inc() @@ -382,6 +385,15 @@ func (c *RuleChecker) getRuleFitStores(rf *placement.RuleFit) []*core.StoreInfo return stores } +func (c *RuleChecker) handleFilterState(region *core.RegionInfo, filterByTempState bool) { + if filterByTempState { + c.regionWaitingList.Put(region.GetID(), nil) + c.pendingList.Remove(region.GetID()) + } else { + c.pendingList.Put(region.GetID(), nil) + } +} + type recorder struct { offlineLeaderCounter map[uint64]uint64 lastUpdateTime time.Time diff --git a/server/schedule/checker/rule_checker_test.go b/server/schedule/checker/rule_checker_test.go index e4f95411b5e..f3a908939bf 100644 --- a/server/schedule/checker/rule_checker_test.go +++ b/server/schedule/checker/rule_checker_test.go @@ -820,3 +820,24 @@ func (s *testRuleCheckerSuite) TestOfflineAndDownStore(c *C) { c.Assert(op, NotNil) c.Assert(op.Desc(), Equals, "replace-rule-down-peer") } + +func (s *testRuleCheckerSuite) TestPendingList(c *C) { + // no enough store + s.cluster.AddLeaderStore(1, 1) + s.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2) + op := s.rc.Check(s.cluster.GetRegion(1)) + c.Assert(op, IsNil) + _, exist := s.rc.pendingList.Get(1) + c.Assert(exist, IsTrue) + + // add more stores + s.cluster.AddLeaderStore(2, 1) + s.cluster.AddLeaderStore(3, 1) + op = s.rc.Check(s.cluster.GetRegion(1)) + c.Assert(op, NotNil) + c.Assert(op.Desc(), Equals, "add-rule-peer") + c.Assert(op.GetPriorityLevel(), Equals, core.HighPriority) + c.Assert(op.Step(0).(operator.AddLearner).ToStore, Equals, uint64(3)) + _, exist = s.rc.pendingList.Get(1) + c.Assert(exist, IsFalse) +} From 1f3c30586d894c8492d56a4315dc559f7eee68ee Mon Sep 17 00:00:00 2001 From: JmPotato Date: Tue, 7 Jun 2022 10:56:29 +0800 Subject: [PATCH 21/82] scripts: add the inefficient assert function usage check (#5110) close tikv/pd#5104 Add the inefficient assert function usage check. Signed-off-by: JmPotato Co-authored-by: Ti Chi Robot --- Makefile | 21 ++++++++++--------- scripts/{check-testing-t.sh => check-test.sh} | 18 ++++++++++++++++ .../{check-errdoc.sh => generate-errdoc.sh} | 0 3 files changed, 29 insertions(+), 10 deletions(-) rename scripts/{check-testing-t.sh => check-test.sh} (55%) rename scripts/{check-errdoc.sh => generate-errdoc.sh} (100%) diff --git a/Makefile b/Makefile index 28a9a62c8d8..2afd99c0734 100644 --- a/Makefile +++ b/Makefile @@ -144,7 +144,7 @@ install-tools: #### Static checks #### -check: install-tools static tidy check-plugin errdoc check-testing-t +check: install-tools static tidy generate-errdoc check-plugin check-test static: install-tools @ echo "gofmt ..." @@ -160,21 +160,22 @@ tidy: @ go mod tidy git diff go.mod go.sum | cat git diff --quiet go.mod go.sum - + @ for mod in $(SUBMODULES); do cd $$mod && $(MAKE) tidy && cd - > /dev/null; done +generate-errdoc: install-tools + @echo "generating errors.toml..." + ./scripts/generate-errdoc.sh + check-plugin: - @echo "checking plugin" + @echo "checking plugin..." cd ./plugin/scheduler_example && $(MAKE) evictLeaderPlugin.so && rm evictLeaderPlugin.so -errdoc: install-tools - @echo "generator errors.toml" - ./scripts/check-errdoc.sh - -check-testing-t: - ./scripts/check-testing-t.sh +check-test: + @echo "checking test..." + ./scripts/check-test.sh -.PHONY: check static tidy check-plugin errdoc docker-build-test check-testing-t +.PHONY: check static tidy generate-errdoc check-plugin check-test #### Test utils #### diff --git a/scripts/check-testing-t.sh b/scripts/check-test.sh similarity index 55% rename from scripts/check-testing-t.sh rename to scripts/check-test.sh index 6d107b5a0d1..c8c5b72c0fe 100755 --- a/scripts/check-testing-t.sh +++ b/scripts/check-test.sh @@ -23,4 +23,22 @@ if [ "$res" ]; then exit 1 fi +# Check if there is any inefficient assert function usage in package. + +res=$(grep -rn --include=\*_test.go -E "(re|suite|require)\.(True|False)\((t, )?reflect\.DeepEqual\(" . | sort -u) \ + +if [ "$res" ]; then + echo "following packages use the inefficient assert function: please replace reflect.DeepEqual with require.Equal" + echo "$res" + exit 1 +fi + +res=$(grep -rn --include=\*_test.go -E "(re|suite|require)\.(True|False)\((t, )?strings\.Contains\(" . | sort -u) + +if [ "$res" ]; then + echo "following packages use the inefficient assert function: please replace strings.Contains with require.Contains" + echo "$res" + exit 1 +fi + exit 0 diff --git a/scripts/check-errdoc.sh b/scripts/generate-errdoc.sh similarity index 100% rename from scripts/check-errdoc.sh rename to scripts/generate-errdoc.sh From 726d345c8ebf5fb5f280d31bec428e154744bd4b Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Tue, 7 Jun 2022 12:12:30 +0800 Subject: [PATCH 22/82] scheduler: enlarge the search space so that the hot-scheduler can still schedule under dimensional conflicts (#4912) ref tikv/pd#4949 Signed-off-by: lhy1024 Co-authored-by: ShuNing Co-authored-by: Ti Chi Robot --- server/schedulers/hot_region.go | 66 ++++++++++------ server/schedulers/hot_region_test.go | 88 ++++++++++++++++++++-- server/statistics/store_hot_peers_infos.go | 27 ++++++- server/statistics/store_load.go | 6 ++ 4 files changed, 160 insertions(+), 27 deletions(-) diff --git a/server/schedulers/hot_region.go b/server/schedulers/hot_region.go index 2c4803b49de..3ca40a7133f 100644 --- a/server/schedulers/hot_region.go +++ b/server/schedulers/hot_region.go @@ -79,6 +79,9 @@ var ( schedulePeerPr = 0.66 // pendingAmpFactor will amplify the impact of pending influence, making scheduling slower or even serial when two stores are close together pendingAmpFactor = 2.0 + // If the distribution of a dimension is below the corresponding stddev threshold, then scheduling will no longer be based on this dimension, + // as it implies that this dimension is sufficiently uniform. + stddevThreshold = 0.1 ) type hotScheduler struct { @@ -384,6 +387,8 @@ type balanceSolver struct { minorDecRatio float64 maxPeerNum int minHotDegree int + + pick func(s interface{}, p func(int) bool) bool } func (bs *balanceSolver) init() { @@ -423,6 +428,11 @@ func (bs *balanceSolver) init() { bs.greatDecRatio, bs.minorDecRatio = bs.sche.conf.GetGreatDecRatio(), bs.sche.conf.GetMinorDecRatio() bs.maxPeerNum = bs.sche.conf.GetMaxPeerNumber() bs.minHotDegree = bs.GetOpts().GetHotRegionCacheHitsThreshold() + + bs.pick = slice.AnyOf + if bs.sche.conf.IsStrictPickingStoreEnabled() { + bs.pick = slice.AllOf + } } func (bs *balanceSolver) isSelectedDim(dim int) bool { @@ -472,7 +482,14 @@ func (bs *balanceSolver) solve() []*operator.Operator { return nil } bs.cur = &solution{} - tryUpdateBestSolution := func() { + + tryUpdateBestSolution := func(isUniformFirstPriority bool) { + if bs.cur.progressiveRank == -1 && isUniformFirstPriority { + // Because region is available for src and dst, so stddev is the same for both, only need to calcurate one. + // If first priority dim is enough uniform, -1 is unnecessary and maybe lead to worse balance for second priority dim + hotSchedulerResultCounter.WithLabelValues("skip-uniform-store", strconv.FormatUint(bs.cur.dstStore.GetID(), 10)).Inc() + return + } if bs.cur.progressiveRank < 0 && bs.betterThan(bs.best) { if newOps, newInfl := bs.buildOperators(); len(newOps) > 0 { bs.ops = newOps @@ -486,7 +503,11 @@ func (bs *balanceSolver) solve() []*operator.Operator { for _, srcStore := range bs.filterSrcStores() { bs.cur.srcStore = srcStore srcStoreID := srcStore.GetID() - + isUniformFirstPriority, isUniformSecondPriority := bs.isUniformFirstPriority(bs.cur.srcStore), bs.isUniformSecondPriority(bs.cur.srcStore) + if isUniformFirstPriority && isUniformSecondPriority { + hotSchedulerResultCounter.WithLabelValues("skip-uniform-store", strconv.FormatUint(bs.cur.srcStore.GetID(), 10)).Inc() + continue + } for _, srcPeerStat := range bs.filterHotPeers(srcStore) { if bs.cur.region = bs.getRegion(srcPeerStat, srcStoreID); bs.cur.region == nil { continue @@ -499,7 +520,7 @@ func (bs *balanceSolver) solve() []*operator.Operator { for _, dstStore := range bs.filterDstStores() { bs.cur.dstStore = dstStore bs.calcProgressiveRank() - tryUpdateBestSolution() + tryUpdateBestSolution(isUniformFirstPriority) } } } @@ -564,15 +585,12 @@ func (bs *balanceSolver) filterSrcStores() map[uint64]*statistics.StoreLoadDetai } func (bs *balanceSolver) checkSrcByDimPriorityAndTolerance(minLoad, expectLoad *statistics.StoreLoad, toleranceRatio float64) bool { - if bs.sche.conf.IsStrictPickingStoreEnabled() { - return slice.AllOf(minLoad.Loads, func(i int) bool { - if bs.isSelectedDim(i) { - return minLoad.Loads[i] > toleranceRatio*expectLoad.Loads[i] - } - return true - }) - } - return minLoad.Loads[bs.firstPriority] > toleranceRatio*expectLoad.Loads[bs.firstPriority] + return bs.pick(minLoad.Loads, func(i int) bool { + if bs.isSelectedDim(i) { + return minLoad.Loads[i] > toleranceRatio*expectLoad.Loads[i] + } + return true + }) } // filterHotPeers filtered hot peers from statistics.HotPeerStat and deleted the peer if its region is in pending status. @@ -770,15 +788,21 @@ func (bs *balanceSolver) pickDstStores(filters []filter.Filter, candidates []*st } func (bs *balanceSolver) checkDstByPriorityAndTolerance(maxLoad, expect *statistics.StoreLoad, toleranceRatio float64) bool { - if bs.sche.conf.IsStrictPickingStoreEnabled() { - return slice.AllOf(maxLoad.Loads, func(i int) bool { - if bs.isSelectedDim(i) { - return maxLoad.Loads[i]*toleranceRatio < expect.Loads[i] - } - return true - }) - } - return maxLoad.Loads[bs.firstPriority]*toleranceRatio < expect.Loads[bs.firstPriority] + return bs.pick(maxLoad.Loads, func(i int) bool { + if bs.isSelectedDim(i) { + return maxLoad.Loads[i]*toleranceRatio < expect.Loads[i] + } + return true + }) +} + +func (bs *balanceSolver) isUniformFirstPriority(store *statistics.StoreLoadDetail) bool { + // first priority should be more uniform than second priority + return store.IsUniform(bs.firstPriority, stddevThreshold*0.5) +} + +func (bs *balanceSolver) isUniformSecondPriority(store *statistics.StoreLoadDetail) bool { + return store.IsUniform(bs.secondPriority, stddevThreshold) } // calcProgressiveRank calculates `bs.cur.progressiveRank`. diff --git a/server/schedulers/hot_region_test.go b/server/schedulers/hot_region_test.go index 5d4d0939d61..66ed9ec3c9c 100644 --- a/server/schedulers/hot_region_test.go +++ b/server/schedulers/hot_region_test.go @@ -872,6 +872,12 @@ func (s *testHotWriteRegionSchedulerSuite) TestWithRuleEnabled(c *C) { tc.SetHotRegionCacheHitsThreshold(0) key, err := hex.DecodeString("") c.Assert(err, IsNil) + // skip stddev check + origin := stddevThreshold + stddevThreshold = -1.0 + defer func() { + stddevThreshold = origin + }() tc.AddRegionStore(1, 20) tc.AddRegionStore(2, 20) @@ -1776,6 +1782,12 @@ func (s *testHotSchedulerSuite) TestHotScheduleWithPriority(c *C) { c.Assert(err, IsNil) hb.(*hotScheduler).conf.SetDstToleranceRatio(1.05) hb.(*hotScheduler).conf.SetSrcToleranceRatio(1.05) + // skip stddev check + origin := stddevThreshold + stddevThreshold = -1.0 + defer func() { + stddevThreshold = origin + }() tc := mockcluster.NewCluster(ctx, opt) tc.SetHotRegionCacheHitsThreshold(0) @@ -1831,28 +1843,94 @@ func (s *testHotSchedulerSuite) TestHotScheduleWithPriority(c *C) { hb, err = schedule.CreateScheduler(statistics.Write.String(), schedule.NewOperatorController(ctx, nil, nil), storage.NewStorageWithMemoryBackend(), nil) c.Assert(err, IsNil) - hb.(*hotScheduler).conf.StrictPickingStore = false + // assert loose store picking tc.UpdateStorageWrittenStats(1, 10*MB*statistics.StoreHeartBeatReportInterval, 1*MB*statistics.StoreHeartBeatReportInterval) - tc.UpdateStorageWrittenStats(2, 6*MB*statistics.StoreHeartBeatReportInterval, 6*MB*statistics.StoreHeartBeatReportInterval) + tc.UpdateStorageWrittenStats(2, 6.1*MB*statistics.StoreHeartBeatReportInterval, 6.1*MB*statistics.StoreHeartBeatReportInterval) tc.UpdateStorageWrittenStats(3, 6*MB*statistics.StoreHeartBeatReportInterval, 6*MB*statistics.StoreHeartBeatReportInterval) tc.UpdateStorageWrittenStats(4, 6*MB*statistics.StoreHeartBeatReportInterval, 6*MB*statistics.StoreHeartBeatReportInterval) tc.UpdateStorageWrittenStats(5, 1*MB*statistics.StoreHeartBeatReportInterval, 1*MB*statistics.StoreHeartBeatReportInterval) hb.(*hotScheduler).conf.WritePeerPriorities = []string{BytePriority, KeyPriority} + hb.(*hotScheduler).conf.StrictPickingStore = true + ops = hb.Schedule(tc) + c.Assert(ops, HasLen, 0) + hb.(*hotScheduler).conf.StrictPickingStore = false ops = hb.Schedule(tc) c.Assert(ops, HasLen, 1) - testutil.CheckTransferPeer(c, ops[0], operator.OpHotRegion, 1, 5) + testutil.CheckTransferPeer(c, ops[0], operator.OpHotRegion, 2, 5) // two dims will be better clearPendingInfluence(hb.(*hotScheduler)) tc.UpdateStorageWrittenStats(1, 6*MB*statistics.StoreHeartBeatReportInterval, 6*MB*statistics.StoreHeartBeatReportInterval) - tc.UpdateStorageWrittenStats(2, 6*MB*statistics.StoreHeartBeatReportInterval, 6*MB*statistics.StoreHeartBeatReportInterval) + tc.UpdateStorageWrittenStats(2, 6.1*MB*statistics.StoreHeartBeatReportInterval, 6.1*MB*statistics.StoreHeartBeatReportInterval) tc.UpdateStorageWrittenStats(3, 6*MB*statistics.StoreHeartBeatReportInterval, 6*MB*statistics.StoreHeartBeatReportInterval) tc.UpdateStorageWrittenStats(4, 1*MB*statistics.StoreHeartBeatReportInterval, 10*MB*statistics.StoreHeartBeatReportInterval) tc.UpdateStorageWrittenStats(5, 1*MB*statistics.StoreHeartBeatReportInterval, 1*MB*statistics.StoreHeartBeatReportInterval) hb.(*hotScheduler).conf.WritePeerPriorities = []string{KeyPriority, BytePriority} + hb.(*hotScheduler).conf.StrictPickingStore = true + ops = hb.Schedule(tc) + c.Assert(ops, HasLen, 0) + hb.(*hotScheduler).conf.StrictPickingStore = false ops = hb.Schedule(tc) c.Assert(ops, HasLen, 1) - testutil.CheckTransferPeer(c, ops[0], operator.OpHotRegion, 4, 5) + testutil.CheckTransferPeer(c, ops[0], operator.OpHotRegion, 2, 5) // two dims will be better + clearPendingInfluence(hb.(*hotScheduler)) +} + +func (s *testHotSchedulerSuite) TestHotScheduleWithStddev(c *C) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + statistics.Denoising = false + opt := config.NewTestOptions() + hb, err := schedule.CreateScheduler(statistics.Write.String(), schedule.NewOperatorController(ctx, nil, nil), storage.NewStorageWithMemoryBackend(), nil) + c.Assert(err, IsNil) + hb.(*hotScheduler).conf.SetDstToleranceRatio(0.0) + hb.(*hotScheduler).conf.SetSrcToleranceRatio(0.0) + tc := mockcluster.NewCluster(ctx, opt) + tc.SetClusterVersion(versioninfo.MinSupportedVersion(versioninfo.Version4_0)) + tc.SetHotRegionCacheHitsThreshold(0) + tc.AddRegionStore(1, 20) + tc.AddRegionStore(2, 20) + tc.AddRegionStore(3, 20) + tc.AddRegionStore(4, 20) + tc.AddRegionStore(5, 20) + hb.(*hotScheduler).conf.StrictPickingStore = false + + // skip uniform cluster + tc.UpdateStorageWrittenStats(1, 5*MB*statistics.StoreHeartBeatReportInterval, 5*MB*statistics.StoreHeartBeatReportInterval) + tc.UpdateStorageWrittenStats(2, 5.3*MB*statistics.StoreHeartBeatReportInterval, 5.3*MB*statistics.StoreHeartBeatReportInterval) + tc.UpdateStorageWrittenStats(3, 5*MB*statistics.StoreHeartBeatReportInterval, 5*MB*statistics.StoreHeartBeatReportInterval) + tc.UpdateStorageWrittenStats(4, 5*MB*statistics.StoreHeartBeatReportInterval, 5*MB*statistics.StoreHeartBeatReportInterval) + tc.UpdateStorageWrittenStats(5, 4.8*MB*statistics.StoreHeartBeatReportInterval, 4.8*MB*statistics.StoreHeartBeatReportInterval) + addRegionInfo(tc, statistics.Write, []testRegionInfo{ + {6, []uint64{3, 4, 2}, 0.1 * MB, 0.1 * MB, 0}, + }) + hb.(*hotScheduler).conf.WritePeerPriorities = []string{BytePriority, KeyPriority} + stddevThreshold = 0.1 + ops := hb.Schedule(tc) + c.Assert(ops, HasLen, 0) + stddevThreshold = -1.0 + ops = hb.Schedule(tc) + c.Assert(ops, HasLen, 1) + testutil.CheckTransferPeer(c, ops[0], operator.OpHotRegion, 2, 5) + clearPendingInfluence(hb.(*hotScheduler)) + + // skip -1 case (uniform cluster) + tc.UpdateStorageWrittenStats(1, 5*MB*statistics.StoreHeartBeatReportInterval, 100*MB*statistics.StoreHeartBeatReportInterval) // two dims are not uniform. + tc.UpdateStorageWrittenStats(2, 5.3*MB*statistics.StoreHeartBeatReportInterval, 4.8*MB*statistics.StoreHeartBeatReportInterval) + tc.UpdateStorageWrittenStats(3, 5*MB*statistics.StoreHeartBeatReportInterval, 5*MB*statistics.StoreHeartBeatReportInterval) + tc.UpdateStorageWrittenStats(4, 5*MB*statistics.StoreHeartBeatReportInterval, 5*MB*statistics.StoreHeartBeatReportInterval) + tc.UpdateStorageWrittenStats(5, 4.8*MB*statistics.StoreHeartBeatReportInterval, 5*MB*statistics.StoreHeartBeatReportInterval) + addRegionInfo(tc, statistics.Write, []testRegionInfo{ + {6, []uint64{3, 4, 2}, 0.1 * MB, 0.1 * MB, 0}, + }) + hb.(*hotScheduler).conf.WritePeerPriorities = []string{BytePriority, KeyPriority} + stddevThreshold = 0.1 + ops = hb.Schedule(tc) + c.Assert(ops, HasLen, 0) + stddevThreshold = -1.0 + ops = hb.Schedule(tc) + c.Assert(ops, HasLen, 1) + testutil.CheckTransferPeer(c, ops[0], operator.OpHotRegion, 2, 5) clearPendingInfluence(hb.(*hotScheduler)) } diff --git a/server/statistics/store_hot_peers_infos.go b/server/statistics/store_hot_peers_infos.go index e0cc0afaa98..6b58023d01e 100644 --- a/server/statistics/store_hot_peers_infos.go +++ b/server/statistics/store_hot_peers_infos.go @@ -16,6 +16,7 @@ package statistics import ( "fmt" + "math" "github.com/tikv/pd/server/core" ) @@ -186,6 +187,19 @@ func summaryStoresLoadByEngine( for i := range expectLoads { expectLoads[i] = allStoreLoadSum[i] / float64(allStoreCount) } + + stddevLoads := make([]float64, len(allStoreLoadSum)) + if allHotPeersCount != 0 { + for _, detail := range loadDetail { + for i := range expectLoads { + stddevLoads[i] += math.Pow(detail.LoadPred.Current.Loads[i]-expectLoads[i], 2) + } + } + for i := range stddevLoads { + stddevLoads[i] = math.Sqrt(stddevLoads[i]/float64(allStoreCount)) / expectLoads[i] + } + } + { // Metric for debug. engine := collector.Engine() @@ -197,13 +211,24 @@ func summaryStoresLoadByEngine( hotPeerSummary.WithLabelValues(ty, engine).Set(expectLoads[QueryDim]) ty = "exp-count-rate-" + rwTy.String() + "-" + kind.String() hotPeerSummary.WithLabelValues(ty, engine).Set(expectCount) + ty = "stddev-byte-rate-" + rwTy.String() + "-" + kind.String() + hotPeerSummary.WithLabelValues(ty, engine).Set(stddevLoads[ByteDim]) + ty = "stddev-key-rate-" + rwTy.String() + "-" + kind.String() + hotPeerSummary.WithLabelValues(ty, engine).Set(stddevLoads[KeyDim]) + ty = "stddev-query-rate-" + rwTy.String() + "-" + kind.String() + hotPeerSummary.WithLabelValues(ty, engine).Set(stddevLoads[QueryDim]) } expect := StoreLoad{ Loads: expectLoads, - Count: float64(allHotPeersCount) / float64(allStoreCount), + Count: expectCount, + } + stddev := StoreLoad{ + Loads: stddevLoads, + Count: expectCount, } for _, detail := range loadDetail { detail.LoadPred.Expect = expect + detail.LoadPred.Stddev = stddev } return loadDetail } diff --git a/server/statistics/store_load.go b/server/statistics/store_load.go index e907f725815..0e8ebb4fc5d 100644 --- a/server/statistics/store_load.go +++ b/server/statistics/store_load.go @@ -73,6 +73,11 @@ func (li *StoreLoadDetail) ToHotPeersStat() *HotPeersStat { } } +// IsUniform returns true if the stores are uniform. +func (li *StoreLoadDetail) IsUniform(dim int, threshold float64) bool { + return li.LoadPred.Stddev.Loads[dim] < threshold +} + func toHotPeerStatShow(p *HotPeerStat, kind RWType) HotPeerStatShow { b, k, q := GetRegionStatKind(kind, ByteDim), GetRegionStatKind(kind, KeyDim), GetRegionStatKind(kind, QueryDim) byteRate := p.Loads[b] @@ -206,6 +211,7 @@ type StoreLoadPred struct { Current StoreLoad Future StoreLoad Expect StoreLoad + Stddev StoreLoad } // Min returns the min load between current and future. From e244adb2760a983d6b7f55d217b2a212a33db5a2 Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Tue, 7 Jun 2022 14:02:30 +0800 Subject: [PATCH 23/82] config: migrate test framework to testify (#5103) ref tikv/pd#4813 Signed-off-by: Ryan Leung Co-authored-by: Ti Chi Robot --- server/config/config_test.go | 332 +++++++++++++++-------------- server/config/store_config_test.go | 57 ++--- server/config/util_test.go | 22 +- 3 files changed, 214 insertions(+), 197 deletions(-) diff --git a/server/config/config_test.go b/server/config/config_test.go index fbda46a0d39..885e24d8d8b 100644 --- a/server/config/config_test.go +++ b/server/config/config_test.go @@ -20,58 +20,47 @@ import ( "math" "os" "path" - "strings" "testing" "time" "github.com/BurntSushi/toml" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/server/storage" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testConfigSuite{}) - -type testConfigSuite struct{} - -func (s *testConfigSuite) SetUpSuite(c *C) { - for _, d := range DefaultSchedulers { - RegisterScheduler(d.Type) - } - RegisterScheduler("random-merge") - RegisterScheduler("shuffle-leader") -} - -func (s *testConfigSuite) TestSecurity(c *C) { +func TestSecurity(t *testing.T) { + re := require.New(t) cfg := NewConfig() - c.Assert(cfg.Security.RedactInfoLog, IsFalse) + re.False(cfg.Security.RedactInfoLog) } -func (s *testConfigSuite) TestTLS(c *C) { +func TestTLS(t *testing.T) { + re := require.New(t) cfg := NewConfig() tls, err := cfg.Security.ToTLSConfig() - c.Assert(err, IsNil) - c.Assert(tls, IsNil) + re.NoError(err) + re.Nil(tls) } -func (s *testConfigSuite) TestBadFormatJoinAddr(c *C) { +func TestBadFormatJoinAddr(t *testing.T) { + re := require.New(t) cfg := NewConfig() cfg.Join = "127.0.0.1:2379" // Wrong join addr without scheme. - c.Assert(cfg.Adjust(nil, false), NotNil) + re.Error(cfg.Adjust(nil, false)) } -func (s *testConfigSuite) TestReloadConfig(c *C) { +func TestReloadConfig(t *testing.T) { + re := require.New(t) + registerDefaultSchedulers() + RegisterScheduler("shuffle-leader") opt, err := newTestScheduleOption() - c.Assert(err, IsNil) + re.NoError(err) storage := storage.NewStorageWithMemoryBackend() scheduleCfg := opt.GetScheduleConfig() scheduleCfg.MaxSnapshotCount = 10 opt.SetMaxReplicas(5) opt.GetPDServerConfig().UseRegionStorage = true - c.Assert(opt.Persist(storage), IsNil) + re.NoError(opt.Persist(storage)) // Add a new default enable scheduler "shuffle-leader" DefaultSchedulers = append(DefaultSchedulers, SchedulerConfig{Type: "shuffle-leader"}) @@ -80,23 +69,25 @@ func (s *testConfigSuite) TestReloadConfig(c *C) { }() newOpt, err := newTestScheduleOption() - c.Assert(err, IsNil) - c.Assert(newOpt.Reload(storage), IsNil) + re.NoError(err) + re.NoError(newOpt.Reload(storage)) schedulers := newOpt.GetSchedulers() - c.Assert(schedulers, HasLen, len(DefaultSchedulers)) - c.Assert(newOpt.IsUseRegionStorage(), IsTrue) + re.Len(schedulers, len(DefaultSchedulers)) + re.True(newOpt.IsUseRegionStorage()) for i, s := range schedulers { - c.Assert(s.Type, Equals, DefaultSchedulers[i].Type) - c.Assert(s.Disable, IsFalse) + re.Equal(DefaultSchedulers[i].Type, s.Type) + re.False(s.Disable) } - c.Assert(newOpt.GetMaxReplicas(), Equals, 5) - c.Assert(newOpt.GetMaxSnapshotCount(), Equals, uint64(10)) - c.Assert(newOpt.GetMaxMovableHotPeerSize(), Equals, int64(512)) + re.Equal(5, newOpt.GetMaxReplicas()) + re.Equal(uint64(10), newOpt.GetMaxSnapshotCount()) + re.Equal(int64(512), newOpt.GetMaxMovableHotPeerSize()) } -func (s *testConfigSuite) TestReloadUpgrade(c *C) { +func TestReloadUpgrade(t *testing.T) { + re := require.New(t) + registerDefaultSchedulers() opt, err := newTestScheduleOption() - c.Assert(err, IsNil) + re.NoError(err) // Simulate an old configuration that only contains 2 fields. type OldConfig struct { @@ -108,17 +99,19 @@ func (s *testConfigSuite) TestReloadUpgrade(c *C) { Replication: *opt.GetReplicationConfig(), } storage := storage.NewStorageWithMemoryBackend() - c.Assert(storage.SaveConfig(old), IsNil) + re.NoError(storage.SaveConfig(old)) newOpt, err := newTestScheduleOption() - c.Assert(err, IsNil) - c.Assert(newOpt.Reload(storage), IsNil) - c.Assert(newOpt.GetPDServerConfig().KeyType, Equals, defaultKeyType) // should be set to default value. + re.NoError(err) + re.NoError(newOpt.Reload(storage)) + re.Equal(defaultKeyType, newOpt.GetPDServerConfig().KeyType) // should be set to default value. } -func (s *testConfigSuite) TestReloadUpgrade2(c *C) { +func TestReloadUpgrade2(t *testing.T) { + re := require.New(t) + registerDefaultSchedulers() opt, err := newTestScheduleOption() - c.Assert(err, IsNil) + re.NoError(err) // Simulate an old configuration that does not contain ScheduleConfig. type OldConfig struct { @@ -128,43 +121,48 @@ func (s *testConfigSuite) TestReloadUpgrade2(c *C) { Replication: *opt.GetReplicationConfig(), } storage := storage.NewStorageWithMemoryBackend() - c.Assert(storage.SaveConfig(old), IsNil) + re.NoError(storage.SaveConfig(old)) newOpt, err := newTestScheduleOption() - c.Assert(err, IsNil) - c.Assert(newOpt.Reload(storage), IsNil) - c.Assert(newOpt.GetScheduleConfig().RegionScoreFormulaVersion, Equals, "") // formulaVersion keep old value when reloading. + re.NoError(err) + re.NoError(newOpt.Reload(storage)) + re.Equal("", newOpt.GetScheduleConfig().RegionScoreFormulaVersion) // formulaVersion keep old value when reloading. } -func (s *testConfigSuite) TestValidation(c *C) { +func TestValidation(t *testing.T) { + re := require.New(t) + registerDefaultSchedulers() cfg := NewConfig() - c.Assert(cfg.Adjust(nil, false), IsNil) + re.NoError(cfg.Adjust(nil, false)) cfg.Log.File.Filename = path.Join(cfg.DataDir, "test") - c.Assert(cfg.Validate(), NotNil) + re.Error(cfg.Validate()) // check schedule config cfg.Schedule.HighSpaceRatio = -0.1 - c.Assert(cfg.Schedule.Validate(), NotNil) + re.Error(cfg.Schedule.Validate()) cfg.Schedule.HighSpaceRatio = 0.6 - c.Assert(cfg.Schedule.Validate(), IsNil) + re.NoError(cfg.Schedule.Validate()) cfg.Schedule.LowSpaceRatio = 1.1 - c.Assert(cfg.Schedule.Validate(), NotNil) + re.Error(cfg.Schedule.Validate()) cfg.Schedule.LowSpaceRatio = 0.4 - c.Assert(cfg.Schedule.Validate(), NotNil) + re.Error(cfg.Schedule.Validate()) cfg.Schedule.LowSpaceRatio = 0.8 - c.Assert(cfg.Schedule.Validate(), IsNil) + re.NoError(cfg.Schedule.Validate()) cfg.Schedule.TolerantSizeRatio = -0.6 - c.Assert(cfg.Schedule.Validate(), NotNil) + re.Error(cfg.Schedule.Validate()) // check quota - c.Assert(cfg.QuotaBackendBytes, Equals, defaultQuotaBackendBytes) + re.Equal(defaultQuotaBackendBytes, cfg.QuotaBackendBytes) // check request bytes - c.Assert(cfg.MaxRequestBytes, Equals, defaultMaxRequestBytes) + re.Equal(defaultMaxRequestBytes, cfg.MaxRequestBytes) - c.Assert(cfg.Log.Format, Equals, defaultLogFormat) + re.Equal(defaultLogFormat, cfg.Log.Format) } -func (s *testConfigSuite) TestAdjust(c *C) { +func TestAdjust(t *testing.T) { + re := require.New(t) + registerDefaultSchedulers() + RegisterScheduler("random-merge") cfgData := ` name = "" lease = 0 @@ -180,27 +178,27 @@ leader-schedule-limit = 0 ` cfg := NewConfig() meta, err := toml.Decode(cfgData, &cfg) - c.Assert(err, IsNil) + re.NoError(err) err = cfg.Adjust(&meta, false) - c.Assert(err, IsNil) + re.NoError(err) // When invalid, use default values. host, err := os.Hostname() - c.Assert(err, IsNil) - c.Assert(cfg.Name, Equals, fmt.Sprintf("%s-%s", defaultName, host)) - c.Assert(cfg.LeaderLease, Equals, defaultLeaderLease) - c.Assert(cfg.MaxRequestBytes, Equals, uint(20000000)) + re.NoError(err) + re.Equal(fmt.Sprintf("%s-%s", defaultName, host), cfg.Name) + re.Equal(defaultLeaderLease, cfg.LeaderLease) + re.Equal(uint(20000000), cfg.MaxRequestBytes) // When defined, use values from config file. - c.Assert(cfg.Schedule.MaxMergeRegionSize, Equals, uint64(0)) - c.Assert(cfg.Schedule.EnableOneWayMerge, IsTrue) - c.Assert(cfg.Schedule.LeaderScheduleLimit, Equals, uint64(0)) + re.Equal(uint64(0), cfg.Schedule.MaxMergeRegionSize) + re.True(cfg.Schedule.EnableOneWayMerge) + re.Equal(uint64(0), cfg.Schedule.LeaderScheduleLimit) // When undefined, use default values. - c.Assert(cfg.PreVote, IsTrue) - c.Assert(cfg.Log.Level, Equals, "info") - c.Assert(cfg.Schedule.MaxMergeRegionKeys, Equals, uint64(defaultMaxMergeRegionKeys)) - c.Assert(cfg.PDServerCfg.MetricStorage, Equals, "http://127.0.0.1:9090") + re.True(cfg.PreVote) + re.Equal("info", cfg.Log.Level) + re.Equal(uint64(defaultMaxMergeRegionKeys), cfg.Schedule.MaxMergeRegionKeys) + re.Equal("http://127.0.0.1:9090", cfg.PDServerCfg.MetricStorage) - c.Assert(cfg.TSOUpdatePhysicalInterval.Duration, Equals, DefaultTSOUpdatePhysicalInterval) + re.Equal(DefaultTSOUpdatePhysicalInterval, cfg.TSOUpdatePhysicalInterval.Duration) // Check undefined config fields cfgData = ` @@ -213,10 +211,10 @@ type = "random-merge" ` cfg = NewConfig() meta, err = toml.Decode(cfgData, &cfg) - c.Assert(err, IsNil) + re.NoError(err) err = cfg.Adjust(&meta, false) - c.Assert(err, IsNil) - c.Assert(strings.Contains(cfg.WarningMsgs[0], "Config contains undefined item"), IsTrue) + re.NoError(err) + re.Contains(cfg.WarningMsgs[0], "Config contains undefined item") // Check misspelled schedulers name cfgData = ` @@ -228,9 +226,9 @@ type = "random-merge-schedulers" ` cfg = NewConfig() meta, err = toml.Decode(cfgData, &cfg) - c.Assert(err, IsNil) + re.NoError(err) err = cfg.Adjust(&meta, false) - c.Assert(err, NotNil) + re.Error(err) // Check correct schedulers name cfgData = ` @@ -242,9 +240,9 @@ type = "random-merge" ` cfg = NewConfig() meta, err = toml.Decode(cfgData, &cfg) - c.Assert(err, IsNil) + re.NoError(err) err = cfg.Adjust(&meta, false) - c.Assert(err, IsNil) + re.NoError(err) cfgData = ` [metric] @@ -253,12 +251,12 @@ address = "localhost:9090" ` cfg = NewConfig() meta, err = toml.Decode(cfgData, &cfg) - c.Assert(err, IsNil) + re.NoError(err) err = cfg.Adjust(&meta, false) - c.Assert(err, IsNil) + re.NoError(err) - c.Assert(cfg.Metric.PushInterval.Duration, Equals, 35*time.Second) - c.Assert(cfg.Metric.PushAddress, Equals, "localhost:9090") + re.Equal(35*time.Second, cfg.Metric.PushInterval.Duration) + re.Equal("localhost:9090", cfg.Metric.PushAddress) // Test clamping TSOUpdatePhysicalInterval value cfgData = ` @@ -266,29 +264,31 @@ tso-update-physical-interval = "10ms" ` cfg = NewConfig() meta, err = toml.Decode(cfgData, &cfg) - c.Assert(err, IsNil) + re.NoError(err) err = cfg.Adjust(&meta, false) - c.Assert(err, IsNil) + re.NoError(err) - c.Assert(cfg.TSOUpdatePhysicalInterval.Duration, Equals, minTSOUpdatePhysicalInterval) + re.Equal(minTSOUpdatePhysicalInterval, cfg.TSOUpdatePhysicalInterval.Duration) cfgData = ` tso-update-physical-interval = "15s" ` cfg = NewConfig() meta, err = toml.Decode(cfgData, &cfg) - c.Assert(err, IsNil) + re.NoError(err) err = cfg.Adjust(&meta, false) - c.Assert(err, IsNil) + re.NoError(err) - c.Assert(cfg.TSOUpdatePhysicalInterval.Duration, Equals, maxTSOUpdatePhysicalInterval) + re.Equal(maxTSOUpdatePhysicalInterval, cfg.TSOUpdatePhysicalInterval.Duration) } -func (s *testConfigSuite) TestMigrateFlags(c *C) { +func TestMigrateFlags(t *testing.T) { + re := require.New(t) + registerDefaultSchedulers() load := func(s string) (*Config, error) { cfg := NewConfig() meta, err := toml.Decode(s, &cfg) - c.Assert(err, IsNil) + re.NoError(err) err = cfg.Adjust(&meta, false) return cfg, err } @@ -301,35 +301,28 @@ enable-make-up-replica = false disable-remove-extra-replica = true enable-remove-extra-replica = false `) - c.Assert(err, IsNil) - c.Assert(cfg.PDServerCfg.FlowRoundByDigit, Equals, math.MaxInt8) - c.Assert(cfg.Schedule.EnableReplaceOfflineReplica, IsTrue) - c.Assert(cfg.Schedule.EnableRemoveDownReplica, IsFalse) - c.Assert(cfg.Schedule.EnableMakeUpReplica, IsFalse) - c.Assert(cfg.Schedule.EnableRemoveExtraReplica, IsFalse) + re.NoError(err) + re.Equal(math.MaxInt8, cfg.PDServerCfg.FlowRoundByDigit) + re.True(cfg.Schedule.EnableReplaceOfflineReplica) + re.False(cfg.Schedule.EnableRemoveDownReplica) + re.False(cfg.Schedule.EnableMakeUpReplica) + re.False(cfg.Schedule.EnableRemoveExtraReplica) b, err := json.Marshal(cfg) - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(b), "disable-replace-offline-replica"), IsFalse) - c.Assert(strings.Contains(string(b), "disable-remove-down-replica"), IsFalse) + re.NoError(err) + re.NotContains(string(b), "disable-replace-offline-replica") + re.NotContains(string(b), "disable-remove-down-replica") _, err = load(` [schedule] enable-make-up-replica = false disable-make-up-replica = false `) - c.Assert(err, NotNil) + re.Error(err) } -func newTestScheduleOption() (*PersistOptions, error) { - cfg := NewConfig() - if err := cfg.Adjust(nil, false); err != nil { - return nil, err - } - opt := NewPersistOptions(cfg) - return opt, nil -} - -func (s *testConfigSuite) TestPDServerConfig(c *C) { +func TestPDServerConfig(t *testing.T) { + re := require.New(t) + registerDefaultSchedulers() tests := []struct { cfgData string hasErr bool @@ -382,19 +375,21 @@ dashboard-address = "foo" }, } - for _, t := range tests { + for _, test := range tests { cfg := NewConfig() - meta, err := toml.Decode(t.cfgData, &cfg) - c.Assert(err, IsNil) + meta, err := toml.Decode(test.cfgData, &cfg) + re.NoError(err) err = cfg.Adjust(&meta, false) - c.Assert(err != nil, Equals, t.hasErr) - if !t.hasErr { - c.Assert(cfg.PDServerCfg.DashboardAddress, Equals, t.dashboardAddress) + re.Equal(test.hasErr, err != nil) + if !test.hasErr { + re.Equal(test.dashboardAddress, cfg.PDServerCfg.DashboardAddress) } } } -func (s *testConfigSuite) TestDashboardConfig(c *C) { +func TestDashboardConfig(t *testing.T) { + re := require.New(t) + registerDefaultSchedulers() cfgData := ` [dashboard] tidb-cacert-path = "/path/ca.pem" @@ -403,12 +398,12 @@ tidb-cert-path = "/path/client.pem" ` cfg := NewConfig() meta, err := toml.Decode(cfgData, &cfg) - c.Assert(err, IsNil) + re.NoError(err) err = cfg.Adjust(&meta, false) - c.Assert(err, IsNil) - c.Assert(cfg.Dashboard.TiDBCAPath, Equals, "/path/ca.pem") - c.Assert(cfg.Dashboard.TiDBKeyPath, Equals, "/path/client-key.pem") - c.Assert(cfg.Dashboard.TiDBCertPath, Equals, "/path/client.pem") + re.NoError(err) + re.Equal("/path/ca.pem", cfg.Dashboard.TiDBCAPath) + re.Equal("/path/client-key.pem", cfg.Dashboard.TiDBKeyPath) + re.Equal("/path/client.pem", cfg.Dashboard.TiDBCertPath) // Test different editions tests := []struct { @@ -424,15 +419,17 @@ tidb-cert-path = "/path/client.pem" initByLDFlags(test.Edition) cfg = NewConfig() meta, err = toml.Decode(cfgData, &cfg) - c.Assert(err, IsNil) + re.NoError(err) err = cfg.Adjust(&meta, false) - c.Assert(err, IsNil) - c.Assert(cfg.Dashboard.EnableTelemetry, Equals, test.EnableTelemetry) + re.NoError(err) + re.Equal(test.EnableTelemetry, cfg.Dashboard.EnableTelemetry) } defaultEnableTelemetry = originalDefaultEnableTelemetry } -func (s *testConfigSuite) TestReplicationMode(c *C) { +func TestReplicationMode(t *testing.T) { + re := require.New(t) + registerDefaultSchedulers() cfgData := ` [replication-mode] replication-mode = "dr-auto-sync" @@ -446,28 +443,30 @@ wait-store-timeout = "120s" ` cfg := NewConfig() meta, err := toml.Decode(cfgData, &cfg) - c.Assert(err, IsNil) + re.NoError(err) err = cfg.Adjust(&meta, false) - c.Assert(err, IsNil) + re.NoError(err) - c.Assert(cfg.ReplicationMode.ReplicationMode, Equals, "dr-auto-sync") - c.Assert(cfg.ReplicationMode.DRAutoSync.LabelKey, Equals, "zone") - c.Assert(cfg.ReplicationMode.DRAutoSync.Primary, Equals, "zone1") - c.Assert(cfg.ReplicationMode.DRAutoSync.DR, Equals, "zone2") - c.Assert(cfg.ReplicationMode.DRAutoSync.PrimaryReplicas, Equals, 2) - c.Assert(cfg.ReplicationMode.DRAutoSync.DRReplicas, Equals, 1) - c.Assert(cfg.ReplicationMode.DRAutoSync.WaitStoreTimeout.Duration, Equals, 2*time.Minute) - c.Assert(cfg.ReplicationMode.DRAutoSync.WaitSyncTimeout.Duration, Equals, time.Minute) + re.Equal("dr-auto-sync", cfg.ReplicationMode.ReplicationMode) + re.Equal("zone", cfg.ReplicationMode.DRAutoSync.LabelKey) + re.Equal("zone1", cfg.ReplicationMode.DRAutoSync.Primary) + re.Equal("zone2", cfg.ReplicationMode.DRAutoSync.DR) + re.Equal(2, cfg.ReplicationMode.DRAutoSync.PrimaryReplicas) + re.Equal(1, cfg.ReplicationMode.DRAutoSync.DRReplicas) + re.Equal(2*time.Minute, cfg.ReplicationMode.DRAutoSync.WaitStoreTimeout.Duration) + re.Equal(time.Minute, cfg.ReplicationMode.DRAutoSync.WaitSyncTimeout.Duration) cfg = NewConfig() meta, err = toml.Decode("", &cfg) - c.Assert(err, IsNil) + re.NoError(err) err = cfg.Adjust(&meta, false) - c.Assert(err, IsNil) - c.Assert(cfg.ReplicationMode.ReplicationMode, Equals, "majority") + re.NoError(err) + re.Equal("majority", cfg.ReplicationMode.ReplicationMode) } -func (s *testConfigSuite) TestHotHistoryRegionConfig(c *C) { +func TestHotHistoryRegionConfig(t *testing.T) { + re := require.New(t) + registerDefaultSchedulers() cfgData := ` [schedule] hot-regions-reserved-days= 30 @@ -475,39 +474,56 @@ hot-regions-write-interval= "30m" ` cfg := NewConfig() meta, err := toml.Decode(cfgData, &cfg) - c.Assert(err, IsNil) + re.NoError(err) err = cfg.Adjust(&meta, false) - c.Assert(err, IsNil) - c.Assert(cfg.Schedule.HotRegionsWriteInterval.Duration, Equals, 30*time.Minute) - c.Assert(cfg.Schedule.HotRegionsReservedDays, Equals, uint64(30)) + re.NoError(err) + re.Equal(30*time.Minute, cfg.Schedule.HotRegionsWriteInterval.Duration) + re.Equal(uint64(30), cfg.Schedule.HotRegionsReservedDays) // Verify default value cfg = NewConfig() err = cfg.Adjust(nil, false) - c.Assert(err, IsNil) - c.Assert(cfg.Schedule.HotRegionsWriteInterval.Duration, Equals, 10*time.Minute) - c.Assert(cfg.Schedule.HotRegionsReservedDays, Equals, uint64(7)) + re.NoError(err) + re.Equal(10*time.Minute, cfg.Schedule.HotRegionsWriteInterval.Duration) + re.Equal(uint64(7), cfg.Schedule.HotRegionsReservedDays) } -func (s *testConfigSuite) TestConfigClone(c *C) { +func TestConfigClone(t *testing.T) { + re := require.New(t) + registerDefaultSchedulers() cfg := &Config{} cfg.Adjust(nil, false) - c.Assert(cfg.Clone(), DeepEquals, cfg) + re.Equal(cfg, cfg.Clone()) emptyConfigMetaData := newConfigMetadata(nil) schedule := &ScheduleConfig{} schedule.adjust(emptyConfigMetaData, false) - c.Assert(schedule.Clone(), DeepEquals, schedule) + re.Equal(schedule, schedule.Clone()) replication := &ReplicationConfig{} replication.adjust(emptyConfigMetaData) - c.Assert(replication.Clone(), DeepEquals, replication) + re.Equal(replication, replication.Clone()) pdServer := &PDServerConfig{} pdServer.adjust(emptyConfigMetaData) - c.Assert(pdServer.Clone(), DeepEquals, pdServer) + re.Equal(pdServer, pdServer.Clone()) replicationMode := &ReplicationModeConfig{} replicationMode.adjust(emptyConfigMetaData) - c.Assert(replicationMode.Clone(), DeepEquals, replicationMode) + re.Equal(replicationMode, replicationMode.Clone()) +} + +func newTestScheduleOption() (*PersistOptions, error) { + cfg := NewConfig() + if err := cfg.Adjust(nil, false); err != nil { + return nil, err + } + opt := NewPersistOptions(cfg) + return opt, nil +} + +func registerDefaultSchedulers() { + for _, d := range DefaultSchedulers { + RegisterScheduler(d.Type) + } } diff --git a/server/config/store_config_test.go b/server/config/store_config_test.go index 478e1ebb3d7..8d06fe029f7 100644 --- a/server/config/store_config_test.go +++ b/server/config/store_config_test.go @@ -18,15 +18,13 @@ import ( "crypto/tls" "encoding/json" "net/http" + "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testTiKVConfigSuite{}) - -type testTiKVConfigSuite struct{} - -func (t *testTiKVConfigSuite) TestTiKVConfig(c *C) { +func TestTiKVConfig(t *testing.T) { + re := require.New(t) // case1: big region. { body := `{ "coprocessor": { @@ -40,32 +38,33 @@ func (t *testTiKVConfigSuite) TestTiKVConfig(c *C) { "perf-level": 2 }}` var config StoreConfig - c.Assert(json.Unmarshal([]byte(body), &config), IsNil) + re.NoError(json.Unmarshal([]byte(body), &config)) - c.Assert(config.GetRegionMaxKeys(), Equals, uint64(144000000)) - c.Assert(config.GetRegionSplitKeys(), Equals, uint64(96000000)) - c.Assert(int(config.GetRegionMaxSize()), Equals, 15*1024) - c.Assert(config.GetRegionSplitSize(), Equals, uint64(10*1024)) + re.Equal(uint64(144000000), config.GetRegionMaxKeys()) + re.Equal(uint64(96000000), config.GetRegionSplitKeys()) + re.Equal(15*1024, int(config.GetRegionMaxSize())) + re.Equal(uint64(10*1024), config.GetRegionSplitSize()) } //case2: empty config. { body := `{}` var config StoreConfig - c.Assert(json.Unmarshal([]byte(body), &config), IsNil) + re.NoError(json.Unmarshal([]byte(body), &config)) - c.Assert(config.GetRegionMaxKeys(), Equals, uint64(1440000)) - c.Assert(config.GetRegionSplitKeys(), Equals, uint64(960000)) - c.Assert(int(config.GetRegionMaxSize()), Equals, 144) - c.Assert(config.GetRegionSplitSize(), Equals, uint64(96)) + re.Equal(uint64(1440000), config.GetRegionMaxKeys()) + re.Equal(uint64(960000), config.GetRegionSplitKeys()) + re.Equal(144, int(config.GetRegionMaxSize())) + re.Equal(uint64(96), config.GetRegionSplitSize()) } } -func (t *testTiKVConfigSuite) TestUpdateConfig(c *C) { +func TestUpdateConfig(t *testing.T) { + re := require.New(t) manager := NewTestStoreConfigManager([]string{"tidb.com"}) manager.ObserveConfig("tikv.com") - c.Assert(manager.GetStoreConfig().GetRegionMaxSize(), Equals, uint64(144)) + re.Equal(uint64(144), manager.GetStoreConfig().GetRegionMaxSize()) manager.ObserveConfig("tidb.com") - c.Assert(manager.GetStoreConfig().GetRegionMaxSize(), Equals, uint64(10)) + re.Equal(uint64(10), manager.GetStoreConfig().GetRegionMaxSize()) client := &http.Client{ Transport: &http.Transport{ @@ -74,10 +73,11 @@ func (t *testTiKVConfigSuite) TestUpdateConfig(c *C) { }, } manager = NewStoreConfigManager(client) - c.Assert(manager.source.(*TiKVConfigSource).schema, Equals, "http") + re.Equal("http", manager.source.(*TiKVConfigSource).schema) } -func (t *testTiKVConfigSuite) TestParseConfig(c *C) { +func TestParseConfig(t *testing.T) { + re := require.New(t) body := ` { "coprocessor":{ @@ -97,11 +97,12 @@ func (t *testTiKVConfigSuite) TestParseConfig(c *C) { ` var config StoreConfig - c.Assert(json.Unmarshal([]byte(body), &config), IsNil) - c.Assert(config.GetRegionBucketSize(), Equals, uint64(96)) + re.NoError(json.Unmarshal([]byte(body), &config)) + re.Equal(uint64(96), config.GetRegionBucketSize()) } -func (t *testTiKVConfigSuite) TestMergeCheck(c *C) { +func TestMergeCheck(t *testing.T) { + re := require.New(t) testdata := []struct { size uint64 mergeSize uint64 @@ -140,11 +141,11 @@ func (t *testTiKVConfigSuite) TestMergeCheck(c *C) { config := &StoreConfig{} for _, v := range testdata { if v.pass { - c.Assert(config.CheckRegionSize(v.size, v.mergeSize), IsNil) - c.Assert(config.CheckRegionKeys(v.keys, v.mergeKeys), IsNil) + re.NoError(config.CheckRegionSize(v.size, v.mergeSize)) + re.NoError(config.CheckRegionKeys(v.keys, v.mergeKeys)) } else { - c.Assert(config.CheckRegionSize(v.size, v.mergeSize), NotNil) - c.Assert(config.CheckRegionKeys(v.keys, v.mergeKeys), NotNil) + re.Error(config.CheckRegionSize(v.size, v.mergeSize)) + re.Error(config.CheckRegionKeys(v.keys, v.mergeKeys)) } } } diff --git a/server/config/util_test.go b/server/config/util_test.go index 6b411d51eaa..6327d465bd6 100644 --- a/server/config/util_test.go +++ b/server/config/util_test.go @@ -15,15 +15,14 @@ package config import ( - . "github.com/pingcap/check" + "testing" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testUtilSuite{}) - -type testUtilSuite struct{} - -func (s *testUtilSuite) TestValidateLabels(c *C) { +func TestValidateLabels(t *testing.T) { + re := require.New(t) tests := []struct { label string hasErr bool @@ -51,12 +50,13 @@ func (s *testUtilSuite) TestValidateLabels(c *C) { {"a$b", true}, {"$$", true}, } - for _, t := range tests { - c.Assert(ValidateLabels([]*metapb.StoreLabel{{Key: t.label}}) != nil, Equals, t.hasErr) + for _, test := range tests { + re.Equal(test.hasErr, ValidateLabels([]*metapb.StoreLabel{{Key: test.label}}) != nil) } } -func (s *testUtilSuite) TestValidateURLWithScheme(c *C) { +func TestValidateURLWithScheme(t *testing.T) { + re := require.New(t) tests := []struct { addr string hasErr bool @@ -73,7 +73,7 @@ func (s *testUtilSuite) TestValidateURLWithScheme(c *C) { {"https://foo.com/bar", false}, {"https://foo.com/bar/", false}, } - for _, t := range tests { - c.Assert(ValidateURLWithScheme(t.addr) != nil, Equals, t.hasErr) + for _, test := range tests { + re.Equal(test.hasErr, ValidateURLWithScheme(test.addr) != nil) } } From 4133b9ff199d24ee52e123b6644c4929550074b7 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Tue, 7 Jun 2022 14:14:30 +0800 Subject: [PATCH 24/82] pkg: parallelize the pkg tests (#5094) ref tikv/pd#5087 Parallelize the `pkg` tests as much as possible. Signed-off-by: JmPotato Co-authored-by: Ti Chi Robot --- pkg/apiutil/apiutil_test.go | 2 ++ pkg/assertutil/assertutil_test.go | 1 + pkg/audit/audit_test.go | 3 +++ pkg/autoscaling/calculation_test.go | 5 +++++ pkg/autoscaling/prometheus_test.go | 6 ++++++ pkg/cache/cache_test.go | 5 +++++ pkg/codec/codec_test.go | 2 ++ pkg/encryption/config_test.go | 4 ++++ pkg/encryption/crypter_test.go | 5 +++++ pkg/encryption/master_key_test.go | 8 ++++++++ pkg/encryption/region_crypter_test.go | 8 ++++++++ pkg/errs/errs_test.go | 2 ++ pkg/etcdutil/etcdutil_test.go | 3 +++ pkg/grpcutil/grpcutil_test.go | 1 + pkg/keyutil/util_test.go | 1 + pkg/logutil/log_test.go | 2 ++ pkg/metricutil/metricutil_test.go | 1 + pkg/mock/mockhbstream/mockhbstream_test.go | 1 + pkg/movingaverage/avg_over_time_test.go | 4 ++++ pkg/movingaverage/max_filter_test.go | 1 + pkg/movingaverage/moving_average_test.go | 2 ++ pkg/movingaverage/queue_test.go | 2 ++ pkg/netutil/address_test.go | 2 ++ pkg/progress/progress_test.go | 2 ++ pkg/rangetree/range_tree_test.go | 2 ++ pkg/ratelimit/concurrency_limiter_test.go | 1 + pkg/ratelimit/limiter_test.go | 5 +++++ pkg/ratelimit/ratelimiter_test.go | 1 + pkg/reflectutil/tag_test.go | 3 +++ pkg/requestutil/context_test.go | 2 ++ pkg/slice/slice_test.go | 2 ++ pkg/typeutil/comparison_test.go | 3 +++ pkg/typeutil/conversion_test.go | 3 +++ pkg/typeutil/duration_test.go | 2 ++ pkg/typeutil/size_test.go | 2 ++ pkg/typeutil/string_slice_test.go | 2 ++ pkg/typeutil/time_test.go | 3 +++ server/config/util.go | 7 ++++--- 38 files changed, 108 insertions(+), 3 deletions(-) diff --git a/pkg/apiutil/apiutil_test.go b/pkg/apiutil/apiutil_test.go index 8a79edc9784..bbbb3b860fb 100644 --- a/pkg/apiutil/apiutil_test.go +++ b/pkg/apiutil/apiutil_test.go @@ -25,6 +25,7 @@ import ( ) func TestJsonRespondErrorOk(t *testing.T) { + t.Parallel() re := require.New(t) rd := render.New(render.Options{ IndentJSON: true, @@ -43,6 +44,7 @@ func TestJsonRespondErrorOk(t *testing.T) { } func TestJsonRespondErrorBadInput(t *testing.T) { + t.Parallel() re := require.New(t) rd := render.New(render.Options{ IndentJSON: true, diff --git a/pkg/assertutil/assertutil_test.go b/pkg/assertutil/assertutil_test.go index 324e403f7b6..8da2ad2b164 100644 --- a/pkg/assertutil/assertutil_test.go +++ b/pkg/assertutil/assertutil_test.go @@ -22,6 +22,7 @@ import ( ) func TestNilFail(t *testing.T) { + t.Parallel() re := require.New(t) var failErr error checker := NewChecker(func() { diff --git a/pkg/audit/audit_test.go b/pkg/audit/audit_test.go index 2b33b62ca55..86da530f81a 100644 --- a/pkg/audit/audit_test.go +++ b/pkg/audit/audit_test.go @@ -32,6 +32,7 @@ import ( ) func TestLabelMatcher(t *testing.T) { + t.Parallel() re := require.New(t) matcher := &LabelMatcher{"testSuccess"} labels1 := &BackendLabels{Labels: []string{"testFail", "testSuccess"}} @@ -41,6 +42,7 @@ func TestLabelMatcher(t *testing.T) { } func TestPrometheusHistogramBackend(t *testing.T) { + t.Parallel() re := require.New(t) serviceAuditHistogramTest := prometheus.NewHistogramVec( prometheus.HistogramOpts{ @@ -87,6 +89,7 @@ func TestPrometheusHistogramBackend(t *testing.T) { } func TestLocalLogBackendUsingFile(t *testing.T) { + t.Parallel() re := require.New(t) backend := NewLocalLogBackend(true) fname := initLog() diff --git a/pkg/autoscaling/calculation_test.go b/pkg/autoscaling/calculation_test.go index 5f9eaaf9767..f5db53fdabd 100644 --- a/pkg/autoscaling/calculation_test.go +++ b/pkg/autoscaling/calculation_test.go @@ -29,6 +29,7 @@ import ( ) func TestGetScaledTiKVGroups(t *testing.T) { + t.Parallel() re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -213,6 +214,7 @@ func (q *mockQuerier) Query(options *QueryOptions) (QueryResult, error) { } func TestGetTotalCPUUseTime(t *testing.T) { + t.Parallel() re := require.New(t) querier := &mockQuerier{} instances := []instance{ @@ -235,6 +237,7 @@ func TestGetTotalCPUUseTime(t *testing.T) { } func TestGetTotalCPUQuota(t *testing.T) { + t.Parallel() re := require.New(t) querier := &mockQuerier{} instances := []instance{ @@ -257,6 +260,7 @@ func TestGetTotalCPUQuota(t *testing.T) { } func TestScaleOutGroupLabel(t *testing.T) { + t.Parallel() re := require.New(t) var jsonStr = []byte(` { @@ -299,6 +303,7 @@ func TestScaleOutGroupLabel(t *testing.T) { } func TestStrategyChangeCount(t *testing.T) { + t.Parallel() re := require.New(t) var count uint64 = 2 strategy := &Strategy{ diff --git a/pkg/autoscaling/prometheus_test.go b/pkg/autoscaling/prometheus_test.go index 2906645b180..f9a38d6071d 100644 --- a/pkg/autoscaling/prometheus_test.go +++ b/pkg/autoscaling/prometheus_test.go @@ -181,6 +181,7 @@ func (c *normalClient) Do(_ context.Context, req *http.Request) (response *http. } func TestRetrieveCPUMetrics(t *testing.T) { + t.Parallel() re := require.New(t) client := &normalClient{ mockData: make(map[string]*response), @@ -225,6 +226,7 @@ func (c *emptyResponseClient) Do(_ context.Context, req *http.Request) (r *http. } func TestEmptyResponse(t *testing.T) { + t.Parallel() re := require.New(t) client := &emptyResponseClient{} querier := NewPrometheusQuerier(client) @@ -252,6 +254,7 @@ func (c *errorHTTPStatusClient) Do(_ context.Context, req *http.Request) (r *htt } func TestErrorHTTPStatus(t *testing.T) { + t.Parallel() re := require.New(t) client := &errorHTTPStatusClient{} querier := NewPrometheusQuerier(client) @@ -277,6 +280,7 @@ func (c *errorPrometheusStatusClient) Do(_ context.Context, req *http.Request) ( } func TestErrorPrometheusStatus(t *testing.T) { + t.Parallel() re := require.New(t) client := &errorPrometheusStatusClient{} querier := NewPrometheusQuerier(client) @@ -287,6 +291,7 @@ func TestErrorPrometheusStatus(t *testing.T) { } func TestGetInstanceNameFromAddress(t *testing.T) { + t.Parallel() re := require.New(t) testCases := []struct { address string @@ -324,6 +329,7 @@ func TestGetInstanceNameFromAddress(t *testing.T) { } func TestGetDurationExpression(t *testing.T) { + t.Parallel() re := require.New(t) testCases := []struct { duration time.Duration diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go index ef257157006..05db409f6dc 100644 --- a/pkg/cache/cache_test.go +++ b/pkg/cache/cache_test.go @@ -24,6 +24,7 @@ import ( ) func TestExpireRegionCache(t *testing.T) { + t.Parallel() re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -115,6 +116,7 @@ func sortIDs(ids []uint64) []uint64 { } func TestLRUCache(t *testing.T) { + t.Parallel() re := require.New(t) cache := newLRU(3) @@ -192,6 +194,7 @@ func TestLRUCache(t *testing.T) { } func TestFifoCache(t *testing.T) { + t.Parallel() re := require.New(t) cache := NewFIFO(3) cache.Put(1, "1") @@ -219,6 +222,7 @@ func TestFifoCache(t *testing.T) { } func TestTwoQueueCache(t *testing.T) { + t.Parallel() re := require.New(t) cache := newTwoQueue(3) cache.Put(1, "1") @@ -303,6 +307,7 @@ func (pq PriorityQueueItemTest) ID() uint64 { } func TestPriorityQueue(t *testing.T) { + t.Parallel() re := require.New(t) testData := []PriorityQueueItemTest{0, 1, 2, 3, 4, 5} pq := NewPriorityQueue(0) diff --git a/pkg/codec/codec_test.go b/pkg/codec/codec_test.go index 50bf552a60d..f734d2e528e 100644 --- a/pkg/codec/codec_test.go +++ b/pkg/codec/codec_test.go @@ -21,6 +21,7 @@ import ( ) func TestDecodeBytes(t *testing.T) { + t.Parallel() re := require.New(t) key := "abcdefghijklmnopqrstuvwxyz" for i := 0; i < len(key); i++ { @@ -31,6 +32,7 @@ func TestDecodeBytes(t *testing.T) { } func TestTableID(t *testing.T) { + t.Parallel() re := require.New(t) key := EncodeBytes([]byte("t\x80\x00\x00\x00\x00\x00\x00\xff")) re.Equal(int64(0xff), key.TableID()) diff --git a/pkg/encryption/config_test.go b/pkg/encryption/config_test.go index 1e3231b0903..a357c344eab 100644 --- a/pkg/encryption/config_test.go +++ b/pkg/encryption/config_test.go @@ -23,6 +23,7 @@ import ( ) func TestAdjustDefaultValue(t *testing.T) { + t.Parallel() re := require.New(t) config := &Config{} err := config.Adjust() @@ -34,18 +35,21 @@ func TestAdjustDefaultValue(t *testing.T) { } func TestAdjustInvalidDataEncryptionMethod(t *testing.T) { + t.Parallel() re := require.New(t) config := &Config{DataEncryptionMethod: "unknown"} re.NotNil(config.Adjust()) } func TestAdjustNegativeRotationDuration(t *testing.T) { + t.Parallel() re := require.New(t) config := &Config{DataKeyRotationPeriod: typeutil.NewDuration(time.Duration(int64(-1)))} re.NotNil(config.Adjust()) } func TestAdjustInvalidMasterKeyType(t *testing.T) { + t.Parallel() re := require.New(t) config := &Config{MasterKey: MasterKeyConfig{Type: "unknown"}} re.NotNil(config.Adjust()) diff --git a/pkg/encryption/crypter_test.go b/pkg/encryption/crypter_test.go index c29ed6a8725..e8b7e06bcdf 100644 --- a/pkg/encryption/crypter_test.go +++ b/pkg/encryption/crypter_test.go @@ -24,6 +24,7 @@ import ( ) func TestEncryptionMethodSupported(t *testing.T) { + t.Parallel() re := require.New(t) re.NotNil(CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_PLAINTEXT)) re.NotNil(CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_UNKNOWN)) @@ -33,6 +34,7 @@ func TestEncryptionMethodSupported(t *testing.T) { } func TestKeyLength(t *testing.T) { + t.Parallel() re := require.New(t) _, err := KeyLength(encryptionpb.EncryptionMethod_PLAINTEXT) re.NotNil(err) @@ -50,6 +52,7 @@ func TestKeyLength(t *testing.T) { } func TestNewIv(t *testing.T) { + t.Parallel() re := require.New(t) ivCtr, err := NewIvCTR() re.NoError(err) @@ -60,6 +63,7 @@ func TestNewIv(t *testing.T) { } func TestNewDataKey(t *testing.T) { + t.Parallel() re := require.New(t) for _, method := range []encryptionpb.EncryptionMethod{ encryptionpb.EncryptionMethod_AES128_CTR, @@ -78,6 +82,7 @@ func TestNewDataKey(t *testing.T) { } func TestAesGcmCrypter(t *testing.T) { + t.Parallel() re := require.New(t) key, err := hex.DecodeString("ed568fbd8c8018ed2d042a4e5d38d6341486922d401d2022fb81e47c900d3f07") re.NoError(err) diff --git a/pkg/encryption/master_key_test.go b/pkg/encryption/master_key_test.go index 990d6322c3e..b8d5657c1fc 100644 --- a/pkg/encryption/master_key_test.go +++ b/pkg/encryption/master_key_test.go @@ -24,6 +24,7 @@ import ( ) func TestPlaintextMasterKey(t *testing.T) { + t.Parallel() re := require.New(t) config := &encryptionpb.MasterKey{ Backend: &encryptionpb.MasterKey_Plaintext{ @@ -49,6 +50,7 @@ func TestPlaintextMasterKey(t *testing.T) { } func TestEncrypt(t *testing.T) { + t.Parallel() re := require.New(t) keyHex := "2f07ec61e5a50284f47f2b402a962ec672e500b26cb3aa568bb1531300c74806" key, err := hex.DecodeString(keyHex) @@ -64,6 +66,7 @@ func TestEncrypt(t *testing.T) { } func TestDecrypt(t *testing.T) { + t.Parallel() re := require.New(t) keyHex := "2f07ec61e5a50284f47f2b402a962ec672e500b26cb3aa568bb1531300c74806" key, err := hex.DecodeString(keyHex) @@ -80,6 +83,7 @@ func TestDecrypt(t *testing.T) { } func TestNewFileMasterKeyMissingPath(t *testing.T) { + t.Parallel() re := require.New(t) config := &encryptionpb.MasterKey{ Backend: &encryptionpb.MasterKey_File{ @@ -93,6 +97,7 @@ func TestNewFileMasterKeyMissingPath(t *testing.T) { } func TestNewFileMasterKeyMissingFile(t *testing.T) { + t.Parallel() re := require.New(t) dir, err := os.MkdirTemp("", "test_key_files") re.NoError(err) @@ -109,6 +114,7 @@ func TestNewFileMasterKeyMissingFile(t *testing.T) { } func TestNewFileMasterKeyNotHexString(t *testing.T) { + t.Parallel() re := require.New(t) dir, err := os.MkdirTemp("", "test_key_files") re.NoError(err) @@ -126,6 +132,7 @@ func TestNewFileMasterKeyNotHexString(t *testing.T) { } func TestNewFileMasterKeyLengthMismatch(t *testing.T) { + t.Parallel() re := require.New(t) dir, err := os.MkdirTemp("", "test_key_files") re.NoError(err) @@ -143,6 +150,7 @@ func TestNewFileMasterKeyLengthMismatch(t *testing.T) { } func TestNewFileMasterKey(t *testing.T) { + t.Parallel() re := require.New(t) key := "2f07ec61e5a50284f47f2b402a962ec672e500b26cb3aa568bb1531300c74806" dir, err := os.MkdirTemp("", "test_key_files") diff --git a/pkg/encryption/region_crypter_test.go b/pkg/encryption/region_crypter_test.go index b1ca558063c..5fd9778a8c0 100644 --- a/pkg/encryption/region_crypter_test.go +++ b/pkg/encryption/region_crypter_test.go @@ -70,6 +70,7 @@ func (m *testKeyManager) GetKey(keyID uint64) (*encryptionpb.DataKey, error) { } func TestNilRegion(t *testing.T) { + t.Parallel() re := require.New(t) m := newTestKeyManager() region, err := EncryptRegion(nil, m) @@ -80,6 +81,7 @@ func TestNilRegion(t *testing.T) { } func TestEncryptRegionWithoutKeyManager(t *testing.T) { + t.Parallel() re := require.New(t) region := &metapb.Region{ Id: 10, @@ -96,6 +98,7 @@ func TestEncryptRegionWithoutKeyManager(t *testing.T) { } func TestEncryptRegionWhileEncryptionDisabled(t *testing.T) { + t.Parallel() re := require.New(t) region := &metapb.Region{ Id: 10, @@ -114,6 +117,7 @@ func TestEncryptRegionWhileEncryptionDisabled(t *testing.T) { } func TestEncryptRegion(t *testing.T) { + t.Parallel() re := require.New(t) startKey := []byte("abc") endKey := []byte("xyz") @@ -148,6 +152,7 @@ func TestEncryptRegion(t *testing.T) { } func TestDecryptRegionNotEncrypted(t *testing.T) { + t.Parallel() re := require.New(t) region := &metapb.Region{ Id: 10, @@ -165,6 +170,7 @@ func TestDecryptRegionNotEncrypted(t *testing.T) { } func TestDecryptRegionWithoutKeyManager(t *testing.T) { + t.Parallel() re := require.New(t) region := &metapb.Region{ Id: 10, @@ -180,6 +186,7 @@ func TestDecryptRegionWithoutKeyManager(t *testing.T) { } func TestDecryptRegionWhileKeyMissing(t *testing.T) { + t.Parallel() re := require.New(t) keyID := uint64(3) m := newTestKeyManager() @@ -200,6 +207,7 @@ func TestDecryptRegionWhileKeyMissing(t *testing.T) { } func TestDecryptRegion(t *testing.T) { + t.Parallel() re := require.New(t) keyID := uint64(1) startKey := []byte("abc") diff --git a/pkg/errs/errs_test.go b/pkg/errs/errs_test.go index 74e55257d70..4556898d9fd 100644 --- a/pkg/errs/errs_test.go +++ b/pkg/errs/errs_test.go @@ -87,6 +87,7 @@ func TestError(t *testing.T) { } func TestErrorEqual(t *testing.T) { + t.Parallel() re := require.New(t) err1 := ErrSchedulerNotFound.FastGenByArgs() err2 := ErrSchedulerNotFound.FastGenByArgs() @@ -123,6 +124,7 @@ func TestZapError(t *testing.T) { } func TestErrorWithStack(t *testing.T) { + t.Parallel() re := require.New(t) conf := &log.Config{Level: "debug", File: log.FileLogConfig{}, DisableTimestamp: true} lg := newZapTestLogger(conf) diff --git a/pkg/etcdutil/etcdutil_test.go b/pkg/etcdutil/etcdutil_test.go index 7731a319a94..bbb8e595c32 100644 --- a/pkg/etcdutil/etcdutil_test.go +++ b/pkg/etcdutil/etcdutil_test.go @@ -28,6 +28,7 @@ import ( ) func TestMemberHelpers(t *testing.T) { + t.Parallel() re := require.New(t) cfg1 := NewTestSingleConfig() etcd1, err := embed.StartEtcd(cfg1) @@ -110,6 +111,7 @@ func TestMemberHelpers(t *testing.T) { } func TestEtcdKVGet(t *testing.T) { + t.Parallel() re := require.New(t) cfg := NewTestSingleConfig() etcd, err := embed.StartEtcd(cfg) @@ -161,6 +163,7 @@ func TestEtcdKVGet(t *testing.T) { } func TestEtcdKVPutWithTTL(t *testing.T) { + t.Parallel() re := require.New(t) cfg := NewTestSingleConfig() etcd, err := embed.StartEtcd(cfg) diff --git a/pkg/grpcutil/grpcutil_test.go b/pkg/grpcutil/grpcutil_test.go index 44eee64b85e..7e9396fdd0b 100644 --- a/pkg/grpcutil/grpcutil_test.go +++ b/pkg/grpcutil/grpcutil_test.go @@ -21,6 +21,7 @@ func loadTLSContent(re *require.Assertions, caPath, certPath, keyPath string) (c } func TestToTLSConfig(t *testing.T) { + t.Parallel() re := require.New(t) tlsConfig := TLSConfig{ KeyPath: "../../tests/client/cert/pd-server-key.pem", diff --git a/pkg/keyutil/util_test.go b/pkg/keyutil/util_test.go index f69463c5060..dc149d9c81e 100644 --- a/pkg/keyutil/util_test.go +++ b/pkg/keyutil/util_test.go @@ -21,6 +21,7 @@ import ( ) func TestKeyUtil(t *testing.T) { + t.Parallel() re := require.New(t) startKey := []byte("a") endKey := []byte("b") diff --git a/pkg/logutil/log_test.go b/pkg/logutil/log_test.go index 81913905704..fd46acbdda3 100644 --- a/pkg/logutil/log_test.go +++ b/pkg/logutil/log_test.go @@ -23,6 +23,7 @@ import ( ) func TestStringToZapLogLevel(t *testing.T) { + t.Parallel() re := require.New(t) re.Equal(zapcore.FatalLevel, StringToZapLogLevel("fatal")) re.Equal(zapcore.ErrorLevel, StringToZapLogLevel("ERROR")) @@ -34,6 +35,7 @@ func TestStringToZapLogLevel(t *testing.T) { } func TestRedactLog(t *testing.T) { + t.Parallel() re := require.New(t) testCases := []struct { name string diff --git a/pkg/metricutil/metricutil_test.go b/pkg/metricutil/metricutil_test.go index a72eb7ee5f5..02fee0ac5a0 100644 --- a/pkg/metricutil/metricutil_test.go +++ b/pkg/metricutil/metricutil_test.go @@ -23,6 +23,7 @@ import ( ) func TestCamelCaseToSnakeCase(t *testing.T) { + t.Parallel() re := require.New(t) inputs := []struct { name string diff --git a/pkg/mock/mockhbstream/mockhbstream_test.go b/pkg/mock/mockhbstream/mockhbstream_test.go index 5f9d814835b..056f8f251de 100644 --- a/pkg/mock/mockhbstream/mockhbstream_test.go +++ b/pkg/mock/mockhbstream/mockhbstream_test.go @@ -30,6 +30,7 @@ import ( ) func TestActivity(t *testing.T) { + t.Parallel() re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/pkg/movingaverage/avg_over_time_test.go b/pkg/movingaverage/avg_over_time_test.go index 9006fea5d5d..13b7da27aef 100644 --- a/pkg/movingaverage/avg_over_time_test.go +++ b/pkg/movingaverage/avg_over_time_test.go @@ -23,6 +23,7 @@ import ( ) func TestPulse(t *testing.T) { + t.Parallel() re := require.New(t) aot := NewAvgOverTime(5 * time.Second) // warm up @@ -42,6 +43,7 @@ func TestPulse(t *testing.T) { } func TestChange(t *testing.T) { + t.Parallel() re := require.New(t) aot := NewAvgOverTime(5 * time.Second) @@ -75,6 +77,7 @@ func TestChange(t *testing.T) { } func TestMinFilled(t *testing.T) { + t.Parallel() re := require.New(t) interval := 10 * time.Second rate := 1.0 @@ -91,6 +94,7 @@ func TestMinFilled(t *testing.T) { } func TestUnstableInterval(t *testing.T) { + t.Parallel() re := require.New(t) aot := NewAvgOverTime(5 * time.Second) re.Equal(0., aot.Get()) diff --git a/pkg/movingaverage/max_filter_test.go b/pkg/movingaverage/max_filter_test.go index 7d3906ec93c..bba770cecc2 100644 --- a/pkg/movingaverage/max_filter_test.go +++ b/pkg/movingaverage/max_filter_test.go @@ -21,6 +21,7 @@ import ( ) func TestMaxFilter(t *testing.T) { + t.Parallel() re := require.New(t) var empty float64 = 0 data := []float64{2, 1, 3, 4, 1, 1, 3, 3, 2, 0, 5} diff --git a/pkg/movingaverage/moving_average_test.go b/pkg/movingaverage/moving_average_test.go index e54aa70b64a..9f6864b007b 100644 --- a/pkg/movingaverage/moving_average_test.go +++ b/pkg/movingaverage/moving_average_test.go @@ -71,6 +71,7 @@ func checkInstantaneous(re *require.Assertions, ma MovingAvg) { } func TestMedianFilter(t *testing.T) { + t.Parallel() re := require.New(t) var empty float64 = 0 data := []float64{2, 4, 2, 800, 600, 6, 3} @@ -90,6 +91,7 @@ type testCase struct { } func TestMovingAvg(t *testing.T) { + t.Parallel() re := require.New(t) var empty float64 = 0 data := []float64{1, 1, 1, 1, 5, 1, 1, 1} diff --git a/pkg/movingaverage/queue_test.go b/pkg/movingaverage/queue_test.go index 56c2337c9a1..4997ef16254 100644 --- a/pkg/movingaverage/queue_test.go +++ b/pkg/movingaverage/queue_test.go @@ -21,6 +21,7 @@ import ( ) func TestQueue(t *testing.T) { + t.Parallel() re := require.New(t) sq := NewSafeQueue() sq.PushBack(1) @@ -32,6 +33,7 @@ func TestQueue(t *testing.T) { } func TestClone(t *testing.T) { + t.Parallel() re := require.New(t) s1 := NewSafeQueue() s1.PushBack(1) diff --git a/pkg/netutil/address_test.go b/pkg/netutil/address_test.go index 477f794c243..211b29083e6 100644 --- a/pkg/netutil/address_test.go +++ b/pkg/netutil/address_test.go @@ -22,6 +22,7 @@ import ( ) func TestResolveLoopBackAddr(t *testing.T) { + t.Parallel() re := require.New(t) nodes := []struct { address string @@ -39,6 +40,7 @@ func TestResolveLoopBackAddr(t *testing.T) { } func TestIsEnableHttps(t *testing.T) { + t.Parallel() re := require.New(t) re.False(IsEnableHTTPS(http.DefaultClient)) httpClient := &http.Client{ diff --git a/pkg/progress/progress_test.go b/pkg/progress/progress_test.go index 72d23c40a6a..cdb60c9573f 100644 --- a/pkg/progress/progress_test.go +++ b/pkg/progress/progress_test.go @@ -24,6 +24,7 @@ import ( ) func TestProgress(t *testing.T) { + t.Parallel() re := require.New(t) n := "test" m := NewManager() @@ -69,6 +70,7 @@ func TestProgress(t *testing.T) { } func TestAbnormal(t *testing.T) { + t.Parallel() re := require.New(t) n := "test" m := NewManager() diff --git a/pkg/rangetree/range_tree_test.go b/pkg/rangetree/range_tree_test.go index 695183f2f90..a9071ed5f1f 100644 --- a/pkg/rangetree/range_tree_test.go +++ b/pkg/rangetree/range_tree_test.go @@ -86,6 +86,7 @@ func bucketDebrisFactory(startKey, endKey []byte, item RangeItem) []RangeItem { } func TestRingPutItem(t *testing.T) { + t.Parallel() re := require.New(t) bucketTree := NewRangeTree(2, bucketDebrisFactory) bucketTree.Update(newSimpleBucketItem([]byte("002"), []byte("100"))) @@ -120,6 +121,7 @@ func TestRingPutItem(t *testing.T) { } func TestDebris(t *testing.T) { + t.Parallel() re := require.New(t) ringItem := newSimpleBucketItem([]byte("010"), []byte("090")) var overlaps []RangeItem diff --git a/pkg/ratelimit/concurrency_limiter_test.go b/pkg/ratelimit/concurrency_limiter_test.go index 6a2a5c80b9c..4722e243b03 100644 --- a/pkg/ratelimit/concurrency_limiter_test.go +++ b/pkg/ratelimit/concurrency_limiter_test.go @@ -21,6 +21,7 @@ import ( ) func TestConcurrencyLimiter(t *testing.T) { + t.Parallel() re := require.New(t) cl := newConcurrencyLimiter(10) for i := 0; i < 10; i++ { diff --git a/pkg/ratelimit/limiter_test.go b/pkg/ratelimit/limiter_test.go index d1a570ccb35..46e6e4b6498 100644 --- a/pkg/ratelimit/limiter_test.go +++ b/pkg/ratelimit/limiter_test.go @@ -24,6 +24,7 @@ import ( ) func TestUpdateConcurrencyLimiter(t *testing.T) { + t.Parallel() re := require.New(t) opts := []Option{UpdateConcurrencyLimiter(10)} @@ -88,6 +89,7 @@ func TestUpdateConcurrencyLimiter(t *testing.T) { } func TestBlockList(t *testing.T) { + t.Parallel() re := require.New(t) opts := []Option{AddLabelAllowList()} limiter := NewLimiter() @@ -107,6 +109,7 @@ func TestBlockList(t *testing.T) { } func TestUpdateQPSLimiter(t *testing.T) { + t.Parallel() re := require.New(t) opts := []Option{UpdateQPSLimiter(float64(rate.Every(time.Second)), 1)} limiter := NewLimiter() @@ -160,6 +163,7 @@ func TestUpdateQPSLimiter(t *testing.T) { } func TestQPSLimiter(t *testing.T) { + t.Parallel() re := require.New(t) opts := []Option{UpdateQPSLimiter(float64(rate.Every(3*time.Second)), 100)} limiter := NewLimiter() @@ -189,6 +193,7 @@ func TestQPSLimiter(t *testing.T) { } func TestTwoLimiters(t *testing.T) { + t.Parallel() re := require.New(t) cfg := &DimensionConfig{ QPS: 100, diff --git a/pkg/ratelimit/ratelimiter_test.go b/pkg/ratelimit/ratelimiter_test.go index f16bb6a83d2..35b355e7b21 100644 --- a/pkg/ratelimit/ratelimiter_test.go +++ b/pkg/ratelimit/ratelimiter_test.go @@ -22,6 +22,7 @@ import ( ) func TestRateLimiter(t *testing.T) { + t.Parallel() re := require.New(t) limiter := NewRateLimiter(100, 100) diff --git a/pkg/reflectutil/tag_test.go b/pkg/reflectutil/tag_test.go index 8e8c4dc7754..bff74b36f40 100644 --- a/pkg/reflectutil/tag_test.go +++ b/pkg/reflectutil/tag_test.go @@ -35,6 +35,7 @@ type testStruct3 struct { } func TestFindJSONFullTagByChildTag(t *testing.T) { + t.Parallel() re := require.New(t) key := "enable" result := FindJSONFullTagByChildTag(reflect.TypeOf(testStruct1{}), key) @@ -50,6 +51,7 @@ func TestFindJSONFullTagByChildTag(t *testing.T) { } func TestFindSameFieldByJSON(t *testing.T) { + t.Parallel() re := require.New(t) input := map[string]interface{}{ "name": "test2", @@ -63,6 +65,7 @@ func TestFindSameFieldByJSON(t *testing.T) { } func TestFindFieldByJSONTag(t *testing.T) { + t.Parallel() re := require.New(t) t1 := testStruct1{} t2 := testStruct2{} diff --git a/pkg/requestutil/context_test.go b/pkg/requestutil/context_test.go index fe93182d537..c4f005fff9f 100644 --- a/pkg/requestutil/context_test.go +++ b/pkg/requestutil/context_test.go @@ -23,6 +23,7 @@ import ( ) func TestRequestInfo(t *testing.T) { + t.Parallel() re := require.New(t) ctx := context.Background() _, ok := RequestInfoFrom(ctx) @@ -51,6 +52,7 @@ func TestRequestInfo(t *testing.T) { } func TestEndTime(t *testing.T) { + t.Parallel() re := require.New(t) ctx := context.Background() _, ok := EndTimeFrom(ctx) diff --git a/pkg/slice/slice_test.go b/pkg/slice/slice_test.go index 809dd2c54b3..d8ba709eb66 100644 --- a/pkg/slice/slice_test.go +++ b/pkg/slice/slice_test.go @@ -22,6 +22,7 @@ import ( ) func TestSlice(t *testing.T) { + t.Parallel() re := require.New(t) testCases := []struct { a []int @@ -44,6 +45,7 @@ func TestSlice(t *testing.T) { } func TestSliceContains(t *testing.T) { + t.Parallel() re := require.New(t) ss := []string{"a", "b", "c"} re.Contains(ss, "a") diff --git a/pkg/typeutil/comparison_test.go b/pkg/typeutil/comparison_test.go index 24934684b03..2a4774091be 100644 --- a/pkg/typeutil/comparison_test.go +++ b/pkg/typeutil/comparison_test.go @@ -22,6 +22,7 @@ import ( ) func TestMinUint64(t *testing.T) { + t.Parallel() re := require.New(t) re.Equal(uint64(1), MinUint64(1, 2)) re.Equal(uint64(1), MinUint64(2, 1)) @@ -29,6 +30,7 @@ func TestMinUint64(t *testing.T) { } func TestMaxUint64(t *testing.T) { + t.Parallel() re := require.New(t) re.Equal(uint64(2), MaxUint64(1, 2)) re.Equal(uint64(2), MaxUint64(2, 1)) @@ -36,6 +38,7 @@ func TestMaxUint64(t *testing.T) { } func TestMinDuration(t *testing.T) { + t.Parallel() re := require.New(t) re.Equal(time.Second, MinDuration(time.Minute, time.Second)) re.Equal(time.Second, MinDuration(time.Second, time.Minute)) diff --git a/pkg/typeutil/conversion_test.go b/pkg/typeutil/conversion_test.go index 3398a1ee618..0a209c899e8 100644 --- a/pkg/typeutil/conversion_test.go +++ b/pkg/typeutil/conversion_test.go @@ -23,6 +23,7 @@ import ( ) func TestBytesToUint64(t *testing.T) { + t.Parallel() re := require.New(t) str := "\x00\x00\x00\x00\x00\x00\x03\xe8" a, err := BytesToUint64([]byte(str)) @@ -31,6 +32,7 @@ func TestBytesToUint64(t *testing.T) { } func TestUint64ToBytes(t *testing.T) { + t.Parallel() re := require.New(t) var a uint64 = 1000 b := Uint64ToBytes(a) @@ -39,6 +41,7 @@ func TestUint64ToBytes(t *testing.T) { } func TestJSONToUint64Slice(t *testing.T) { + t.Parallel() re := require.New(t) type testArray struct { Array []uint64 `json:"array"` diff --git a/pkg/typeutil/duration_test.go b/pkg/typeutil/duration_test.go index a7db13ffd04..f815b29ab6b 100644 --- a/pkg/typeutil/duration_test.go +++ b/pkg/typeutil/duration_test.go @@ -27,6 +27,7 @@ type example struct { } func TestDurationJSON(t *testing.T) { + t.Parallel() re := require.New(t) example := &example{} @@ -40,6 +41,7 @@ func TestDurationJSON(t *testing.T) { } func TestDurationTOML(t *testing.T) { + t.Parallel() re := require.New(t) example := &example{} diff --git a/pkg/typeutil/size_test.go b/pkg/typeutil/size_test.go index 4cc9e66f3de..db18928332f 100644 --- a/pkg/typeutil/size_test.go +++ b/pkg/typeutil/size_test.go @@ -22,6 +22,7 @@ import ( ) func TestSizeJSON(t *testing.T) { + t.Parallel() re := require.New(t) b := ByteSize(265421587) o, err := json.Marshal(b) @@ -38,6 +39,7 @@ func TestSizeJSON(t *testing.T) { } func TestParseMbFromText(t *testing.T) { + t.Parallel() re := require.New(t) testCases := []struct { body []string diff --git a/pkg/typeutil/string_slice_test.go b/pkg/typeutil/string_slice_test.go index 9177cee0eb9..9a197eb68e4 100644 --- a/pkg/typeutil/string_slice_test.go +++ b/pkg/typeutil/string_slice_test.go @@ -22,6 +22,7 @@ import ( ) func TestStringSliceJSON(t *testing.T) { + t.Parallel() re := require.New(t) b := StringSlice([]string{"zone", "rack"}) o, err := json.Marshal(b) @@ -35,6 +36,7 @@ func TestStringSliceJSON(t *testing.T) { } func TestEmpty(t *testing.T) { + t.Parallel() re := require.New(t) ss := StringSlice([]string{}) b, err := json.Marshal(ss) diff --git a/pkg/typeutil/time_test.go b/pkg/typeutil/time_test.go index b8078f63fa8..7a5baf55afa 100644 --- a/pkg/typeutil/time_test.go +++ b/pkg/typeutil/time_test.go @@ -23,6 +23,7 @@ import ( ) func TestParseTimestamp(t *testing.T) { + t.Parallel() re := require.New(t) for i := 0; i < 3; i++ { t := time.Now().Add(time.Second * time.Duration(rand.Int31n(1000))) @@ -38,6 +39,7 @@ func TestParseTimestamp(t *testing.T) { } func TestSubTimeByWallClock(t *testing.T) { + t.Parallel() re := require.New(t) for i := 0; i < 100; i++ { r := rand.Int63n(1000) @@ -61,6 +63,7 @@ func TestSubTimeByWallClock(t *testing.T) { } func TestSmallTimeDifference(t *testing.T) { + t.Parallel() re := require.New(t) t1, err := time.Parse("2006-01-02 15:04:05.999", "2021-04-26 00:44:25.682") re.NoError(err) diff --git a/server/config/util.go b/server/config/util.go index af11a7d8fbe..b333b376b5b 100644 --- a/server/config/util.go +++ b/server/config/util.go @@ -18,6 +18,7 @@ import ( "net/url" "regexp" "strings" + "sync" "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/metapb" @@ -66,16 +67,16 @@ func ValidateURLWithScheme(rawURL string) error { return nil } -var schedulerMap = make(map[string]struct{}) +var schedulerMap sync.Map // RegisterScheduler registers the scheduler type. func RegisterScheduler(typ string) { - schedulerMap[typ] = struct{}{} + schedulerMap.Store(typ, struct{}{}) } // IsSchedulerRegistered checks if the named scheduler type is registered. func IsSchedulerRegistered(name string) bool { - _, ok := schedulerMap[name] + _, ok := schedulerMap.Load(name) return ok } From 36db3c745f89c12d991886d79c57e6e677b1ba72 Mon Sep 17 00:00:00 2001 From: LLThomas Date: Tue, 7 Jun 2022 14:28:30 +0800 Subject: [PATCH 25/82] kv: migrate test framework to testify (#5111) ref tikv/pd#4813 Signed-off-by: LLThomas Co-authored-by: Ti Chi Robot --- server/storage/kv/kv_test.go | 69 +++++++++++++++++------------------- 1 file changed, 32 insertions(+), 37 deletions(-) diff --git a/server/storage/kv/kv_test.go b/server/storage/kv/kv_test.go index 51c90e8a1d1..88bac9b279f 100644 --- a/server/storage/kv/kv_test.go +++ b/server/storage/kv/kv_test.go @@ -23,75 +23,70 @@ import ( "strconv" "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/tempurl" "go.etcd.io/etcd/clientv3" "go.etcd.io/etcd/embed" ) -func TestKV(t *testing.T) { - TestingT(t) -} - -type testKVSuite struct{} - -var _ = Suite(&testKVSuite{}) - -func (s *testKVSuite) TestEtcd(c *C) { +func TestEtcd(t *testing.T) { + re := require.New(t) cfg := newTestSingleConfig() defer cleanConfig(cfg) etcd, err := embed.StartEtcd(cfg) - c.Assert(err, IsNil) + re.NoError(err) defer etcd.Close() ep := cfg.LCUrls[0].String() client, err := clientv3.New(clientv3.Config{ Endpoints: []string{ep}, }) - c.Assert(err, IsNil) + re.NoError(err) rootPath := path.Join("/pd", strconv.FormatUint(100, 10)) kv := NewEtcdKVBase(client, rootPath) - s.testReadWrite(c, kv) - s.testRange(c, kv) + testReadWrite(re, kv) + testRange(re, kv) } -func (s *testKVSuite) TestLevelDB(c *C) { +func TestLevelDB(t *testing.T) { + re := require.New(t) dir, err := os.MkdirTemp("/tmp", "leveldb_kv") - c.Assert(err, IsNil) + re.NoError(err) defer os.RemoveAll(dir) kv, err := NewLevelDBKV(dir) - c.Assert(err, IsNil) + re.NoError(err) - s.testReadWrite(c, kv) - s.testRange(c, kv) + testReadWrite(re, kv) + testRange(re, kv) } -func (s *testKVSuite) TestMemKV(c *C) { +func TestMemKV(t *testing.T) { + re := require.New(t) kv := NewMemoryKV() - s.testReadWrite(c, kv) - s.testRange(c, kv) + testReadWrite(re, kv) + testRange(re, kv) } -func (s *testKVSuite) testReadWrite(c *C, kv Base) { +func testReadWrite(re *require.Assertions, kv Base) { v, err := kv.Load("key") - c.Assert(err, IsNil) - c.Assert(v, Equals, "") + re.NoError(err) + re.Equal("", v) err = kv.Save("key", "value") - c.Assert(err, IsNil) + re.NoError(err) v, err = kv.Load("key") - c.Assert(err, IsNil) - c.Assert(v, Equals, "value") + re.NoError(err) + re.Equal("value", v) err = kv.Remove("key") - c.Assert(err, IsNil) + re.NoError(err) v, err = kv.Load("key") - c.Assert(err, IsNil) - c.Assert(v, Equals, "") + re.NoError(err) + re.Equal("", v) err = kv.Remove("key") - c.Assert(err, IsNil) + re.NoError(err) } -func (s *testKVSuite) testRange(c *C, kv Base) { +func testRange(re *require.Assertions, kv Base) { keys := []string{ "test-a", "test-a/a", "test-a/ab", "test", "test/a", "test/ab", @@ -99,7 +94,7 @@ func (s *testKVSuite) testRange(c *C, kv Base) { } for _, k := range keys { err := kv.Save(k, k) - c.Assert(err, IsNil) + re.NoError(err) } sortedKeys := append(keys[:0:0], keys...) sort.Strings(sortedKeys) @@ -120,9 +115,9 @@ func (s *testKVSuite) testRange(c *C, kv Base) { for _, tc := range testCases { ks, vs, err := kv.LoadRange(tc.start, tc.end, tc.limit) - c.Assert(err, IsNil) - c.Assert(ks, DeepEquals, tc.expect) - c.Assert(vs, DeepEquals, tc.expect) + re.NoError(err) + re.Equal(tc.expect, ks) + re.Equal(tc.expect, vs) } } From 12a9513c7392612a4bf0a338537306cbcec80797 Mon Sep 17 00:00:00 2001 From: Shirly Date: Tue, 7 Jun 2022 15:14:30 +0800 Subject: [PATCH 26/82] server/grpc_service: make update gc_safepoint concurrently safe (#5070) close tikv/pd#5018 Signed-off-by: shirly Co-authored-by: buffer <1045931706@qq.com> Co-authored-by: Ti Chi Robot --- server/gc/safepoint.go | 64 +++++++++++++++++++++++++++++ server/gc/safepoint_test.go | 80 +++++++++++++++++++++++++++++++++++++ server/grpc_service.go | 13 ++---- server/server.go | 4 ++ 4 files changed, 151 insertions(+), 10 deletions(-) create mode 100644 server/gc/safepoint.go create mode 100644 server/gc/safepoint_test.go diff --git a/server/gc/safepoint.go b/server/gc/safepoint.go new file mode 100644 index 00000000000..3cec08d8951 --- /dev/null +++ b/server/gc/safepoint.go @@ -0,0 +1,64 @@ +// Copyright 2022 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gc + +import ( + "github.com/tikv/pd/pkg/syncutil" + "github.com/tikv/pd/server/storage/endpoint" +) + +// SafePointManager is the manager for safePoint of GC and services +type SafePointManager struct { + *gcSafePointManager + // TODO add ServiceSafepointManager +} + +// NewSafepointManager creates a SafePointManager of GC and services +func NewSafepointManager(store endpoint.GCSafePointStorage) *SafePointManager { + return &SafePointManager{ + newGCSafePointManager(store), + } +} + +type gcSafePointManager struct { + syncutil.Mutex + store endpoint.GCSafePointStorage +} + +func newGCSafePointManager(store endpoint.GCSafePointStorage) *gcSafePointManager { + return &gcSafePointManager{store: store} +} + +// LoadGCSafePoint loads current GC safe point from storage. +func (manager *gcSafePointManager) LoadGCSafePoint() (uint64, error) { + return manager.store.LoadGCSafePoint() +} + +// UpdateGCSafePoint updates the safepoint if it is greater than the previous one +// it returns the old safepoint in the storage. +func (manager *gcSafePointManager) UpdateGCSafePoint(newSafePoint uint64) (oldSafePoint uint64, err error) { + manager.Lock() + defer manager.Unlock() + // TODO: cache the safepoint in the storage. + oldSafePoint, err = manager.store.LoadGCSafePoint() + if err != nil { + return + } + if oldSafePoint >= newSafePoint { + return + } + err = manager.store.SaveGCSafePoint(newSafePoint) + return +} diff --git a/server/gc/safepoint_test.go b/server/gc/safepoint_test.go new file mode 100644 index 00000000000..2af82ba7145 --- /dev/null +++ b/server/gc/safepoint_test.go @@ -0,0 +1,80 @@ +// Copyright 2022 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gc + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/require" + "github.com/tikv/pd/server/storage/endpoint" + "github.com/tikv/pd/server/storage/kv" +) + +func newGCStorage() endpoint.GCSafePointStorage { + return endpoint.NewStorageEndpoint(kv.NewMemoryKV(), nil) +} + +func TestGCSafePointUpdateSequentially(t *testing.T) { + gcSafePointManager := newGCSafePointManager(newGCStorage()) + re := require.New(t) + curSafePoint := uint64(0) + // update gc safePoint with asc value. + for id := 10; id < 20; id++ { + safePoint, err := gcSafePointManager.LoadGCSafePoint() + re.NoError(err) + re.Equal(curSafePoint, safePoint) + previousSafePoint := curSafePoint + curSafePoint = uint64(id) + oldSafePoint, err := gcSafePointManager.UpdateGCSafePoint(curSafePoint) + re.NoError(err) + re.Equal(previousSafePoint, oldSafePoint) + } + + safePoint, err := gcSafePointManager.LoadGCSafePoint() + re.NoError(err) + re.Equal(curSafePoint, safePoint) + // update with smaller value should be failed. + oldSafePoint, err := gcSafePointManager.UpdateGCSafePoint(safePoint - 5) + re.NoError(err) + re.Equal(safePoint, oldSafePoint) + curSafePoint, err = gcSafePointManager.LoadGCSafePoint() + re.NoError(err) + // current safePoint should not change since the update value was smaller + re.Equal(safePoint, curSafePoint) +} + +func TestGCSafePointUpdateCurrently(t *testing.T) { + gcSafePointManager := newGCSafePointManager(newGCStorage()) + maxSafePoint := uint64(1000) + wg := sync.WaitGroup{} + re := require.New(t) + + // update gc safePoint concurrently + for id := 0; id < 20; id++ { + wg.Add(1) + go func(step uint64) { + for safePoint := step; safePoint <= maxSafePoint; safePoint += step { + _, err := gcSafePointManager.UpdateGCSafePoint(safePoint) + re.NoError(err) + } + wg.Done() + }(uint64(id + 1)) + } + wg.Wait() + safePoint, err := gcSafePointManager.LoadGCSafePoint() + re.NoError(err) + re.Equal(maxSafePoint, safePoint) +} diff --git a/server/grpc_service.go b/server/grpc_service.go index d487f059ec5..53b74ba517d 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -1295,8 +1295,7 @@ func (s *GrpcServer) GetGCSafePoint(ctx context.Context, request *pdpb.GetGCSafe return &pdpb.GetGCSafePointResponse{Header: s.notBootstrappedHeader()}, nil } - var storage endpoint.GCSafePointStorage = s.storage - safePoint, err := storage.LoadGCSafePoint() + safePoint, err := s.gcSafePointManager.LoadGCSafePoint() if err != nil { return nil, err } @@ -1335,19 +1334,13 @@ func (s *GrpcServer) UpdateGCSafePoint(ctx context.Context, request *pdpb.Update return &pdpb.UpdateGCSafePointResponse{Header: s.notBootstrappedHeader()}, nil } - var storage endpoint.GCSafePointStorage = s.storage - oldSafePoint, err := storage.LoadGCSafePoint() + newSafePoint := request.GetSafePoint() + oldSafePoint, err := s.gcSafePointManager.UpdateGCSafePoint(newSafePoint) if err != nil { return nil, err } - newSafePoint := request.SafePoint - - // Only save the safe point if it's greater than the previous one if newSafePoint > oldSafePoint { - if err := storage.SaveGCSafePoint(newSafePoint); err != nil { - return nil, err - } log.Info("updated gc safe point", zap.Uint64("safe-point", newSafePoint)) } else if newSafePoint < oldSafePoint { diff --git a/server/server.go b/server/server.go index bd13193532f..c27941f7c85 100644 --- a/server/server.go +++ b/server/server.go @@ -52,6 +52,7 @@ import ( "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/core" "github.com/tikv/pd/server/encryptionkm" + "github.com/tikv/pd/server/gc" "github.com/tikv/pd/server/id" "github.com/tikv/pd/server/member" syncer "github.com/tikv/pd/server/region_syncer" @@ -130,6 +131,8 @@ type Server struct { encryptionKeyManager *encryptionkm.KeyManager // for storage operation. storage storage.Storage + // safepoint manager + gcSafePointManager *gc.SafePointManager // for basicCluster operation. basicCluster *core.BasicCluster // for tso. @@ -410,6 +413,7 @@ func (s *Server) startServer(ctx context.Context) error { } defaultStorage := storage.NewStorageWithEtcdBackend(s.client, s.rootPath) s.storage = storage.NewCoreStorage(defaultStorage, regionStorage) + s.gcSafePointManager = gc.NewSafepointManager(s.storage) s.basicCluster = core.NewBasicCluster() s.cluster = cluster.NewRaftCluster(ctx, s.clusterID, syncer.NewRegionSyncer(s), s.client, s.httpClient) s.hbStreams = hbstream.NewHeartbeatStreams(ctx, s.clusterID, s.cluster) From 147e9c041a49155e3c15ec0432cf9de1d4d38c44 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Tue, 7 Jun 2022 15:26:30 +0800 Subject: [PATCH 27/82] *: use testutil.Eventually to replace testutil.WaitUntil (#5108) ref tikv/pd#5105 Use testutil.Eventually to replace testutil.WaitUntil. Signed-off-by: JmPotato Co-authored-by: Ti Chi Robot --- pkg/mock/mockhbstream/mockhbstream_test.go | 6 +-- pkg/testutil/testutil.go | 20 ++++----- tests/client/client_test.go | 52 +++++++++++----------- tests/cluster.go | 8 ++-- 4 files changed, 41 insertions(+), 45 deletions(-) diff --git a/pkg/mock/mockhbstream/mockhbstream_test.go b/pkg/mock/mockhbstream/mockhbstream_test.go index 056f8f251de..f31fea5b589 100644 --- a/pkg/mock/mockhbstream/mockhbstream_test.go +++ b/pkg/mock/mockhbstream/mockhbstream_test.go @@ -49,13 +49,13 @@ func TestActivity(t *testing.T) { // Active stream is stream1. hbs.BindStream(1, stream1) - testutil.WaitUntilWithTestingT(t, func() bool { + testutil.Eventually(re, func() bool { hbs.SendMsg(region, proto.Clone(msg).(*pdpb.RegionHeartbeatResponse)) return stream1.Recv() != nil && stream2.Recv() == nil }) // Rebind to stream2. hbs.BindStream(1, stream2) - testutil.WaitUntilWithTestingT(t, func() bool { + testutil.Eventually(re, func() bool { hbs.SendMsg(region, proto.Clone(msg).(*pdpb.RegionHeartbeatResponse)) return stream1.Recv() == nil && stream2.Recv() != nil }) @@ -66,7 +66,7 @@ func TestActivity(t *testing.T) { re.NotNil(res.GetHeader().GetError()) // Switch back to 1 again. hbs.BindStream(1, stream1) - testutil.WaitUntilWithTestingT(t, func() bool { + testutil.Eventually(re, func() bool { hbs.SendMsg(region, proto.Clone(msg).(*pdpb.RegionHeartbeatResponse)) return stream1.Recv() != nil && stream2.Recv() == nil }) diff --git a/pkg/testutil/testutil.go b/pkg/testutil/testutil.go index dfb209c648d..bc54e901a63 100644 --- a/pkg/testutil/testutil.go +++ b/pkg/testutil/testutil.go @@ -17,7 +17,6 @@ package testutil import ( "os" "strings" - "testing" "time" "github.com/pingcap/check" @@ -55,6 +54,7 @@ func WithSleepInterval(sleep time.Duration) WaitOption { } // WaitUntil repeatedly evaluates f() for a period of time, util it returns true. +// NOTICE: this function will be removed soon, please use `Eventually` instead. func WaitUntil(c *check.C, f CheckFunc, opts ...WaitOption) { c.Log("wait start") option := &WaitOp{ @@ -73,10 +73,8 @@ func WaitUntil(c *check.C, f CheckFunc, opts ...WaitOption) { c.Fatal("wait timeout") } -// WaitUntilWithTestingT repeatedly evaluates f() for a period of time, util it returns true. -// NOTICE: this is a temporary function that we will be used to replace `WaitUntil` later. -func WaitUntilWithTestingT(t *testing.T, f CheckFunc, opts ...WaitOption) { - t.Log("wait start") +// Eventually asserts that given condition will be met in a period of time. +func Eventually(re *require.Assertions, condition func() bool, opts ...WaitOption) { option := &WaitOp{ retryTimes: waitMaxRetry, sleepInterval: waitRetrySleep, @@ -84,13 +82,11 @@ func WaitUntilWithTestingT(t *testing.T, f CheckFunc, opts ...WaitOption) { for _, opt := range opts { opt(option) } - for i := 0; i < option.retryTimes; i++ { - if f() { - return - } - time.Sleep(option.sleepInterval) - } - t.Fatal("wait timeout") + re.Eventually( + condition, + option.sleepInterval*time.Duration(option.retryTimes), + option.sleepInterval, + ) } // NewRequestHeader creates a new request header. diff --git a/tests/client/client_test.go b/tests/client/client_test.go index 9d24e5439fd..975b54d72f8 100644 --- a/tests/client/client_test.go +++ b/tests/client/client_test.go @@ -75,7 +75,7 @@ func TestClientLeaderChange(t *testing.T) { cli := setupCli(re, ctx, endpoints) var ts1, ts2 uint64 - testutil.WaitUntilWithTestingT(t, func() bool { + testutil.Eventually(re, func() bool { p1, l1, err := cli.GetTS(context.TODO()) if err == nil { ts1 = tsoutil.ComposeTS(p1, l1) @@ -87,16 +87,16 @@ func TestClientLeaderChange(t *testing.T) { re.True(cluster.CheckTSOUnique(ts1)) leader := cluster.GetLeader() - waitLeader(t, cli.(client), cluster.GetServer(leader).GetConfig().ClientUrls) + waitLeader(re, cli.(client), cluster.GetServer(leader).GetConfig().ClientUrls) err = cluster.GetServer(leader).Stop() re.NoError(err) leader = cluster.WaitLeader() re.NotEmpty(leader) - waitLeader(t, cli.(client), cluster.GetServer(leader).GetConfig().ClientUrls) + waitLeader(re, cli.(client), cluster.GetServer(leader).GetConfig().ClientUrls) // Check TS won't fall back after leader changed. - testutil.WaitUntilWithTestingT(t, func() bool { + testutil.Eventually(re, func() bool { p2, l2, err := cli.GetTS(context.TODO()) if err == nil { ts2 = tsoutil.ComposeTS(p2, l2) @@ -128,7 +128,7 @@ func TestLeaderTransfer(t *testing.T) { cli := setupCli(re, ctx, endpoints) var lastTS uint64 - testutil.WaitUntilWithTestingT(t, func() bool { + testutil.Eventually(re, func() bool { physical, logical, err := cli.GetTS(context.TODO()) if err == nil { lastTS = tsoutil.ComposeTS(physical, logical) @@ -187,7 +187,7 @@ func TestUpdateAfterResetTSO(t *testing.T) { endpoints := runServer(re, cluster) cli := setupCli(re, ctx, endpoints) - testutil.WaitUntilWithTestingT(t, func() bool { + testutil.Eventually(re, func() bool { _, _, err := cli.GetTS(context.TODO()) return err == nil }) @@ -200,7 +200,7 @@ func TestUpdateAfterResetTSO(t *testing.T) { newLeaderName := cluster.WaitLeader() re.NotEqual(oldLeaderName, newLeaderName) // Request a new TSO. - testutil.WaitUntilWithTestingT(t, func() bool { + testutil.Eventually(re, func() bool { _, _, err := cli.GetTS(context.TODO()) return err == nil }) @@ -209,7 +209,7 @@ func TestUpdateAfterResetTSO(t *testing.T) { err = cluster.GetServer(newLeaderName).ResignLeader() re.NoError(err) // Should NOT panic here. - testutil.WaitUntilWithTestingT(t, func() bool { + testutil.Eventually(re, func() bool { _, _, err := cli.GetTS(context.TODO()) return err == nil }) @@ -235,7 +235,7 @@ func TestTSOAllocatorLeader(t *testing.T) { err = cluster.RunInitialServers() re.NoError(err) - cluster.WaitAllLeadersWithTestingT(t, dcLocationConfig) + cluster.WaitAllLeadersWithTestify(re, dcLocationConfig) var ( testServers = cluster.GetServers() @@ -249,7 +249,7 @@ func TestTSOAllocatorLeader(t *testing.T) { var allocatorLeaderMap = make(map[string]string) for _, dcLocation := range dcLocationConfig { var pdName string - testutil.WaitUntilWithTestingT(t, func() bool { + testutil.Eventually(re, func() bool { pdName = cluster.WaitAllocatorLeader(dcLocation) return len(pdName) > 0 }) @@ -347,7 +347,7 @@ func TestGlobalAndLocalTSO(t *testing.T) { re.NoError(err) dcLocationConfig["pd4"] = "dc-4" cluster.CheckClusterDCLocation() - cluster.WaitAllLeadersWithTestingT(t, dcLocationConfig) + cluster.WaitAllLeadersWithTestify(re, dcLocationConfig) // Test a nonexistent dc-location for Local TSO p, l, err := cli.GetLocalTS(context.TODO(), "nonexistent-dc") @@ -475,7 +475,7 @@ func TestGetTsoFromFollowerClient1(t *testing.T) { re.Nil(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork", "return(true)")) var lastTS uint64 - testutil.WaitUntilWithTestingT(t, func() bool { + testutil.Eventually(re, func() bool { physical, logical, err := cli.GetTS(context.TODO()) if err == nil { lastTS = tsoutil.ComposeTS(physical, logical) @@ -506,7 +506,7 @@ func TestGetTsoFromFollowerClient2(t *testing.T) { re.Nil(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork", "return(true)")) var lastTS uint64 - testutil.WaitUntilWithTestingT(t, func() bool { + testutil.Eventually(re, func() bool { physical, logical, err := cli.GetTS(context.TODO()) if err == nil { lastTS = tsoutil.ComposeTS(physical, logical) @@ -560,8 +560,8 @@ func setupCli(re *require.Assertions, ctx context.Context, endpoints []string, o return cli } -func waitLeader(t *testing.T, cli client, leader string) { - testutil.WaitUntilWithTestingT(t, func() bool { +func waitLeader(re *require.Assertions, cli client, leader string) { + testutil.Eventually(re, func() bool { cli.ScheduleCheckLeader() return cli.GetLeaderAddr() == leader }) @@ -835,8 +835,8 @@ func (suite *clientTestSuite) TestGetRegion() { } err := suite.regionHeartbeat.Send(req) suite.NoError(err) - t := suite.T() - testutil.WaitUntilWithTestingT(t, func() bool { + re := suite.Require() + testutil.Eventually(re, func() bool { r, err := suite.client.GetRegion(context.Background(), []byte("a")) suite.NoError(err) if r == nil { @@ -864,7 +864,7 @@ func (suite *clientTestSuite) TestGetRegion() { }, } suite.NoError(suite.reportBucket.Send(breq)) - testutil.WaitUntilWithTestingT(t, func() bool { + testutil.Eventually(re, func() bool { r, err := suite.client.GetRegion(context.Background(), []byte("a"), pd.WithBuckets()) suite.NoError(err) if r == nil { @@ -874,7 +874,7 @@ func (suite *clientTestSuite) TestGetRegion() { }) config := suite.srv.GetRaftCluster().GetStoreConfig() config.EnableRegionBucket = false - testutil.WaitUntilWithTestingT(t, func() bool { + testutil.Eventually(re, func() bool { r, err := suite.client.GetRegion(context.Background(), []byte("a"), pd.WithBuckets()) suite.NoError(err) if r == nil { @@ -911,7 +911,7 @@ func (suite *clientTestSuite) TestGetPrevRegion() { } time.Sleep(500 * time.Millisecond) for i := 0; i < 20; i++ { - testutil.WaitUntilWithTestingT(suite.T(), func() bool { + testutil.Eventually(suite.Require(), func() bool { r, err := suite.client.GetPrevRegion(context.Background(), []byte{byte(i)}) suite.NoError(err) if i > 0 && i < regionLen { @@ -949,8 +949,7 @@ func (suite *clientTestSuite) TestScanRegions() { } // Wait for region heartbeats. - t := suite.T() - testutil.WaitUntilWithTestingT(t, func() bool { + testutil.Eventually(suite.Require(), func() bool { scanRegions, err := suite.client.ScanRegions(context.Background(), []byte{0}, nil, 10) return err == nil && len(scanRegions) == 10 }) @@ -967,6 +966,7 @@ func (suite *clientTestSuite) TestScanRegions() { region5 := core.NewRegionInfo(regions[5], regions[5].Peers[0], core.WithPendingPeers([]*metapb.Peer{regions[5].Peers[1], regions[5].Peers[2]})) suite.srv.GetRaftCluster().HandleRegionHeartbeat(region5) + t := suite.T() check := func(start, end []byte, limit int, expect []*metapb.Region) { scanRegions, err := suite.client.ScanRegions(context.Background(), start, end, limit) suite.NoError(err) @@ -1017,7 +1017,7 @@ func (suite *clientTestSuite) TestGetRegionByID() { err := suite.regionHeartbeat.Send(req) suite.NoError(err) - testutil.WaitUntilWithTestingT(suite.T(), func() bool { + testutil.Eventually(suite.Require(), func() bool { r, err := suite.client.GetRegionByID(context.Background(), regionID) suite.NoError(err) if r == nil { @@ -1300,8 +1300,8 @@ func (suite *clientTestSuite) TestScatterRegion() { regionsID := []uint64{regionID} suite.NoError(err) // Test interface `ScatterRegions`. - t := suite.T() - testutil.WaitUntilWithTestingT(t, func() bool { + re := suite.Require() + testutil.Eventually(re, func() bool { scatterResp, err := suite.client.ScatterRegions(context.Background(), regionsID, pd.WithGroup("test"), pd.WithRetry(1)) if err != nil { return false @@ -1320,7 +1320,7 @@ func (suite *clientTestSuite) TestScatterRegion() { // Test interface `ScatterRegion`. // TODO: Deprecate interface `ScatterRegion`. - testutil.WaitUntilWithTestingT(t, func() bool { + testutil.Eventually(re, func() bool { err := suite.client.ScatterRegion(context.Background(), regionID) if err != nil { fmt.Println(err) diff --git a/tests/cluster.go b/tests/cluster.go index 3b0e10a02e6..0d7efe90ec9 100644 --- a/tests/cluster.go +++ b/tests/cluster.go @@ -19,7 +19,6 @@ import ( "net/http" "os" "sync" - "testing" "time" "github.com/coreos/go-semver/semver" @@ -28,6 +27,7 @@ import ( "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/log" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/autoscaling" "github.com/tikv/pd/pkg/dashboard" "github.com/tikv/pd/pkg/errs" @@ -623,9 +623,9 @@ func (c *TestCluster) WaitAllLeaders(testC *check.C, dcLocations map[string]stri wg.Wait() } -// WaitAllLeadersWithTestingT will block and wait for the election of PD leader and all Local TSO Allocator leaders. +// WaitAllLeadersWithTestify will block and wait for the election of PD leader and all Local TSO Allocator leaders. // NOTICE: this is a temporary function that we will be used to replace `WaitAllLeaders` later. -func (c *TestCluster) WaitAllLeadersWithTestingT(t *testing.T, dcLocations map[string]string) { +func (c *TestCluster) WaitAllLeadersWithTestify(re *require.Assertions, dcLocations map[string]string) { c.WaitLeader() c.CheckClusterDCLocation() // Wait for each DC's Local TSO Allocator leader @@ -633,7 +633,7 @@ func (c *TestCluster) WaitAllLeadersWithTestingT(t *testing.T, dcLocations map[s for _, dcLocation := range dcLocations { wg.Add(1) go func(dc string) { - testutil.WaitUntilWithTestingT(t, func() bool { + testutil.Eventually(re, func() bool { leaderName := c.WaitAllocatorLeader(dc) return leaderName != "" }) From df25b66732dd97796fe412458e4629359e680339 Mon Sep 17 00:00:00 2001 From: Shirly Date: Tue, 7 Jun 2022 17:54:29 +0800 Subject: [PATCH 28/82] server/cluster: store the min version of stores as the cluster_version (#5050) close tikv/pd#5049 Signed-off-by: shirly Co-authored-by: Ti Chi Robot --- server/cluster/cluster.go | 23 ++++++++++++----------- server/cluster/cluster_test.go | 21 +++++++++++++++++++++ 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/server/cluster/cluster.go b/server/cluster/cluster.go index 4ff1232752e..7c6f0c3701b 100644 --- a/server/cluster/cluster.go +++ b/server/cluster/cluster.go @@ -1879,19 +1879,20 @@ func (c *RaftCluster) onStoreVersionChangeLocked() { failpoint.Inject("versionChangeConcurrency", func() { time.Sleep(500 * time.Millisecond) }) + if minVersion == nil || clusterVersion.Equal(*minVersion) { + return + } - if minVersion != nil && clusterVersion.LessThan(*minVersion) { - if !c.opt.CASClusterVersion(clusterVersion, minVersion) { - log.Error("cluster version changed by API at the same time") - } - err := c.opt.Persist(c.storage) - if err != nil { - log.Error("persist cluster version meet error", errs.ZapError(err)) - } - log.Info("cluster version changed", - zap.Stringer("old-cluster-version", clusterVersion), - zap.Stringer("new-cluster-version", minVersion)) + if !c.opt.CASClusterVersion(clusterVersion, minVersion) { + log.Error("cluster version changed by API at the same time") + } + err := c.opt.Persist(c.storage) + if err != nil { + log.Error("persist cluster version meet error", errs.ZapError(err)) } + log.Info("cluster version changed", + zap.Stringer("old-cluster-version", clusterVersion), + zap.Stringer("new-cluster-version", minVersion)) } func (c *RaftCluster) changedRegionNotifier() <-chan *core.RegionInfo { diff --git a/server/cluster/cluster_test.go b/server/cluster/cluster_test.go index 66882d95ed5..c91899c662d 100644 --- a/server/cluster/cluster_test.go +++ b/server/cluster/cluster_test.go @@ -512,6 +512,27 @@ func (s *testClusterInfoSuite) TestDeleteStoreUpdatesClusterVersion(c *C) { c.Assert(cluster.GetClusterVersion(), Equals, "5.0.0") } +func (s *testClusterInfoSuite) TestStoreClusterVersion(c *C) { + _, opt, err := newTestScheduleConfig() + c.Assert(err, IsNil) + cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + stores := newTestStores(3, "5.0.0") + s1, s2, s3 := stores[0].GetMeta(), stores[1].GetMeta(), stores[2].GetMeta() + s1.Version = "5.0.1" + s2.Version = "5.0.3" + s3.Version = "5.0.5" + c.Assert(cluster.PutStore(s2), IsNil) + c.Assert(cluster.GetClusterVersion(), Equals, s2.Version) + + c.Assert(cluster.PutStore(s1), IsNil) + // the cluster version should be 5.0.1(the min one) + c.Assert(cluster.GetClusterVersion(), Equals, s1.Version) + + c.Assert(cluster.PutStore(s3), IsNil) + // the cluster version should be 5.0.1(the min one) + c.Assert(cluster.GetClusterVersion(), Equals, s1.Version) +} + func (s *testClusterInfoSuite) TestRegionHeartbeatHotStat(c *C) { _, opt, err := newTestScheduleConfig() c.Assert(err, IsNil) From 1382857bfe4f8cf102bf3370f890d6356e7b5a66 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Wed, 8 Jun 2022 13:50:30 +0800 Subject: [PATCH 29/82] pkg, scripts: refine the use of require.Error/NoError (#5124) ref tikv/pd#4813 Refine the use of require.Error/NoError. Signed-off-by: JmPotato --- pkg/encryption/crypter_test.go | 6 +++--- scripts/check-test.sh | 8 ++++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/pkg/encryption/crypter_test.go b/pkg/encryption/crypter_test.go index e8b7e06bcdf..2f952d5b729 100644 --- a/pkg/encryption/crypter_test.go +++ b/pkg/encryption/crypter_test.go @@ -37,9 +37,9 @@ func TestKeyLength(t *testing.T) { t.Parallel() re := require.New(t) _, err := KeyLength(encryptionpb.EncryptionMethod_PLAINTEXT) - re.NotNil(err) + re.Error(err) _, err = KeyLength(encryptionpb.EncryptionMethod_UNKNOWN) - re.NotNil(err) + re.Error(err) length, err := KeyLength(encryptionpb.EncryptionMethod_AES128_CTR) re.NoError(err) re.Equal(16, length) @@ -111,5 +111,5 @@ func TestAesGcmCrypter(t *testing.T) { // ignore overflow fakeCiphertext[0] = ciphertext[0] + 1 _, err = AesGcmDecrypt(key, fakeCiphertext, iv) - re.NotNil(err) + re.Error(err) } diff --git a/scripts/check-test.sh b/scripts/check-test.sh index c8c5b72c0fe..f65d506565f 100755 --- a/scripts/check-test.sh +++ b/scripts/check-test.sh @@ -41,4 +41,12 @@ if [ "$res" ]; then exit 1 fi +res=$(grep -rn --include=\*_test.go -E "(re|suite|require)\.(Nil|NotNil)\((t, )?(err|error)" . | sort -u) + +if [ "$res" ]; then + echo "following packages use the inefficient assert function: please replace require.Nil/NotNil with require.NoError/Error" + echo "$res" + exit 1 +fi + exit 0 From 79b0290f55beef3065eab7fb22cb84b9e62f4915 Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Wed, 8 Jun 2022 16:30:31 +0800 Subject: [PATCH 30/82] election: migrate test framework to testify (#5132) ref tikv/pd#4813 Signed-off-by: Ryan Leung Co-authored-by: Ti Chi Robot --- server/election/leadership_test.go | 65 +++++++++++++----------------- server/election/lease_test.go | 64 ++++++++++++++--------------- 2 files changed, 60 insertions(+), 69 deletions(-) diff --git a/server/election/leadership_test.go b/server/election/leadership_test.go index a52f867288c..9a4b52f782e 100644 --- a/server/election/leadership_test.go +++ b/server/election/leadership_test.go @@ -19,36 +19,29 @@ import ( "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/etcdutil" "go.etcd.io/etcd/clientv3" "go.etcd.io/etcd/embed" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testLeadershipSuite{}) - -type testLeadershipSuite struct{} - const defaultLeaseTimeout = 1 -func (s *testLeadershipSuite) TestLeadership(c *C) { +func TestLeadership(t *testing.T) { + re := require.New(t) cfg := etcdutil.NewTestSingleConfig() etcd, err := embed.StartEtcd(cfg) defer func() { etcd.Close() etcdutil.CleanConfig(cfg) }() - c.Assert(err, IsNil) + re.NoError(err) ep := cfg.LCUrls[0].String() client, err := clientv3.New(clientv3.Config{ Endpoints: []string{ep}, }) - c.Assert(err, IsNil) + re.NoError(err) <-etcd.Server.ReadyNotify() @@ -58,27 +51,27 @@ func (s *testLeadershipSuite) TestLeadership(c *C) { // leadership1 starts first and get the leadership err = leadership1.Campaign(defaultLeaseTimeout, "test_leader_1") - c.Assert(err, IsNil) + re.NoError(err) // leadership2 starts then and can not get the leadership err = leadership2.Campaign(defaultLeaseTimeout, "test_leader_2") - c.Assert(err, NotNil) + re.Error(err) - c.Assert(leadership1.Check(), IsTrue) + re.True(leadership1.Check()) // leadership2 failed, so the check should return false - c.Assert(leadership2.Check(), IsFalse) + re.False(leadership2.Check()) // Sleep longer than the defaultLeaseTimeout to wait for the lease expires time.Sleep((defaultLeaseTimeout + 1) * time.Second) - c.Assert(leadership1.Check(), IsFalse) - c.Assert(leadership2.Check(), IsFalse) + re.False(leadership1.Check()) + re.False(leadership2.Check()) // Delete the leader key and campaign for leadership1 err = leadership1.DeleteLeaderKey() - c.Assert(err, IsNil) + re.NoError(err) err = leadership1.Campaign(defaultLeaseTimeout, "test_leader_1") - c.Assert(err, IsNil) - c.Assert(leadership1.Check(), IsTrue) + re.NoError(err) + re.True(leadership1.Check()) ctx, cancel := context.WithCancel(context.Background()) defer cancel() go leadership1.Keep(ctx) @@ -86,15 +79,15 @@ func (s *testLeadershipSuite) TestLeadership(c *C) { // Sleep longer than the defaultLeaseTimeout time.Sleep((defaultLeaseTimeout + 1) * time.Second) - c.Assert(leadership1.Check(), IsTrue) - c.Assert(leadership2.Check(), IsFalse) + re.True(leadership1.Check()) + re.False(leadership2.Check()) // Delete the leader key and re-campaign for leadership2 err = leadership1.DeleteLeaderKey() - c.Assert(err, IsNil) + re.NoError(err) err = leadership2.Campaign(defaultLeaseTimeout, "test_leader_2") - c.Assert(err, IsNil) - c.Assert(leadership2.Check(), IsTrue) + re.NoError(err) + re.True(leadership2.Check()) ctx, cancel = context.WithCancel(context.Background()) defer cancel() go leadership2.Keep(ctx) @@ -102,14 +95,14 @@ func (s *testLeadershipSuite) TestLeadership(c *C) { // Sleep longer than the defaultLeaseTimeout time.Sleep((defaultLeaseTimeout + 1) * time.Second) - c.Assert(leadership1.Check(), IsFalse) - c.Assert(leadership2.Check(), IsTrue) + re.False(leadership1.Check()) + re.True(leadership2.Check()) // Test resetting the leadership. leadership1.Reset() leadership2.Reset() - c.Assert(leadership1.Check(), IsFalse) - c.Assert(leadership2.Check(), IsFalse) + re.False(leadership1.Check()) + re.False(leadership2.Check()) // Try to keep the reset leadership. leadership1.Keep(ctx) @@ -117,12 +110,12 @@ func (s *testLeadershipSuite) TestLeadership(c *C) { // Check the lease. lease1 := leadership1.getLease() - c.Assert(lease1, NotNil) + re.NotNil(lease1) lease2 := leadership1.getLease() - c.Assert(lease2, NotNil) + re.NotNil(lease2) - c.Assert(lease1.IsExpired(), IsTrue) - c.Assert(lease2.IsExpired(), IsTrue) - c.Assert(lease1.Close(), IsNil) - c.Assert(lease2.Close(), IsNil) + re.True(lease1.IsExpired()) + re.True(lease2.IsExpired()) + re.NoError(lease1.Close()) + re.NoError(lease2.Close()) } diff --git a/server/election/lease_test.go b/server/election/lease_test.go index 0c0aa3c1687..ef8c12be2e9 100644 --- a/server/election/lease_test.go +++ b/server/election/lease_test.go @@ -16,32 +16,30 @@ package election import ( "context" + "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/etcdutil" "go.etcd.io/etcd/clientv3" "go.etcd.io/etcd/embed" ) -var _ = Suite(&testLeaseSuite{}) - -type testLeaseSuite struct{} - -func (s *testLeaseSuite) TestLease(c *C) { +func TestLease(t *testing.T) { + re := require.New(t) cfg := etcdutil.NewTestSingleConfig() etcd, err := embed.StartEtcd(cfg) defer func() { etcd.Close() etcdutil.CleanConfig(cfg) }() - c.Assert(err, IsNil) + re.NoError(err) ep := cfg.LCUrls[0].String() client, err := clientv3.New(clientv3.Config{ Endpoints: []string{ep}, }) - c.Assert(err, IsNil) + re.NoError(err) <-etcd.Server.ReadyNotify() @@ -56,51 +54,51 @@ func (s *testLeaseSuite) TestLease(c *C) { client: client, lease: clientv3.NewLease(client), } - c.Check(lease1.IsExpired(), IsTrue) - c.Check(lease2.IsExpired(), IsTrue) - c.Check(lease1.Close(), IsNil) - c.Check(lease2.Close(), IsNil) + re.True(lease1.IsExpired()) + re.True(lease2.IsExpired()) + re.NoError(lease1.Close()) + re.NoError(lease2.Close()) // Grant the two leases with the same timeout. - c.Check(lease1.Grant(defaultLeaseTimeout), IsNil) - c.Check(lease2.Grant(defaultLeaseTimeout), IsNil) - c.Check(lease1.IsExpired(), IsFalse) - c.Check(lease2.IsExpired(), IsFalse) + re.NoError(lease1.Grant(defaultLeaseTimeout)) + re.NoError(lease2.Grant(defaultLeaseTimeout)) + re.False(lease1.IsExpired()) + re.False(lease2.IsExpired()) // Wait for a while to make both two leases timeout. time.Sleep((defaultLeaseTimeout + 1) * time.Second) - c.Check(lease1.IsExpired(), IsTrue) - c.Check(lease2.IsExpired(), IsTrue) + re.True(lease1.IsExpired()) + re.True(lease2.IsExpired()) // Grant the two leases with different timeouts. - c.Check(lease1.Grant(defaultLeaseTimeout), IsNil) - c.Check(lease2.Grant(defaultLeaseTimeout*4), IsNil) - c.Check(lease1.IsExpired(), IsFalse) - c.Check(lease2.IsExpired(), IsFalse) + re.NoError(lease1.Grant(defaultLeaseTimeout)) + re.NoError(lease2.Grant(defaultLeaseTimeout * 4)) + re.False(lease1.IsExpired()) + re.False(lease2.IsExpired()) // Wait for a while to make one of the lease timeout. time.Sleep((defaultLeaseTimeout + 1) * time.Second) - c.Check(lease1.IsExpired(), IsTrue) - c.Check(lease2.IsExpired(), IsFalse) + re.True(lease1.IsExpired()) + re.False(lease2.IsExpired()) // Close both of the two leases. - c.Check(lease1.Close(), IsNil) - c.Check(lease2.Close(), IsNil) - c.Check(lease1.IsExpired(), IsTrue) - c.Check(lease2.IsExpired(), IsTrue) + re.NoError(lease1.Close()) + re.NoError(lease2.Close()) + re.True(lease1.IsExpired()) + re.True(lease2.IsExpired()) // Grant the lease1 and keep it alive. - c.Check(lease1.Grant(defaultLeaseTimeout), IsNil) - c.Check(lease1.IsExpired(), IsFalse) + re.NoError(lease1.Grant(defaultLeaseTimeout)) + re.False(lease1.IsExpired()) ctx, cancel := context.WithCancel(context.Background()) go lease1.KeepAlive(ctx) defer cancel() // Wait for a timeout. time.Sleep((defaultLeaseTimeout + 1) * time.Second) - c.Check(lease1.IsExpired(), IsFalse) + re.False(lease1.IsExpired()) // Close and wait for a timeout. - c.Check(lease1.Close(), IsNil) + re.NoError(lease1.Close()) time.Sleep((defaultLeaseTimeout + 1) * time.Second) - c.Check(lease1.IsExpired(), IsTrue) + re.True(lease1.IsExpired()) } From 6c0985d91647c9711c8456ae92bce3448f100e98 Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Wed, 8 Jun 2022 17:02:30 +0800 Subject: [PATCH 31/82] core: migrate test framework to testify (#5123) ref tikv/pd#4813 Signed-off-by: Ryan Leung Co-authored-by: Ti Chi Robot --- server/core/region_test.go | 262 ++++++++++++++--------------- server/core/region_tree_test.go | 286 ++++++++++++++++---------------- server/core/store_stats_test.go | 31 ++-- server/core/store_test.go | 43 ++--- 4 files changed, 301 insertions(+), 321 deletions(-) diff --git a/server/core/region_test.go b/server/core/region_test.go index c1ed83b7f46..edf55c8ac7b 100644 --- a/server/core/region_test.go +++ b/server/core/region_test.go @@ -19,25 +19,17 @@ import ( "math" "math/rand" "strconv" - "strings" "testing" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/mock/mockid" "github.com/tikv/pd/server/id" ) -func TestCore(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testRegionInfoSuite{}) - -type testRegionInfoSuite struct{} - -func (s *testRegionInfoSuite) TestNeedMerge(c *C) { +func TestNeedMerge(t *testing.T) { + re := require.New(t) mererSize, mergeKeys := int64(20), int64(200000) testdata := []struct { size int64 @@ -69,11 +61,12 @@ func (s *testRegionInfoSuite) TestNeedMerge(c *C) { approximateSize: v.size, approximateKeys: v.keys, } - c.Assert(r.NeedMerge(mererSize, mergeKeys), Equals, v.expect) + re.Equal(v.expect, r.NeedMerge(mererSize, mergeKeys)) } } -func (s *testRegionInfoSuite) TestSortedEqual(c *C) { +func TestSortedEqual(t *testing.T) { + re := require.New(t) testcases := []struct { idsA []int idsB []int @@ -153,47 +146,48 @@ func (s *testRegionInfoSuite) TestSortedEqual(c *C) { return peers } // test NewRegionInfo - for _, t := range testcases { - regionA := NewRegionInfo(&metapb.Region{Id: 100, Peers: pickPeers(t.idsA)}, nil) - regionB := NewRegionInfo(&metapb.Region{Id: 100, Peers: pickPeers(t.idsB)}, nil) - c.Assert(SortedPeersEqual(regionA.GetVoters(), regionB.GetVoters()), Equals, t.isEqual) - c.Assert(SortedPeersEqual(regionA.GetVoters(), regionB.GetVoters()), Equals, t.isEqual) + for _, test := range testcases { + regionA := NewRegionInfo(&metapb.Region{Id: 100, Peers: pickPeers(test.idsA)}, nil) + regionB := NewRegionInfo(&metapb.Region{Id: 100, Peers: pickPeers(test.idsB)}, nil) + re.Equal(test.isEqual, SortedPeersEqual(regionA.GetVoters(), regionB.GetVoters())) + re.Equal(test.isEqual, SortedPeersEqual(regionA.GetVoters(), regionB.GetVoters())) } // test RegionFromHeartbeat - for _, t := range testcases { + for _, test := range testcases { regionA := RegionFromHeartbeat(&pdpb.RegionHeartbeatRequest{ - Region: &metapb.Region{Id: 100, Peers: pickPeers(t.idsA)}, - DownPeers: pickPeerStats(t.idsA), - PendingPeers: pickPeers(t.idsA), + Region: &metapb.Region{Id: 100, Peers: pickPeers(test.idsA)}, + DownPeers: pickPeerStats(test.idsA), + PendingPeers: pickPeers(test.idsA), }) regionB := RegionFromHeartbeat(&pdpb.RegionHeartbeatRequest{ - Region: &metapb.Region{Id: 100, Peers: pickPeers(t.idsB)}, - DownPeers: pickPeerStats(t.idsB), - PendingPeers: pickPeers(t.idsB), + Region: &metapb.Region{Id: 100, Peers: pickPeers(test.idsB)}, + DownPeers: pickPeerStats(test.idsB), + PendingPeers: pickPeers(test.idsB), }) - c.Assert(SortedPeersEqual(regionA.GetVoters(), regionB.GetVoters()), Equals, t.isEqual) - c.Assert(SortedPeersEqual(regionA.GetVoters(), regionB.GetVoters()), Equals, t.isEqual) - c.Assert(SortedPeersEqual(regionA.GetPendingPeers(), regionB.GetPendingPeers()), Equals, t.isEqual) - c.Assert(SortedPeersStatsEqual(regionA.GetDownPeers(), regionB.GetDownPeers()), Equals, t.isEqual) + re.Equal(test.isEqual, SortedPeersEqual(regionA.GetVoters(), regionB.GetVoters())) + re.Equal(test.isEqual, SortedPeersEqual(regionA.GetVoters(), regionB.GetVoters())) + re.Equal(test.isEqual, SortedPeersEqual(regionA.GetPendingPeers(), regionB.GetPendingPeers())) + re.Equal(test.isEqual, SortedPeersStatsEqual(regionA.GetDownPeers(), regionB.GetDownPeers())) } // test Clone region := NewRegionInfo(meta, meta.Peers[0]) - for _, t := range testcases { - downPeersA := pickPeerStats(t.idsA) - downPeersB := pickPeerStats(t.idsB) - pendingPeersA := pickPeers(t.idsA) - pendingPeersB := pickPeers(t.idsB) + for _, test := range testcases { + downPeersA := pickPeerStats(test.idsA) + downPeersB := pickPeerStats(test.idsB) + pendingPeersA := pickPeers(test.idsA) + pendingPeersB := pickPeers(test.idsB) regionA := region.Clone(WithDownPeers(downPeersA), WithPendingPeers(pendingPeersA)) regionB := region.Clone(WithDownPeers(downPeersB), WithPendingPeers(pendingPeersB)) - c.Assert(SortedPeersStatsEqual(regionA.GetDownPeers(), regionB.GetDownPeers()), Equals, t.isEqual) - c.Assert(SortedPeersEqual(regionA.GetPendingPeers(), regionB.GetPendingPeers()), Equals, t.isEqual) + re.Equal(test.isEqual, SortedPeersStatsEqual(regionA.GetDownPeers(), regionB.GetDownPeers())) + re.Equal(test.isEqual, SortedPeersEqual(regionA.GetPendingPeers(), regionB.GetPendingPeers())) } } -func (s *testRegionInfoSuite) TestInherit(c *C) { +func TestInherit(t *testing.T) { + re := require.New(t) // size in MB // case for approximateSize testcases := []struct { @@ -208,16 +202,16 @@ func (s *testRegionInfoSuite) TestInherit(c *C) { {true, 1, 2, 2}, {true, 2, 0, 2}, } - for _, t := range testcases { + for _, test := range testcases { var origin *RegionInfo - if t.originExists { + if test.originExists { origin = NewRegionInfo(&metapb.Region{Id: 100}, nil) - origin.approximateSize = int64(t.originSize) + origin.approximateSize = int64(test.originSize) } r := NewRegionInfo(&metapb.Region{Id: 100}, nil) - r.approximateSize = int64(t.size) + r.approximateSize = int64(test.size) r.Inherit(origin, false) - c.Assert(r.approximateSize, Equals, int64(t.expect)) + re.Equal(int64(test.expect), r.approximateSize) } // bucket @@ -234,17 +228,18 @@ func (s *testRegionInfoSuite) TestInherit(c *C) { origin := NewRegionInfo(&metapb.Region{Id: 100}, nil, SetBuckets(d.originBuckets)) r := NewRegionInfo(&metapb.Region{Id: 100}, nil) r.Inherit(origin, true) - c.Assert(r.GetBuckets(), DeepEquals, d.originBuckets) + re.Equal(d.originBuckets, r.GetBuckets()) // region will not inherit bucket keys. if origin.GetBuckets() != nil { newRegion := NewRegionInfo(&metapb.Region{Id: 100}, nil) newRegion.Inherit(origin, false) - c.Assert(newRegion.GetBuckets(), Not(DeepEquals), d.originBuckets) + re.NotEqual(d.originBuckets, newRegion.GetBuckets()) } } } -func (s *testRegionInfoSuite) TestRegionRoundingFlow(c *C) { +func TestRegionRoundingFlow(t *testing.T) { + re := require.New(t) testcases := []struct { flow uint64 digit int @@ -259,15 +254,16 @@ func (s *testRegionInfoSuite) TestRegionRoundingFlow(c *C) { {252623, math.MaxInt64, 0}, {252623, math.MinInt64, 252623}, } - for _, t := range testcases { - r := NewRegionInfo(&metapb.Region{Id: 100}, nil, WithFlowRoundByDigit(t.digit)) - r.readBytes = t.flow - r.writtenBytes = t.flow - c.Assert(r.GetRoundBytesRead(), Equals, t.expect) + for _, test := range testcases { + r := NewRegionInfo(&metapb.Region{Id: 100}, nil, WithFlowRoundByDigit(test.digit)) + r.readBytes = test.flow + r.writtenBytes = test.flow + re.Equal(test.expect, r.GetRoundBytesRead()) } } -func (s *testRegionInfoSuite) TestRegionWriteRate(c *C) { +func TestRegionWriteRate(t *testing.T) { + re := require.New(t) testcases := []struct { bytes uint64 keys uint64 @@ -284,25 +280,17 @@ func (s *testRegionInfoSuite) TestRegionWriteRate(c *C) { {0, 0, 500, 0, 0}, {10, 3, 500, 0, 0}, } - for _, t := range testcases { - r := NewRegionInfo(&metapb.Region{Id: 100}, nil, SetWrittenBytes(t.bytes), SetWrittenKeys(t.keys), SetReportInterval(t.interval)) + for _, test := range testcases { + r := NewRegionInfo(&metapb.Region{Id: 100}, nil, SetWrittenBytes(test.bytes), SetWrittenKeys(test.keys), SetReportInterval(test.interval)) bytesRate, keysRate := r.GetWriteRate() - c.Assert(bytesRate, Equals, t.expectBytesRate) - c.Assert(keysRate, Equals, t.expectKeysRate) + re.Equal(test.expectBytesRate, bytesRate) + re.Equal(test.expectKeysRate, keysRate) } } -var _ = Suite(&testRegionGuideSuite{}) - -type testRegionGuideSuite struct { - RegionGuide RegionGuideFunc -} - -func (s *testRegionGuideSuite) SetUpSuite(c *C) { - s.RegionGuide = GenerateRegionGuideFunc(false) -} - -func (s *testRegionGuideSuite) TestNeedSync(c *C) { +func TestNeedSync(t *testing.T) { + re := require.New(t) + RegionGuide := GenerateRegionGuideFunc(false) meta := &metapb.Region{ Id: 1000, StartKey: []byte("a"), @@ -369,41 +357,38 @@ func (s *testRegionGuideSuite) TestNeedSync(c *C) { }, } - for _, t := range testcases { - regionA := region.Clone(t.optionsA...) - regionB := region.Clone(t.optionsB...) - _, _, _, needSync := s.RegionGuide(regionA, regionB) - c.Assert(needSync, Equals, t.needSync) + for _, test := range testcases { + regionA := region.Clone(test.optionsA...) + regionB := region.Clone(test.optionsB...) + _, _, _, needSync := RegionGuide(regionA, regionB) + re.Equal(test.needSync, needSync) } } -var _ = Suite(&testRegionMapSuite{}) - -type testRegionMapSuite struct{} - -func (s *testRegionMapSuite) TestRegionMap(c *C) { +func TestRegionMap(t *testing.T) { + re := require.New(t) rm := newRegionMap() - s.check(c, rm) - rm.AddNew(s.regionInfo(1)) - s.check(c, rm, 1) + check(re, rm) + rm.AddNew(regionInfo(1)) + check(re, rm, 1) - rm.AddNew(s.regionInfo(2)) - rm.AddNew(s.regionInfo(3)) - s.check(c, rm, 1, 2, 3) + rm.AddNew(regionInfo(2)) + rm.AddNew(regionInfo(3)) + check(re, rm, 1, 2, 3) - rm.AddNew(s.regionInfo(3)) + rm.AddNew(regionInfo(3)) rm.Delete(4) - s.check(c, rm, 1, 2, 3) + check(re, rm, 1, 2, 3) rm.Delete(3) rm.Delete(1) - s.check(c, rm, 2) + check(re, rm, 2) - rm.AddNew(s.regionInfo(3)) - s.check(c, rm, 2, 3) + rm.AddNew(regionInfo(3)) + check(re, rm, 2, 3) } -func (s *testRegionMapSuite) regionInfo(id uint64) *RegionInfo { +func regionInfo(id uint64) *RegionInfo { return &RegionInfo{ meta: &metapb.Region{ Id: id, @@ -413,13 +398,13 @@ func (s *testRegionMapSuite) regionInfo(id uint64) *RegionInfo { } } -func (s *testRegionMapSuite) check(c *C, rm regionMap, ids ...uint64) { +func check(re *require.Assertions, rm regionMap, ids ...uint64) { // Check Get. for _, id := range ids { - c.Assert(rm.Get(id).region.GetID(), Equals, id) + re.Equal(id, rm.Get(id).region.GetID()) } // Check Len. - c.Assert(rm.Len(), Equals, len(ids)) + re.Equal(len(ids), rm.Len()) // Check id set. expect := make(map[uint64]struct{}) for _, id := range ids { @@ -429,14 +414,11 @@ func (s *testRegionMapSuite) check(c *C, rm regionMap, ids ...uint64) { for _, r := range rm { set1[r.region.GetID()] = struct{}{} } - c.Assert(set1, DeepEquals, expect) + re.Equal(expect, set1) } -var _ = Suite(&testRegionKey{}) - -type testRegionKey struct{} - -func (*testRegionKey) TestRegionKey(c *C) { +func TestRegionKey(t *testing.T) { + re := require.New(t) testCase := []struct { key string expect string @@ -446,29 +428,30 @@ func (*testRegionKey) TestRegionKey(c *C) { {"\"\\x80\\x00\\x00\\x00\\x00\\x00\\x00\\xff\\x05\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\xf8\"", `80000000000000FF0500000000000000F8`}, } - for _, t := range testCase { - got, err := strconv.Unquote(t.key) - c.Assert(err, IsNil) + for _, test := range testCase { + got, err := strconv.Unquote(test.key) + re.NoError(err) s := fmt.Sprintln(RegionToHexMeta(&metapb.Region{StartKey: []byte(got)})) - c.Assert(strings.Contains(s, t.expect), IsTrue) + re.Contains(s, test.expect) // start key changed origin := NewRegionInfo(&metapb.Region{EndKey: []byte(got)}, nil) region := NewRegionInfo(&metapb.Region{StartKey: []byte(got), EndKey: []byte(got)}, nil) s = DiffRegionKeyInfo(origin, region) - c.Assert(s, Matches, ".*StartKey Changed.*") - c.Assert(strings.Contains(s, t.expect), IsTrue) + re.Regexp(".*StartKey Changed.*", s) + re.Contains(s, test.expect) // end key changed origin = NewRegionInfo(&metapb.Region{StartKey: []byte(got)}, nil) region = NewRegionInfo(&metapb.Region{StartKey: []byte(got), EndKey: []byte(got)}, nil) s = DiffRegionKeyInfo(origin, region) - c.Assert(s, Matches, ".*EndKey Changed.*") - c.Assert(strings.Contains(s, t.expect), IsTrue) + re.Regexp(".*EndKey Changed.*", s) + re.Contains(s, test.expect) } } -func (*testRegionKey) TestSetRegion(c *C) { +func TestSetRegion(t *testing.T) { + re := require.New(t) regions := NewRegionsInfo() for i := 0; i < 100; i++ { peer1 := &metapb.Peer{StoreId: uint64(i%5 + 1), Id: uint64(i*5 + 1)} @@ -495,9 +478,9 @@ func (*testRegionKey) TestSetRegion(c *C) { region.learners = append(region.learners, peer2) region.pendingPeers = append(region.pendingPeers, peer3) regions.SetRegion(region) - checkRegions(c, regions) - c.Assert(regions.tree.length(), Equals, 97) - c.Assert(regions.GetRegions(), HasLen, 97) + checkRegions(re, regions) + re.Equal(97, regions.tree.length()) + re.Len(regions.GetRegions(), 97) regions.SetRegion(region) peer1 = &metapb.Peer{StoreId: uint64(2), Id: uint64(101)} @@ -512,21 +495,21 @@ func (*testRegionKey) TestSetRegion(c *C) { region.learners = append(region.learners, peer2) region.pendingPeers = append(region.pendingPeers, peer3) regions.SetRegion(region) - checkRegions(c, regions) - c.Assert(regions.tree.length(), Equals, 97) - c.Assert(regions.GetRegions(), HasLen, 97) + checkRegions(re, regions) + re.Equal(97, regions.tree.length()) + re.Len(regions.GetRegions(), 97) // Test remove overlaps. region = region.Clone(WithStartKey([]byte(fmt.Sprintf("%20d", 175))), WithNewRegionID(201)) - c.Assert(regions.GetRegion(21), NotNil) - c.Assert(regions.GetRegion(18), NotNil) + re.NotNil(regions.GetRegion(21)) + re.NotNil(regions.GetRegion(18)) regions.SetRegion(region) - checkRegions(c, regions) - c.Assert(regions.tree.length(), Equals, 96) - c.Assert(regions.GetRegions(), HasLen, 96) - c.Assert(regions.GetRegion(201), NotNil) - c.Assert(regions.GetRegion(21), IsNil) - c.Assert(regions.GetRegion(18), IsNil) + checkRegions(re, regions) + re.Equal(96, regions.tree.length()) + re.Len(regions.GetRegions(), 96) + re.NotNil(regions.GetRegion(201)) + re.Nil(regions.GetRegion(21)) + re.Nil(regions.GetRegion(18)) // Test update keys and size of region. region = region.Clone( @@ -536,17 +519,18 @@ func (*testRegionKey) TestSetRegion(c *C) { SetWrittenKeys(10), SetReportInterval(5)) regions.SetRegion(region) - checkRegions(c, regions) - c.Assert(regions.tree.length(), Equals, 96) - c.Assert(regions.GetRegions(), HasLen, 96) - c.Assert(regions.GetRegion(201), NotNil) - c.Assert(regions.tree.TotalSize(), Equals, int64(30)) + checkRegions(re, regions) + re.Equal(96, regions.tree.length()) + re.Len(regions.GetRegions(), 96) + re.NotNil(regions.GetRegion(201)) + re.Equal(int64(30), regions.tree.TotalSize()) bytesRate, keysRate := regions.tree.TotalWriteRate() - c.Assert(bytesRate, Equals, float64(8)) - c.Assert(keysRate, Equals, float64(2)) + re.Equal(float64(8), bytesRate) + re.Equal(float64(2), keysRate) } -func (*testRegionKey) TestShouldRemoveFromSubTree(c *C) { +func TestShouldRemoveFromSubTree(t *testing.T) { + re := require.New(t) peer1 := &metapb.Peer{StoreId: uint64(1), Id: uint64(1)} peer2 := &metapb.Peer{StoreId: uint64(2), Id: uint64(2)} peer3 := &metapb.Peer{StoreId: uint64(3), Id: uint64(3)} @@ -564,28 +548,28 @@ func (*testRegionKey) TestShouldRemoveFromSubTree(c *C) { StartKey: []byte(fmt.Sprintf("%20d", 10)), EndKey: []byte(fmt.Sprintf("%20d", 20)), }, peer1) - c.Assert(region.peersEqualTo(origin), IsTrue) + re.True(region.peersEqualTo(origin)) region.leader = peer2 - c.Assert(region.peersEqualTo(origin), IsFalse) + re.False(region.peersEqualTo(origin)) region.leader = peer1 region.pendingPeers = append(region.pendingPeers, peer4) - c.Assert(region.peersEqualTo(origin), IsFalse) + re.False(region.peersEqualTo(origin)) region.pendingPeers = nil region.learners = append(region.learners, peer2) - c.Assert(region.peersEqualTo(origin), IsFalse) + re.False(region.peersEqualTo(origin)) origin.learners = append(origin.learners, peer2, peer3) region.learners = append(region.learners, peer4) - c.Assert(region.peersEqualTo(origin), IsTrue) + re.True(region.peersEqualTo(origin)) region.voters[2].StoreId = 4 - c.Assert(region.peersEqualTo(origin), IsFalse) + re.False(region.peersEqualTo(origin)) } -func checkRegions(c *C, regions *RegionsInfo) { +func checkRegions(re *require.Assertions, regions *RegionsInfo) { leaderMap := make(map[uint64]uint64) followerMap := make(map[uint64]uint64) learnerMap := make(map[uint64]uint64) @@ -619,16 +603,16 @@ func checkRegions(c *C, regions *RegionsInfo) { } } for key, value := range regions.leaders { - c.Assert(value.length(), Equals, int(leaderMap[key])) + re.Equal(int(leaderMap[key]), value.length()) } for key, value := range regions.followers { - c.Assert(value.length(), Equals, int(followerMap[key])) + re.Equal(int(followerMap[key]), value.length()) } for key, value := range regions.learners { - c.Assert(value.length(), Equals, int(learnerMap[key])) + re.Equal(int(learnerMap[key]), value.length()) } for key, value := range regions.pendingPeers { - c.Assert(value.length(), Equals, int(pendingPeerMap[key])) + re.Equal(int(pendingPeerMap[key]), value.length()) } } diff --git a/server/core/region_tree_test.go b/server/core/region_tree_test.go index 92c26744abf..0f813717fcb 100644 --- a/server/core/region_tree_test.go +++ b/server/core/region_tree_test.go @@ -19,16 +19,13 @@ import ( "math/rand" "testing" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testRegionSuite{}) - -type testRegionSuite struct{} - -func (s *testRegionSuite) TestRegionInfo(c *C) { +func TestRegionInfo(t *testing.T) { + re := require.New(t) n := uint64(3) peers := make([]*metapb.Peer, 0, n) @@ -51,104 +48,108 @@ func (s *testRegionSuite) TestRegionInfo(c *C) { WithPendingPeers([]*metapb.Peer{pendingPeer})) r := info.Clone() - c.Assert(r, DeepEquals, info) + re.Equal(info, r) for i := uint64(0); i < n; i++ { - c.Assert(r.GetPeer(i), Equals, r.meta.Peers[i]) + re.Equal(r.meta.Peers[i], r.GetPeer(i)) } - c.Assert(r.GetPeer(n), IsNil) - c.Assert(r.GetDownPeer(n), IsNil) - c.Assert(r.GetDownPeer(downPeer.GetId()), DeepEquals, downPeer) - c.Assert(r.GetPendingPeer(n), IsNil) - c.Assert(r.GetPendingPeer(pendingPeer.GetId()), DeepEquals, pendingPeer) + re.Nil(r.GetPeer(n)) + re.Nil(r.GetDownPeer(n)) + re.Equal(downPeer, r.GetDownPeer(downPeer.GetId())) + re.Nil(r.GetPendingPeer(n)) + re.Equal(pendingPeer, r.GetPendingPeer(pendingPeer.GetId())) for i := uint64(0); i < n; i++ { - c.Assert(r.GetStorePeer(i).GetStoreId(), Equals, i) + re.Equal(i, r.GetStorePeer(i).GetStoreId()) } - c.Assert(r.GetStorePeer(n), IsNil) + re.Nil(r.GetStorePeer(n)) removePeer := &metapb.Peer{ Id: n, StoreId: n, } r = r.Clone(SetPeers(append(r.meta.Peers, removePeer))) - c.Assert(DiffRegionPeersInfo(info, r), Matches, "Add peer.*") - c.Assert(DiffRegionPeersInfo(r, info), Matches, "Remove peer.*") - c.Assert(r.GetStorePeer(n), DeepEquals, removePeer) + re.Regexp("Add peer.*", DiffRegionPeersInfo(info, r)) + re.Regexp("Remove peer.*", DiffRegionPeersInfo(r, info)) + re.Equal(removePeer, r.GetStorePeer(n)) r = r.Clone(WithRemoveStorePeer(n)) - c.Assert(DiffRegionPeersInfo(r, info), Equals, "") - c.Assert(r.GetStorePeer(n), IsNil) + re.Equal("", DiffRegionPeersInfo(r, info)) + re.Nil(r.GetStorePeer(n)) r = r.Clone(WithStartKey([]byte{0})) - c.Assert(DiffRegionKeyInfo(r, info), Matches, "StartKey Changed.*") + re.Regexp("StartKey Changed.*", DiffRegionKeyInfo(r, info)) r = r.Clone(WithEndKey([]byte{1})) - c.Assert(DiffRegionKeyInfo(r, info), Matches, ".*EndKey Changed.*") + re.Regexp(".*EndKey Changed.*", DiffRegionKeyInfo(r, info)) stores := r.GetStoreIds() - c.Assert(stores, HasLen, int(n)) + re.Len(stores, int(n)) for i := uint64(0); i < n; i++ { _, ok := stores[i] - c.Assert(ok, IsTrue) + re.True(ok) } followers := r.GetFollowers() - c.Assert(followers, HasLen, int(n-1)) + re.Len(followers, int(n-1)) for i := uint64(1); i < n; i++ { - c.Assert(followers[peers[i].GetStoreId()], DeepEquals, peers[i]) + re.Equal(peers[i], followers[peers[i].GetStoreId()]) } } -func (s *testRegionSuite) TestRegionItem(c *C) { +func TestRegionItem(t *testing.T) { + re := require.New(t) item := newRegionItem([]byte("b"), []byte{}) - c.Assert(item.Less(newRegionItem([]byte("a"), []byte{})), IsFalse) - c.Assert(item.Less(newRegionItem([]byte("b"), []byte{})), IsFalse) - c.Assert(item.Less(newRegionItem([]byte("c"), []byte{})), IsTrue) + re.False(item.Less(newRegionItem([]byte("a"), []byte{}))) + re.False(item.Less(newRegionItem([]byte("b"), []byte{}))) + re.True(item.Less(newRegionItem([]byte("c"), []byte{}))) - c.Assert(item.Contains([]byte("a")), IsFalse) - c.Assert(item.Contains([]byte("b")), IsTrue) - c.Assert(item.Contains([]byte("c")), IsTrue) + re.False(item.Contains([]byte("a"))) + re.True(item.Contains([]byte("b"))) + re.True(item.Contains([]byte("c"))) item = newRegionItem([]byte("b"), []byte("d")) - c.Assert(item.Contains([]byte("a")), IsFalse) - c.Assert(item.Contains([]byte("b")), IsTrue) - c.Assert(item.Contains([]byte("c")), IsTrue) - c.Assert(item.Contains([]byte("d")), IsFalse) + re.False(item.Contains([]byte("a"))) + re.True(item.Contains([]byte("b"))) + re.True(item.Contains([]byte("c"))) + re.False(item.Contains([]byte("d"))) } -func (s *testRegionSuite) newRegionWithStat(start, end string, size, keys int64) *RegionInfo { +func newRegionWithStat(start, end string, size, keys int64) *RegionInfo { region := NewTestRegionInfo([]byte(start), []byte(end)) region.approximateSize, region.approximateKeys = size, keys return region } -func (s *testRegionSuite) TestRegionTreeStat(c *C) { +func TestRegionTreeStat(t *testing.T) { + re := require.New(t) tree := newRegionTree() - c.Assert(tree.totalSize, Equals, int64(0)) - updateNewItem(tree, s.newRegionWithStat("a", "b", 1, 2)) - c.Assert(tree.totalSize, Equals, int64(1)) - updateNewItem(tree, s.newRegionWithStat("b", "c", 3, 4)) - c.Assert(tree.totalSize, Equals, int64(4)) - updateNewItem(tree, s.newRegionWithStat("b", "e", 5, 6)) - c.Assert(tree.totalSize, Equals, int64(6)) - tree.remove(s.newRegionWithStat("a", "b", 1, 2)) - c.Assert(tree.totalSize, Equals, int64(5)) - tree.remove(s.newRegionWithStat("f", "g", 1, 2)) - c.Assert(tree.totalSize, Equals, int64(5)) + re.Equal(int64(0), tree.totalSize) + updateNewItem(tree, newRegionWithStat("a", "b", 1, 2)) + re.Equal(int64(1), tree.totalSize) + updateNewItem(tree, newRegionWithStat("b", "c", 3, 4)) + re.Equal(int64(4), tree.totalSize) + updateNewItem(tree, newRegionWithStat("b", "e", 5, 6)) + re.Equal(int64(6), tree.totalSize) + tree.remove(newRegionWithStat("a", "b", 1, 2)) + re.Equal(int64(5), tree.totalSize) + tree.remove(newRegionWithStat("f", "g", 1, 2)) + re.Equal(int64(5), tree.totalSize) } -func (s *testRegionSuite) TestRegionTreeMerge(c *C) { +func TestRegionTreeMerge(t *testing.T) { + re := require.New(t) tree := newRegionTree() - updateNewItem(tree, s.newRegionWithStat("a", "b", 1, 2)) - updateNewItem(tree, s.newRegionWithStat("b", "c", 3, 4)) - c.Assert(tree.totalSize, Equals, int64(4)) - updateNewItem(tree, s.newRegionWithStat("a", "c", 5, 5)) - c.Assert(tree.totalSize, Equals, int64(5)) + updateNewItem(tree, newRegionWithStat("a", "b", 1, 2)) + updateNewItem(tree, newRegionWithStat("b", "c", 3, 4)) + re.Equal(int64(4), tree.totalSize) + updateNewItem(tree, newRegionWithStat("a", "c", 5, 5)) + re.Equal(int64(5), tree.totalSize) } -func (s *testRegionSuite) TestRegionTree(c *C) { +func TestRegionTree(t *testing.T) { + re := require.New(t) tree := newRegionTree() - c.Assert(tree.search([]byte("a")), IsNil) + re.Nil(tree.search([]byte("a"))) regionA := NewTestRegionInfo([]byte("a"), []byte("b")) regionB := NewTestRegionInfo([]byte("b"), []byte("c")) @@ -157,86 +158,87 @@ func (s *testRegionSuite) TestRegionTree(c *C) { updateNewItem(tree, regionA) updateNewItem(tree, regionC) - c.Assert(tree.search([]byte{}), IsNil) - c.Assert(tree.search([]byte("a")), Equals, regionA) - c.Assert(tree.search([]byte("b")), IsNil) - c.Assert(tree.search([]byte("c")), Equals, regionC) - c.Assert(tree.search([]byte("d")), IsNil) + re.Nil(tree.search([]byte{})) + re.Equal(regionA, tree.search([]byte("a"))) + re.Nil(tree.search([]byte("b"))) + re.Equal(regionC, tree.search([]byte("c"))) + re.Nil(tree.search([]byte("d"))) // search previous region - c.Assert(tree.searchPrev([]byte("a")), IsNil) - c.Assert(tree.searchPrev([]byte("b")), IsNil) - c.Assert(tree.searchPrev([]byte("c")), IsNil) + re.Nil(tree.searchPrev([]byte("a"))) + re.Nil(tree.searchPrev([]byte("b"))) + re.Nil(tree.searchPrev([]byte("c"))) updateNewItem(tree, regionB) // search previous region - c.Assert(tree.searchPrev([]byte("c")), Equals, regionB) - c.Assert(tree.searchPrev([]byte("b")), Equals, regionA) + re.Equal(regionB, tree.searchPrev([]byte("c"))) + re.Equal(regionA, tree.searchPrev([]byte("b"))) tree.remove(regionC) updateNewItem(tree, regionD) - c.Assert(tree.search([]byte{}), IsNil) - c.Assert(tree.search([]byte("a")), Equals, regionA) - c.Assert(tree.search([]byte("b")), Equals, regionB) - c.Assert(tree.search([]byte("c")), IsNil) - c.Assert(tree.search([]byte("d")), Equals, regionD) + re.Nil(tree.search([]byte{})) + re.Equal(regionA, tree.search([]byte("a"))) + re.Equal(regionB, tree.search([]byte("b"))) + re.Nil(tree.search([]byte("c"))) + re.Equal(regionD, tree.search([]byte("d"))) // check get adjacent regions prev, next := tree.getAdjacentRegions(regionA) - c.Assert(prev, IsNil) - c.Assert(next.region, Equals, regionB) + re.Nil(prev) + re.Equal(regionB, next.region) prev, next = tree.getAdjacentRegions(regionB) - c.Assert(prev.region, Equals, regionA) - c.Assert(next.region, Equals, regionD) + re.Equal(regionA, prev.region) + re.Equal(regionD, next.region) prev, next = tree.getAdjacentRegions(regionC) - c.Assert(prev.region, Equals, regionB) - c.Assert(next.region, Equals, regionD) + re.Equal(regionB, prev.region) + re.Equal(regionD, next.region) prev, next = tree.getAdjacentRegions(regionD) - c.Assert(prev.region, Equals, regionB) - c.Assert(next, IsNil) + re.Equal(regionB, prev.region) + re.Nil(next) // region with the same range and different region id will not be delete. region0 := newRegionItem([]byte{}, []byte("a")).region updateNewItem(tree, region0) - c.Assert(tree.search([]byte{}), Equals, region0) + re.Equal(region0, tree.search([]byte{})) anotherRegion0 := newRegionItem([]byte{}, []byte("a")).region anotherRegion0.meta.Id = 123 tree.remove(anotherRegion0) - c.Assert(tree.search([]byte{}), Equals, region0) + re.Equal(region0, tree.search([]byte{})) // overlaps with 0, A, B, C. region0D := newRegionItem([]byte(""), []byte("d")).region updateNewItem(tree, region0D) - c.Assert(tree.search([]byte{}), Equals, region0D) - c.Assert(tree.search([]byte("a")), Equals, region0D) - c.Assert(tree.search([]byte("b")), Equals, region0D) - c.Assert(tree.search([]byte("c")), Equals, region0D) - c.Assert(tree.search([]byte("d")), Equals, regionD) + re.Equal(region0D, tree.search([]byte{})) + re.Equal(region0D, tree.search([]byte("a"))) + re.Equal(region0D, tree.search([]byte("b"))) + re.Equal(region0D, tree.search([]byte("c"))) + re.Equal(regionD, tree.search([]byte("d"))) // overlaps with D. regionE := newRegionItem([]byte("e"), []byte{}).region updateNewItem(tree, regionE) - c.Assert(tree.search([]byte{}), Equals, region0D) - c.Assert(tree.search([]byte("a")), Equals, region0D) - c.Assert(tree.search([]byte("b")), Equals, region0D) - c.Assert(tree.search([]byte("c")), Equals, region0D) - c.Assert(tree.search([]byte("d")), IsNil) - c.Assert(tree.search([]byte("e")), Equals, regionE) + re.Equal(region0D, tree.search([]byte{})) + re.Equal(region0D, tree.search([]byte("a"))) + re.Equal(region0D, tree.search([]byte("b"))) + re.Equal(region0D, tree.search([]byte("c"))) + re.Nil(tree.search([]byte("d"))) + re.Equal(regionE, tree.search([]byte("e"))) } -func updateRegions(c *C, tree *regionTree, regions []*RegionInfo) { +func updateRegions(re *require.Assertions, tree *regionTree, regions []*RegionInfo) { for _, region := range regions { updateNewItem(tree, region) - c.Assert(tree.search(region.GetStartKey()), Equals, region) + re.Equal(region, tree.search(region.GetStartKey())) if len(region.GetEndKey()) > 0 { end := region.GetEndKey()[0] - c.Assert(tree.search([]byte{end - 1}), Equals, region) - c.Assert(tree.search([]byte{end + 1}), Not(Equals), region) + re.Equal(region, tree.search([]byte{end - 1})) + re.NotEqual(region, tree.search([]byte{end + 1})) } } } -func (s *testRegionSuite) TestRegionTreeSplitAndMerge(c *C) { +func TestRegionTreeSplitAndMerge(t *testing.T) { + re := require.New(t) tree := newRegionTree() regions := []*RegionInfo{newRegionItem([]byte{}, []byte{}).region} @@ -246,13 +248,13 @@ func (s *testRegionSuite) TestRegionTreeSplitAndMerge(c *C) { // Split. for i := 0; i < n; i++ { regions = SplitRegions(regions) - updateRegions(c, tree, regions) + updateRegions(re, tree, regions) } // Merge. for i := 0; i < n; i++ { regions = MergeRegions(regions) - updateRegions(c, tree, regions) + updateRegions(re, tree, regions) } // Split twice and merge once. @@ -262,19 +264,20 @@ func (s *testRegionSuite) TestRegionTreeSplitAndMerge(c *C) { } else { regions = SplitRegions(regions) } - updateRegions(c, tree, regions) + updateRegions(re, tree, regions) } } -func (s *testRegionSuite) TestRandomRegion(c *C) { +func TestRandomRegion(t *testing.T) { + re := require.New(t) tree := newRegionTree() r := tree.RandomRegion(nil) - c.Assert(r, IsNil) + re.Nil(r) regionA := NewTestRegionInfo([]byte(""), []byte("g")) updateNewItem(tree, regionA) ra := tree.RandomRegion([]KeyRange{NewKeyRange("", "")}) - c.Assert(ra, DeepEquals, regionA) + re.Equal(regionA, ra) regionB := NewTestRegionInfo([]byte("g"), []byte("n")) regionC := NewTestRegionInfo([]byte("n"), []byte("t")) @@ -284,70 +287,71 @@ func (s *testRegionSuite) TestRandomRegion(c *C) { updateNewItem(tree, regionD) rb := tree.RandomRegion([]KeyRange{NewKeyRange("g", "n")}) - c.Assert(rb, DeepEquals, regionB) + re.Equal(regionB, rb) rc := tree.RandomRegion([]KeyRange{NewKeyRange("n", "t")}) - c.Assert(rc, DeepEquals, regionC) + re.Equal(regionC, rc) rd := tree.RandomRegion([]KeyRange{NewKeyRange("t", "")}) - c.Assert(rd, DeepEquals, regionD) - - re := tree.RandomRegion([]KeyRange{NewKeyRange("", "a")}) - c.Assert(re, IsNil) - re = tree.RandomRegion([]KeyRange{NewKeyRange("o", "s")}) - c.Assert(re, IsNil) - re = tree.RandomRegion([]KeyRange{NewKeyRange("", "a")}) - c.Assert(re, IsNil) - re = tree.RandomRegion([]KeyRange{NewKeyRange("z", "")}) - c.Assert(re, IsNil) - - checkRandomRegion(c, tree, []*RegionInfo{regionA, regionB, regionC, regionD}, []KeyRange{NewKeyRange("", "")}) - checkRandomRegion(c, tree, []*RegionInfo{regionA, regionB}, []KeyRange{NewKeyRange("", "n")}) - checkRandomRegion(c, tree, []*RegionInfo{regionC, regionD}, []KeyRange{NewKeyRange("n", "")}) - checkRandomRegion(c, tree, []*RegionInfo{}, []KeyRange{NewKeyRange("h", "s")}) - checkRandomRegion(c, tree, []*RegionInfo{regionB, regionC}, []KeyRange{NewKeyRange("a", "z")}) + re.Equal(regionD, rd) + + rf := tree.RandomRegion([]KeyRange{NewKeyRange("", "a")}) + re.Nil(rf) + rf = tree.RandomRegion([]KeyRange{NewKeyRange("o", "s")}) + re.Nil(rf) + rf = tree.RandomRegion([]KeyRange{NewKeyRange("", "a")}) + re.Nil(rf) + rf = tree.RandomRegion([]KeyRange{NewKeyRange("z", "")}) + re.Nil(rf) + + checkRandomRegion(re, tree, []*RegionInfo{regionA, regionB, regionC, regionD}, []KeyRange{NewKeyRange("", "")}) + checkRandomRegion(re, tree, []*RegionInfo{regionA, regionB}, []KeyRange{NewKeyRange("", "n")}) + checkRandomRegion(re, tree, []*RegionInfo{regionC, regionD}, []KeyRange{NewKeyRange("n", "")}) + checkRandomRegion(re, tree, []*RegionInfo{}, []KeyRange{NewKeyRange("h", "s")}) + checkRandomRegion(re, tree, []*RegionInfo{regionB, regionC}, []KeyRange{NewKeyRange("a", "z")}) } -func (s *testRegionSuite) TestRandomRegionDiscontinuous(c *C) { +func TestRandomRegionDiscontinuous(t *testing.T) { + re := require.New(t) tree := newRegionTree() r := tree.RandomRegion([]KeyRange{NewKeyRange("c", "f")}) - c.Assert(r, IsNil) + re.Nil(r) // test for single region regionA := NewTestRegionInfo([]byte("c"), []byte("f")) updateNewItem(tree, regionA) ra := tree.RandomRegion([]KeyRange{NewKeyRange("c", "e")}) - c.Assert(ra, IsNil) + re.Nil(ra) ra = tree.RandomRegion([]KeyRange{NewKeyRange("c", "f")}) - c.Assert(ra, DeepEquals, regionA) + re.Equal(regionA, ra) ra = tree.RandomRegion([]KeyRange{NewKeyRange("c", "g")}) - c.Assert(ra, DeepEquals, regionA) + re.Equal(regionA, ra) ra = tree.RandomRegion([]KeyRange{NewKeyRange("a", "e")}) - c.Assert(ra, IsNil) + re.Nil(ra) ra = tree.RandomRegion([]KeyRange{NewKeyRange("a", "f")}) - c.Assert(ra, DeepEquals, regionA) + re.Equal(regionA, ra) ra = tree.RandomRegion([]KeyRange{NewKeyRange("a", "g")}) - c.Assert(ra, DeepEquals, regionA) + re.Equal(regionA, ra) regionB := NewTestRegionInfo([]byte("n"), []byte("x")) updateNewItem(tree, regionB) rb := tree.RandomRegion([]KeyRange{NewKeyRange("g", "x")}) - c.Assert(rb, DeepEquals, regionB) + re.Equal(regionB, rb) rb = tree.RandomRegion([]KeyRange{NewKeyRange("g", "y")}) - c.Assert(rb, DeepEquals, regionB) + re.Equal(regionB, rb) rb = tree.RandomRegion([]KeyRange{NewKeyRange("n", "y")}) - c.Assert(rb, DeepEquals, regionB) + re.Equal(regionB, rb) rb = tree.RandomRegion([]KeyRange{NewKeyRange("o", "y")}) - c.Assert(rb, IsNil) + re.Nil(rb) regionC := NewTestRegionInfo([]byte("z"), []byte("")) updateNewItem(tree, regionC) rc := tree.RandomRegion([]KeyRange{NewKeyRange("y", "")}) - c.Assert(rc, DeepEquals, regionC) + re.Equal(regionC, rc) regionD := NewTestRegionInfo([]byte(""), []byte("a")) updateNewItem(tree, regionD) rd := tree.RandomRegion([]KeyRange{NewKeyRange("", "b")}) - c.Assert(rd, DeepEquals, regionD) + re.Equal(regionD, rd) - checkRandomRegion(c, tree, []*RegionInfo{regionA, regionB, regionC, regionD}, []KeyRange{NewKeyRange("", "")}) + checkRandomRegion(re, tree, []*RegionInfo{regionA, regionB, regionC, regionD}, []KeyRange{NewKeyRange("", "")}) } func updateNewItem(tree *regionTree, region *RegionInfo) { @@ -355,7 +359,7 @@ func updateNewItem(tree *regionTree, region *RegionInfo) { tree.update(item) } -func checkRandomRegion(c *C, tree *regionTree, regions []*RegionInfo, ranges []KeyRange) { +func checkRandomRegion(re *require.Assertions, tree *regionTree, regions []*RegionInfo, ranges []KeyRange) { keys := make(map[string]struct{}) for i := 0; i < 10000 && len(keys) < len(regions); i++ { re := tree.RandomRegion(ranges) @@ -369,9 +373,9 @@ func checkRandomRegion(c *C, tree *regionTree, regions []*RegionInfo, ranges []K } for _, region := range regions { _, ok := keys[string(region.GetStartKey())] - c.Assert(ok, IsTrue) + re.True(ok) } - c.Assert(keys, HasLen, len(regions)) + re.Len(keys, len(regions)) } func newRegionItem(start, end []byte) *regionItem { diff --git a/server/core/store_stats_test.go b/server/core/store_stats_test.go index 7b046e3d0c6..82598fd9347 100644 --- a/server/core/store_stats_test.go +++ b/server/core/store_stats_test.go @@ -15,16 +15,15 @@ package core import ( - . "github.com/pingcap/check" + "testing" + "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testStoreStatsSuite{}) - -type testStoreStatsSuite struct{} - -func (s *testStoreStatsSuite) TestStoreStats(c *C) { +func TestStoreStats(t *testing.T) { + re := require.New(t) G := uint64(1024 * 1024 * 1024) meta := &metapb.Store{Id: 1, State: metapb.StoreState_Up} store := NewStoreInfo(meta, SetStoreStats(&pdpb.StoreStats{ @@ -33,11 +32,11 @@ func (s *testStoreStatsSuite) TestStoreStats(c *C) { Available: 150 * G, })) - c.Assert(store.GetCapacity(), Equals, 200*G) - c.Assert(store.GetUsedSize(), Equals, 50*G) - c.Assert(store.GetAvailable(), Equals, 150*G) - c.Assert(store.GetAvgAvailable(), Equals, 150*G) - c.Assert(store.GetAvailableDeviation(), Equals, uint64(0)) + re.Equal(200*G, store.GetCapacity()) + re.Equal(50*G, store.GetUsedSize()) + re.Equal(150*G, store.GetAvailable()) + re.Equal(150*G, store.GetAvgAvailable()) + re.Equal(uint64(0), store.GetAvailableDeviation()) store = store.Clone(SetStoreStats(&pdpb.StoreStats{ Capacity: 200 * G, @@ -45,9 +44,9 @@ func (s *testStoreStatsSuite) TestStoreStats(c *C) { Available: 160 * G, })) - c.Assert(store.GetAvailable(), Equals, 160*G) - c.Assert(store.GetAvgAvailable(), Greater, 150*G) - c.Assert(store.GetAvgAvailable(), Less, 160*G) - c.Assert(store.GetAvailableDeviation(), Greater, uint64(0)) - c.Assert(store.GetAvailableDeviation(), Less, 10*G) + re.Equal(160*G, store.GetAvailable()) + re.Greater(store.GetAvgAvailable(), 150*G) + re.Less(store.GetAvgAvailable(), 160*G) + re.Greater(store.GetAvailableDeviation(), uint64(0)) + re.Less(store.GetAvailableDeviation(), 10*G) } diff --git a/server/core/store_test.go b/server/core/store_test.go index 2bed3783f9e..b311315a29e 100644 --- a/server/core/store_test.go +++ b/server/core/store_test.go @@ -17,18 +17,16 @@ package core import ( "math" "sync" + "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testDistinctScoreSuite{}) - -type testDistinctScoreSuite struct{} - -func (s *testDistinctScoreSuite) TestDistinctScore(c *C) { +func TestDistinctScore(t *testing.T) { + re := require.New(t) labels := []string{"zone", "rack", "host"} zones := []string{"z1", "z2", "z3"} racks := []string{"r1", "r2", "r3"} @@ -54,19 +52,15 @@ func (s *testDistinctScoreSuite) TestDistinctScore(c *C) { // Number of stores in the same rack but in different hosts. numHosts := k score := (numZones*replicaBaseScore+numRacks)*replicaBaseScore + numHosts - c.Assert(DistinctScore(labels, stores, store), Equals, float64(score)) + re.Equal(float64(score), DistinctScore(labels, stores, store)) } } } store := NewStoreInfoWithLabel(100, 1, nil) - c.Assert(DistinctScore(labels, stores, store), Equals, float64(0)) + re.Equal(float64(0), DistinctScore(labels, stores, store)) } -var _ = Suite(&testConcurrencySuite{}) - -type testConcurrencySuite struct{} - -func (s *testConcurrencySuite) TestCloneStore(c *C) { +func TestCloneStore(t *testing.T) { meta := &metapb.Store{Id: 1, Address: "mock://tikv-1", Labels: []*metapb.StoreLabel{{Key: "zone", Value: "z1"}, {Key: "host", Value: "h1"}}} store := NewStoreInfo(meta) start := time.Now() @@ -96,11 +90,8 @@ func (s *testConcurrencySuite) TestCloneStore(c *C) { wg.Wait() } -var _ = Suite(&testStoreSuite{}) - -type testStoreSuite struct{} - -func (s *testStoreSuite) TestRegionScore(c *C) { +func TestRegionScore(t *testing.T) { + re := require.New(t) stats := &pdpb.StoreStats{} stats.Capacity = 512 * (1 << 20) // 512 MB stats.Available = 100 * (1 << 20) // 100 MB @@ -113,22 +104,24 @@ func (s *testStoreSuite) TestRegionScore(c *C) { ) score := store.RegionScore("v1", 0.7, 0.9, 0) // Region score should never be NaN, or /store API would fail. - c.Assert(math.IsNaN(score), IsFalse) + re.False(math.IsNaN(score)) } -func (s *testStoreSuite) TestLowSpaceRatio(c *C) { +func TestLowSpaceRatio(t *testing.T) { + re := require.New(t) store := NewStoreInfoWithLabel(1, 20, nil) store.rawStats.Capacity = initialMinSpace << 4 store.rawStats.Available = store.rawStats.Capacity >> 3 - c.Assert(store.IsLowSpace(0.8), IsFalse) + re.False(store.IsLowSpace(0.8)) store.regionCount = 31 - c.Assert(store.IsLowSpace(0.8), IsTrue) + re.True(store.IsLowSpace(0.8)) store.rawStats.Available = store.rawStats.Capacity >> 2 - c.Assert(store.IsLowSpace(0.8), IsFalse) + re.False(store.IsLowSpace(0.8)) } -func (s *testStoreSuite) TestLowSpaceScoreV2(c *C) { +func TestLowSpaceScoreV2(t *testing.T) { + re := require.New(t) testdata := []struct { bigger *StoreInfo small *StoreInfo @@ -172,6 +165,6 @@ func (s *testStoreSuite) TestLowSpaceScoreV2(c *C) { for _, v := range testdata { score1 := v.bigger.regionScoreV2(0, 0.8) score2 := v.small.regionScoreV2(0, 0.8) - c.Assert(score1, Greater, score2) + re.Greater(score1, score2) } } From ec1fbdaafd6e23d6041ba126765e0f65321d8f85 Mon Sep 17 00:00:00 2001 From: LLThomas Date: Thu, 9 Jun 2022 12:20:30 +0800 Subject: [PATCH 32/82] filter: migrate test framework to testify (#5133) ref tikv/pd#4813 Signed-off-by: LLThomas --- server/schedule/filter/candidates_test.go | 50 +++++++------ server/schedule/filter/filters_test.go | 87 +++++++++++------------ 2 files changed, 66 insertions(+), 71 deletions(-) diff --git a/server/schedule/filter/candidates_test.go b/server/schedule/filter/candidates_test.go index 86fd35e739e..5150bed9b66 100644 --- a/server/schedule/filter/candidates_test.go +++ b/server/schedule/filter/candidates_test.go @@ -15,8 +15,9 @@ package filter import ( - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/require" + "testing" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/core" @@ -55,45 +56,42 @@ func (f idFilter) Target(opt *config.PersistOptions, store *core.StoreInfo) bool return f(store.GetID()) } -type testCandidatesSuite struct{} - -var _ = Suite(&testCandidatesSuite{}) - -func (s *testCandidatesSuite) TestCandidates(c *C) { - cs := s.newCandidates(1, 2, 3, 4, 5) +func TestCandidates(t *testing.T) { + re := require.New(t) + cs := newTestCandidates(1, 2, 3, 4, 5) cs.FilterSource(nil, idFilter(func(id uint64) bool { return id > 2 })) - s.check(c, cs, 3, 4, 5) + check(re, cs, 3, 4, 5) cs.FilterTarget(nil, idFilter(func(id uint64) bool { return id%2 == 1 })) - s.check(c, cs, 3, 5) + check(re, cs, 3, 5) cs.FilterTarget(nil, idFilter(func(id uint64) bool { return id > 100 })) - s.check(c, cs) + check(re, cs) store := cs.PickFirst() - c.Assert(store, IsNil) + re.Nil(store) store = cs.RandomPick() - c.Assert(store, IsNil) + re.Nil(store) - cs = s.newCandidates(1, 3, 5, 7, 6, 2, 4) + cs = newTestCandidates(1, 3, 5, 7, 6, 2, 4) cs.Sort(idComparer) - s.check(c, cs, 1, 2, 3, 4, 5, 6, 7) + check(re, cs, 1, 2, 3, 4, 5, 6, 7) store = cs.PickFirst() - c.Assert(store.GetID(), Equals, uint64(1)) + re.Equal(uint64(1), store.GetID()) cs.Reverse() - s.check(c, cs, 7, 6, 5, 4, 3, 2, 1) + check(re, cs, 7, 6, 5, 4, 3, 2, 1) store = cs.PickFirst() - c.Assert(store.GetID(), Equals, uint64(7)) + re.Equal(uint64(7), store.GetID()) cs.Shuffle() cs.Sort(idComparer) - s.check(c, cs, 1, 2, 3, 4, 5, 6, 7) + check(re, cs, 1, 2, 3, 4, 5, 6, 7) store = cs.RandomPick() - c.Assert(store.GetID(), Greater, uint64(0)) - c.Assert(store.GetID(), Less, uint64(8)) + re.Greater(store.GetID(), uint64(0)) + re.Less(store.GetID(), uint64(8)) - cs = s.newCandidates(10, 15, 23, 20, 33, 32, 31) + cs = newTestCandidates(10, 15, 23, 20, 33, 32, 31) cs.Sort(idComparer).Reverse().Top(idComparer2) - s.check(c, cs, 33, 32, 31) + check(re, cs, 33, 32, 31) } -func (s *testCandidatesSuite) newCandidates(ids ...uint64) *StoreCandidates { +func newTestCandidates(ids ...uint64) *StoreCandidates { stores := make([]*core.StoreInfo, 0, len(ids)) for _, id := range ids { stores = append(stores, core.NewStoreInfo(&metapb.Store{Id: id})) @@ -101,9 +99,9 @@ func (s *testCandidatesSuite) newCandidates(ids ...uint64) *StoreCandidates { return NewCandidates(stores) } -func (s *testCandidatesSuite) check(c *C, candidates *StoreCandidates, ids ...uint64) { - c.Assert(candidates.Stores, HasLen, len(ids)) +func check(re *require.Assertions, candidates *StoreCandidates, ids ...uint64) { + re.Len(candidates.Stores, len(ids)) for i, s := range candidates.Stores { - c.Assert(s.GetID(), Equals, ids[i]) + re.Equal(ids[i], s.GetID()) } } diff --git a/server/schedule/filter/filters_test.go b/server/schedule/filter/filters_test.go index 0b44b8cb258..31da16f6ff6 100644 --- a/server/schedule/filter/filters_test.go +++ b/server/schedule/filter/filters_test.go @@ -18,35 +18,17 @@ import ( "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/mock/mockcluster" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/core" "github.com/tikv/pd/server/schedule/placement" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testFiltersSuite{}) - -type testFiltersSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *testFiltersSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) -} - -func (s *testFiltersSuite) TearDownTest(c *C) { - s.cancel() -} - -func (s *testFiltersSuite) TestDistinctScoreFilter(c *C) { +func TestDistinctScoreFilter(t *testing.T) { + re := require.New(t) labels := []string{"zone", "rack", "host"} allStores := []*core.StoreInfo{ core.NewStoreInfoWithLabel(1, 1, map[string]string{"zone": "z1", "rack": "r1", "host": "h1"}), @@ -75,14 +57,18 @@ func (s *testFiltersSuite) TestDistinctScoreFilter(c *C) { } ls := NewLocationSafeguard("", labels, stores, allStores[tc.source-1]) li := NewLocationImprover("", labels, stores, allStores[tc.source-1]) - c.Assert(ls.Target(config.NewTestOptions(), allStores[tc.target-1]), Equals, tc.safeGuardRes) - c.Assert(li.Target(config.NewTestOptions(), allStores[tc.target-1]), Equals, tc.improverRes) + re.Equal(tc.safeGuardRes, ls.Target(config.NewTestOptions(), allStores[tc.target-1])) + re.Equal(tc.improverRes, li.Target(config.NewTestOptions(), allStores[tc.target-1])) } } -func (s *testFiltersSuite) TestLabelConstraintsFilter(c *C) { +func TestLabelConstraintsFilter(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + opt := config.NewTestOptions() - testCluster := mockcluster.NewCluster(s.ctx, opt) + testCluster := mockcluster.NewCluster(ctx, opt) store := core.NewStoreInfoWithLabel(1, 1, map[string]string{"id": "1"}) testCases := []struct { @@ -103,14 +89,18 @@ func (s *testFiltersSuite) TestLabelConstraintsFilter(c *C) { } for _, tc := range testCases { filter := NewLabelConstaintFilter("", []placement.LabelConstraint{{Key: tc.key, Op: placement.LabelConstraintOp(tc.op), Values: tc.values}}) - c.Assert(filter.Source(testCluster.GetOpts(), store), Equals, tc.res) + re.Equal(tc.res, filter.Source(testCluster.GetOpts(), store)) } } -func (s *testFiltersSuite) TestRuleFitFilter(c *C) { +func TestRuleFitFilter(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + opt := config.NewTestOptions() opt.SetPlacementRuleEnabled(false) - testCluster := mockcluster.NewCluster(s.ctx, opt) + testCluster := mockcluster.NewCluster(ctx, opt) testCluster.SetLocationLabels([]string{"zone"}) testCluster.SetEnablePlacementRules(true) region := core.NewRegionInfo(&metapb.Region{Peers: []*metapb.Peer{ @@ -139,12 +129,13 @@ func (s *testFiltersSuite) TestRuleFitFilter(c *C) { } for _, tc := range testCases { filter := newRuleFitFilter("", testCluster.GetBasicCluster(), testCluster.GetRuleManager(), region, 1) - c.Assert(filter.Source(testCluster.GetOpts(), testCluster.GetStore(tc.storeID)), Equals, tc.sourceRes) - c.Assert(filter.Target(testCluster.GetOpts(), testCluster.GetStore(tc.storeID)), Equals, tc.targetRes) + re.Equal(tc.sourceRes, filter.Source(testCluster.GetOpts(), testCluster.GetStore(tc.storeID))) + re.Equal(tc.targetRes, filter.Target(testCluster.GetOpts(), testCluster.GetStore(tc.storeID))) } } -func (s *testFiltersSuite) TestStoreStateFilter(c *C) { +func TestStoreStateFilter(t *testing.T) { + re := require.New(t) filters := []Filter{ &StoreStateFilter{TransferLeader: true}, &StoreStateFilter{MoveRegion: true}, @@ -162,8 +153,8 @@ func (s *testFiltersSuite) TestStoreStateFilter(c *C) { check := func(store *core.StoreInfo, testCases []testCase) { for _, tc := range testCases { - c.Assert(filters[tc.filterIdx].Source(opt, store), Equals, tc.sourceRes) - c.Assert(filters[tc.filterIdx].Target(opt, store), Equals, tc.targetRes) + re.Equal(tc.sourceRes, filters[tc.filterIdx].Source(opt, store)) + re.Equal(tc.targetRes, filters[tc.filterIdx].Target(opt, store)) } } @@ -195,9 +186,13 @@ func (s *testFiltersSuite) TestStoreStateFilter(c *C) { check(store, testCases) } -func (s *testFiltersSuite) TestIsolationFilter(c *C) { +func TestIsolationFilter(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + opt := config.NewTestOptions() - testCluster := mockcluster.NewCluster(s.ctx, opt) + testCluster := mockcluster.NewCluster(ctx, opt) testCluster.SetLocationLabels([]string{"zone", "rack", "host"}) allStores := []struct { storeID uint64 @@ -256,16 +251,20 @@ func (s *testFiltersSuite) TestIsolationFilter(c *C) { for _, tc := range testCases { filter := NewIsolationFilter("", tc.isolationLevel, testCluster.GetLocationLabels(), testCluster.GetRegionStores(tc.region)) for idx, store := range allStores { - c.Assert(filter.Source(testCluster.GetOpts(), testCluster.GetStore(store.storeID)), Equals, tc.sourceRes[idx]) - c.Assert(filter.Target(testCluster.GetOpts(), testCluster.GetStore(store.storeID)), Equals, tc.targetRes[idx]) + re.Equal(tc.sourceRes[idx], filter.Source(testCluster.GetOpts(), testCluster.GetStore(store.storeID))) + re.Equal(tc.targetRes[idx], filter.Target(testCluster.GetOpts(), testCluster.GetStore(store.storeID))) } } } -func (s *testFiltersSuite) TestPlacementGuard(c *C) { +func TestPlacementGuard(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + opt := config.NewTestOptions() opt.SetPlacementRuleEnabled(false) - testCluster := mockcluster.NewCluster(s.ctx, opt) + testCluster := mockcluster.NewCluster(ctx, opt) testCluster.SetLocationLabels([]string{"zone"}) testCluster.AddLabelsStore(1, 1, map[string]string{"zone": "z1"}) testCluster.AddLabelsStore(2, 1, map[string]string{"zone": "z1"}) @@ -279,13 +278,11 @@ func (s *testFiltersSuite) TestPlacementGuard(c *C) { }}, &metapb.Peer{StoreId: 1, Id: 1}) store := testCluster.GetStore(1) - c.Assert(NewPlacementSafeguard("", testCluster.GetOpts(), testCluster.GetBasicCluster(), testCluster.GetRuleManager(), region, store), - FitsTypeOf, - NewLocationSafeguard("", []string{"zone"}, testCluster.GetRegionStores(region), store)) + re.IsType(NewLocationSafeguard("", []string{"zone"}, testCluster.GetRegionStores(region), store), + NewPlacementSafeguard("", testCluster.GetOpts(), testCluster.GetBasicCluster(), testCluster.GetRuleManager(), region, store)) testCluster.SetEnablePlacementRules(true) - c.Assert(NewPlacementSafeguard("", testCluster.GetOpts(), testCluster.GetBasicCluster(), testCluster.GetRuleManager(), region, store), - FitsTypeOf, - newRuleFitFilter("", testCluster.GetBasicCluster(), testCluster.GetRuleManager(), region, 1)) + re.IsType(newRuleFitFilter("", testCluster.GetBasicCluster(), testCluster.GetRuleManager(), region, 1), + NewPlacementSafeguard("", testCluster.GetOpts(), testCluster.GetBasicCluster(), testCluster.GetRuleManager(), region, store)) } func BenchmarkCloneRegionTest(b *testing.B) { From 5ace930a4bcc33cce752e6174ff3c1c5746119be Mon Sep 17 00:00:00 2001 From: JmPotato Date: Thu, 9 Jun 2022 14:54:31 +0800 Subject: [PATCH 33/82] server, tests: remove EnableZap (#5119) close tikv/pd#5118 Remove `EnableZap`. Signed-off-by: JmPotato Co-authored-by: Ti Chi Robot --- server/api/server_test.go | 1 - server/server.go | 16 ++-------------- server/server_test.go | 1 - tests/compatibility/version_upgrade_test.go | 1 - tests/dashboard/race_test.go | 3 --- tests/dashboard/service_test.go | 3 --- tests/pdbackup/backup_test.go | 5 ----- tests/pdctl/cluster/cluster_test.go | 5 ----- tests/pdctl/config/config_test.go | 5 ----- tests/pdctl/global_test.go | 4 ---- tests/pdctl/health/health_test.go | 5 ----- tests/pdctl/hot/hot_test.go | 5 ----- tests/pdctl/label/label_test.go | 5 ----- tests/pdctl/log/log_test.go | 1 - tests/pdctl/member/member_test.go | 5 ----- tests/pdctl/operator/operator_test.go | 5 ----- tests/pdctl/region/region_test.go | 5 ----- tests/pdctl/scheduler/scheduler_test.go | 2 -- tests/pdctl/store/store_test.go | 5 ----- tests/pdctl/tso/tso_test.go | 5 ----- tests/server/api/api_test.go | 6 ------ tests/server/cluster/cluster_test.go | 1 - tests/server/cluster/cluster_work_test.go | 2 -- tests/server/global_config/global_config_test.go | 1 - tests/server/id/id_test.go | 2 -- tests/server/join/join_fail/join_fail_test.go | 5 ----- tests/server/join/join_test.go | 1 - tests/server/member/member_test.go | 1 - tests/server/region_syncer/region_syncer_test.go | 2 -- tests/server/server_test.go | 2 -- tests/server/storage/hot_region_storage_test.go | 5 ----- tests/server/tso/allocator_test.go | 2 -- tests/server/tso/consistency_test.go | 2 -- tests/server/tso/global_tso_test.go | 2 -- tests/server/tso/manager_test.go | 2 -- tests/server/tso/tso_test.go | 2 -- tests/server/watch/leader_watch_test.go | 2 -- 37 files changed, 2 insertions(+), 125 deletions(-) diff --git a/server/api/server_test.go b/server/api/server_test.go index 8d9f1b4c227..a4c6b6de6fb 100644 --- a/server/api/server_test.go +++ b/server/api/server_test.go @@ -62,7 +62,6 @@ var ( ) func TestAPIServer(t *testing.T) { - server.EnableZap = true TestingT(t) } diff --git a/server/server.go b/server/server.go index c27941f7c85..a63c8c52525 100644 --- a/server/server.go +++ b/server/server.go @@ -82,12 +82,8 @@ const ( pdClusterIDPath = "/pd/cluster_id" ) -var ( - // EnableZap enable the zap logger in embed etcd. - EnableZap = false - // EtcdStartTimeout the timeout of the startup etcd. - EtcdStartTimeout = time.Minute * 5 -) +// EtcdStartTimeout the timeout of the startup etcd. +var EtcdStartTimeout = time.Minute * 5 // Server is the pd server. // nolint @@ -282,14 +278,6 @@ func CreateServer(ctx context.Context, cfg *config.Config, serviceBuilders ...Ha diagnosticspb.RegisterDiagnosticsServer(gs, s) } s.etcdCfg = etcdCfg - if EnableZap { - // The etcd master version has removed embed.Config.SetupLogging. - // Now logger is set up automatically based on embed.Config.Logger, - // Use zap logger in the test, otherwise will panic. - // Reference: https://go.etcd.io/etcd/blob/master/embed/config_logging.go#L45 - s.etcdCfg.Logger = "zap" - s.etcdCfg.LogOutputs = []string{"stdout"} - } s.lg = cfg.GetZapLogger() s.logProps = cfg.GetZapLogProperties() return s, nil diff --git a/server/server_test.go b/server/server_test.go index f433ac2dc31..c6c14fe011f 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -33,7 +33,6 @@ import ( ) func TestServer(t *testing.T) { - EnableZap = true TestingT(t) } diff --git a/tests/compatibility/version_upgrade_test.go b/tests/compatibility/version_upgrade_test.go index 03e34697084..2fcdc1bf5d5 100644 --- a/tests/compatibility/version_upgrade_test.go +++ b/tests/compatibility/version_upgrade_test.go @@ -39,7 +39,6 @@ type compatibilityTestSuite struct { func (s *compatibilityTestSuite) SetUpSuite(c *C) { s.ctx, s.cancel = context.WithCancel(context.Background()) - server.EnableZap = true } func (s *compatibilityTestSuite) TearDownSuite(c *C) { diff --git a/tests/dashboard/race_test.go b/tests/dashboard/race_test.go index ee72f8dbb61..9ca82567b52 100644 --- a/tests/dashboard/race_test.go +++ b/tests/dashboard/race_test.go @@ -21,7 +21,6 @@ import ( . "github.com/pingcap/check" "github.com/tikv/pd/pkg/dashboard" - "github.com/tikv/pd/server" "github.com/tikv/pd/tests" // Register schedulers. @@ -33,14 +32,12 @@ var _ = Suite(&raceTestSuite{}) type raceTestSuite struct{} func (s *raceTestSuite) SetUpSuite(c *C) { - server.EnableZap = true dashboard.SetCheckInterval(50 * time.Millisecond) tests.WaitLeaderReturnDelay = 0 tests.WaitLeaderCheckInterval = 20 * time.Millisecond } func (s *raceTestSuite) TearDownSuite(c *C) { - server.EnableZap = false dashboard.SetCheckInterval(time.Second) tests.WaitLeaderReturnDelay = 20 * time.Millisecond tests.WaitLeaderCheckInterval = 500 * time.Millisecond diff --git a/tests/dashboard/service_test.go b/tests/dashboard/service_test.go index 911cf80be30..f11f39e466a 100644 --- a/tests/dashboard/service_test.go +++ b/tests/dashboard/service_test.go @@ -27,7 +27,6 @@ import ( "github.com/tikv/pd/pkg/dashboard" "github.com/tikv/pd/pkg/testutil" - "github.com/tikv/pd/server" "github.com/tikv/pd/server/config" "github.com/tikv/pd/tests" "github.com/tikv/pd/tests/pdctl" @@ -54,7 +53,6 @@ type dashboardTestSuite struct { } func (s *dashboardTestSuite) SetUpSuite(c *C) { - server.EnableZap = true dashboard.SetCheckInterval(10 * time.Millisecond) s.ctx, s.cancel = context.WithCancel(context.Background()) s.httpClient = &http.Client{ @@ -71,7 +69,6 @@ func (s *dashboardTestSuite) SetUpSuite(c *C) { func (s *dashboardTestSuite) TearDownSuite(c *C) { s.cancel() s.httpClient.CloseIdleConnections() - server.EnableZap = false dashboard.SetCheckInterval(time.Second) } diff --git a/tests/pdbackup/backup_test.go b/tests/pdbackup/backup_test.go index 49e348699b7..a36cf89f44d 100644 --- a/tests/pdbackup/backup_test.go +++ b/tests/pdbackup/backup_test.go @@ -23,7 +23,6 @@ import ( "time" . "github.com/pingcap/check" - "github.com/tikv/pd/server" "github.com/tikv/pd/tests" "github.com/tikv/pd/tools/pd-backup/pdbackup" "go.etcd.io/etcd/clientv3" @@ -37,10 +36,6 @@ var _ = Suite(&backupTestSuite{}) type backupTestSuite struct{} -func (s *backupTestSuite) SetUpSuite(c *C) { - server.EnableZap = true -} - func (s *backupTestSuite) TestBackup(c *C) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/tests/pdctl/cluster/cluster_test.go b/tests/pdctl/cluster/cluster_test.go index 4f1b67b63e7..4b8cceb3bc5 100644 --- a/tests/pdctl/cluster/cluster_test.go +++ b/tests/pdctl/cluster/cluster_test.go @@ -23,7 +23,6 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" - "github.com/tikv/pd/server" clusterpkg "github.com/tikv/pd/server/cluster" "github.com/tikv/pd/tests" "github.com/tikv/pd/tests/pdctl" @@ -38,10 +37,6 @@ var _ = Suite(&clusterTestSuite{}) type clusterTestSuite struct{} -func (s *clusterTestSuite) SetUpSuite(c *C) { - server.EnableZap = true -} - func (s *clusterTestSuite) TestClusterAndPing(c *C) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/tests/pdctl/config/config_test.go b/tests/pdctl/config/config_test.go index 297cc538606..ad99d583133 100644 --- a/tests/pdctl/config/config_test.go +++ b/tests/pdctl/config/config_test.go @@ -28,7 +28,6 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" "github.com/tikv/pd/pkg/typeutil" - "github.com/tikv/pd/server" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/schedule/placement" "github.com/tikv/pd/tests" @@ -44,10 +43,6 @@ var _ = Suite(&configTestSuite{}) type configTestSuite struct{} -func (s *configTestSuite) SetUpSuite(c *C) { - server.EnableZap = true -} - type testItem struct { name string value interface{} diff --git a/tests/pdctl/global_test.go b/tests/pdctl/global_test.go index bb14eeafac2..de165eea600 100644 --- a/tests/pdctl/global_test.go +++ b/tests/pdctl/global_test.go @@ -37,10 +37,6 @@ var _ = Suite(&globalTestSuite{}) type globalTestSuite struct{} -func (s *globalTestSuite) SetUpSuite(c *C) { - server.EnableZap = true -} - func (s *globalTestSuite) TestSendAndGetComponent(c *C) { handler := func(ctx context.Context, s *server.Server) (http.Handler, server.ServiceGroup, error) { mux := http.NewServeMux() diff --git a/tests/pdctl/health/health_test.go b/tests/pdctl/health/health_test.go index ecc9b6deb2f..06e287dcb36 100644 --- a/tests/pdctl/health/health_test.go +++ b/tests/pdctl/health/health_test.go @@ -20,7 +20,6 @@ import ( "testing" . "github.com/pingcap/check" - "github.com/tikv/pd/server" "github.com/tikv/pd/server/api" "github.com/tikv/pd/server/cluster" "github.com/tikv/pd/tests" @@ -36,10 +35,6 @@ var _ = Suite(&healthTestSuite{}) type healthTestSuite struct{} -func (s *healthTestSuite) SetUpSuite(c *C) { - server.EnableZap = true -} - func (s *healthTestSuite) TestHealth(c *C) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/tests/pdctl/hot/hot_test.go b/tests/pdctl/hot/hot_test.go index de40564c2d9..06a657df7d7 100644 --- a/tests/pdctl/hot/hot_test.go +++ b/tests/pdctl/hot/hot_test.go @@ -25,7 +25,6 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" - "github.com/tikv/pd/server" "github.com/tikv/pd/server/api" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/core" @@ -44,10 +43,6 @@ var _ = Suite(&hotTestSuite{}) type hotTestSuite struct{} -func (s *hotTestSuite) SetUpSuite(c *C) { - server.EnableZap = true -} - func (s *hotTestSuite) TestHot(c *C) { statistics.Denoising = false ctx, cancel := context.WithCancel(context.Background()) diff --git a/tests/pdctl/label/label_test.go b/tests/pdctl/label/label_test.go index c3cd0b105ee..50a52413e82 100644 --- a/tests/pdctl/label/label_test.go +++ b/tests/pdctl/label/label_test.go @@ -23,7 +23,6 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" - "github.com/tikv/pd/server" "github.com/tikv/pd/server/api" "github.com/tikv/pd/server/config" "github.com/tikv/pd/tests" @@ -39,10 +38,6 @@ var _ = Suite(&labelTestSuite{}) type labelTestSuite struct{} -func (s *labelTestSuite) SetUpSuite(c *C) { - server.EnableZap = true -} - func (s *labelTestSuite) TestLabel(c *C) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/tests/pdctl/log/log_test.go b/tests/pdctl/log/log_test.go index 7103b842530..6499b2694c7 100644 --- a/tests/pdctl/log/log_test.go +++ b/tests/pdctl/log/log_test.go @@ -42,7 +42,6 @@ type logTestSuite struct { } func (s *logTestSuite) SetUpSuite(c *C) { - server.EnableZap = true s.ctx, s.cancel = context.WithCancel(context.Background()) var err error s.cluster, err = tests.NewTestCluster(s.ctx, 3) diff --git a/tests/pdctl/member/member_test.go b/tests/pdctl/member/member_test.go index 1dac74ce6a4..f85f2d946df 100644 --- a/tests/pdctl/member/member_test.go +++ b/tests/pdctl/member/member_test.go @@ -25,7 +25,6 @@ import ( "github.com/pingcap/kvproto/pkg/pdpb" "github.com/tikv/pd/pkg/etcdutil" "github.com/tikv/pd/pkg/testutil" - "github.com/tikv/pd/server" "github.com/tikv/pd/tests" "github.com/tikv/pd/tests/pdctl" pdctlCmd "github.com/tikv/pd/tools/pd-ctl/pdctl" @@ -39,10 +38,6 @@ var _ = Suite(&memberTestSuite{}) type memberTestSuite struct{} -func (s *memberTestSuite) SetUpSuite(c *C) { - server.EnableZap = true -} - func (s *memberTestSuite) TestMember(c *C) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/tests/pdctl/operator/operator_test.go b/tests/pdctl/operator/operator_test.go index d6eec639770..73ae2687c80 100644 --- a/tests/pdctl/operator/operator_test.go +++ b/tests/pdctl/operator/operator_test.go @@ -23,7 +23,6 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" - "github.com/tikv/pd/server" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/core" "github.com/tikv/pd/tests" @@ -39,10 +38,6 @@ var _ = Suite(&operatorTestSuite{}) type operatorTestSuite struct{} -func (s *operatorTestSuite) SetUpSuite(c *C) { - server.EnableZap = true -} - func (s *operatorTestSuite) TestOperator(c *C) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/tests/pdctl/region/region_test.go b/tests/pdctl/region/region_test.go index f6142878268..dd83accea55 100644 --- a/tests/pdctl/region/region_test.go +++ b/tests/pdctl/region/region_test.go @@ -24,7 +24,6 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" - "github.com/tikv/pd/server" "github.com/tikv/pd/server/api" "github.com/tikv/pd/server/core" "github.com/tikv/pd/tests" @@ -40,10 +39,6 @@ var _ = Suite(®ionTestSuite{}) type regionTestSuite struct{} -func (s *regionTestSuite) SetUpSuite(c *C) { - server.EnableZap = true -} - func (s *regionTestSuite) TestRegionKeyFormat(c *C) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/tests/pdctl/scheduler/scheduler_test.go b/tests/pdctl/scheduler/scheduler_test.go index de0e523dc8e..53ed808f410 100644 --- a/tests/pdctl/scheduler/scheduler_test.go +++ b/tests/pdctl/scheduler/scheduler_test.go @@ -23,7 +23,6 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" - "github.com/tikv/pd/server" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/versioninfo" "github.com/tikv/pd/tests" @@ -43,7 +42,6 @@ type schedulerTestSuite struct { } func (s *schedulerTestSuite) SetUpSuite(c *C) { - server.EnableZap = true s.context, s.cancel = context.WithCancel(context.Background()) } diff --git a/tests/pdctl/store/store_test.go b/tests/pdctl/store/store_test.go index 84a9a4e383d..a43d70722e8 100644 --- a/tests/pdctl/store/store_test.go +++ b/tests/pdctl/store/store_test.go @@ -23,7 +23,6 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" - "github.com/tikv/pd/server" "github.com/tikv/pd/server/api" "github.com/tikv/pd/server/core" "github.com/tikv/pd/server/core/storelimit" @@ -41,10 +40,6 @@ var _ = Suite(&storeTestSuite{}) type storeTestSuite struct{} -func (s *storeTestSuite) SetUpSuite(c *C) { - server.EnableZap = true -} - func (s *storeTestSuite) TestStore(c *C) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/tests/pdctl/tso/tso_test.go b/tests/pdctl/tso/tso_test.go index 0f67e0ff7cf..1d2cdb77dc0 100644 --- a/tests/pdctl/tso/tso_test.go +++ b/tests/pdctl/tso/tso_test.go @@ -21,7 +21,6 @@ import ( "time" . "github.com/pingcap/check" - "github.com/tikv/pd/server" "github.com/tikv/pd/tests/pdctl" pdctlCmd "github.com/tikv/pd/tools/pd-ctl/pdctl" ) @@ -34,10 +33,6 @@ var _ = Suite(&tsoTestSuite{}) type tsoTestSuite struct{} -func (s *tsoTestSuite) SetUpSuite(c *C) { - server.EnableZap = true -} - func (s *tsoTestSuite) TestTSO(c *C) { cmd := pdctlCmd.GetRootCmd() diff --git a/tests/server/api/api_test.go b/tests/server/api/api_test.go index 68af7235316..9e6248f9cec 100644 --- a/tests/server/api/api_test.go +++ b/tests/server/api/api_test.go @@ -133,7 +133,6 @@ type testMiddlewareSuite struct { func (s *testMiddlewareSuite) SetUpSuite(c *C) { c.Assert(failpoint.Enable("github.com/tikv/pd/server/api/enableFailpointAPI", "return(true)"), IsNil) ctx, cancel := context.WithCancel(context.Background()) - server.EnableZap = true s.cleanup = cancel cluster, err := tests.NewTestCluster(ctx, 3) c.Assert(err, IsNil) @@ -200,7 +199,6 @@ func (s *testMiddlewareSuite) TestRequestInfoMiddleware(c *C) { func BenchmarkDoRequestWithServiceMiddleware(b *testing.B) { b.StopTimer() ctx, cancel := context.WithCancel(context.Background()) - server.EnableZap = true cluster, _ := tests.NewTestCluster(ctx, 1) cluster.RunInitialServers() cluster.WaitLeader() @@ -223,7 +221,6 @@ func BenchmarkDoRequestWithServiceMiddleware(b *testing.B) { func BenchmarkDoRequestWithoutServiceMiddleware(b *testing.B) { b.StopTimer() ctx, cancel := context.WithCancel(context.Background()) - server.EnableZap = true cluster, _ := tests.NewTestCluster(ctx, 1) cluster.RunInitialServers() cluster.WaitLeader() @@ -348,7 +345,6 @@ func (s *testMiddlewareSuite) TestAuditLocalLogBackend(c *C) { func BenchmarkDoRequestWithLocalLogAudit(b *testing.B) { b.StopTimer() ctx, cancel := context.WithCancel(context.Background()) - server.EnableZap = true cluster, _ := tests.NewTestCluster(ctx, 1) cluster.RunInitialServers() cluster.WaitLeader() @@ -371,7 +367,6 @@ func BenchmarkDoRequestWithLocalLogAudit(b *testing.B) { func BenchmarkDoRequestWithoutLocalLogAudit(b *testing.B) { b.StopTimer() ctx, cancel := context.WithCancel(context.Background()) - server.EnableZap = true cluster, _ := tests.NewTestCluster(ctx, 1) cluster.RunInitialServers() cluster.WaitLeader() @@ -400,7 +395,6 @@ type testRedirectorSuite struct { func (s *testRedirectorSuite) SetUpSuite(c *C) { ctx, cancel := context.WithCancel(context.Background()) - server.EnableZap = true s.cleanup = cancel cluster, err := tests.NewTestCluster(ctx, 3, func(conf *config.Config, serverName string) { conf.TickInterval = typeutil.Duration{Duration: 50 * time.Millisecond} diff --git a/tests/server/cluster/cluster_test.go b/tests/server/cluster/cluster_test.go index bbb10c525aa..f493ed21b00 100644 --- a/tests/server/cluster/cluster_test.go +++ b/tests/server/cluster/cluster_test.go @@ -68,7 +68,6 @@ type clusterTestSuite struct { func (s *clusterTestSuite) SetUpSuite(c *C) { s.ctx, s.cancel = context.WithCancel(context.Background()) - server.EnableZap = true // to prevent GetStorage dashboard.SetCheckInterval(30 * time.Minute) } diff --git a/tests/server/cluster/cluster_work_test.go b/tests/server/cluster/cluster_work_test.go index a9d83ca5fcf..b3d9fdcf9e0 100644 --- a/tests/server/cluster/cluster_work_test.go +++ b/tests/server/cluster/cluster_work_test.go @@ -23,7 +23,6 @@ import ( "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" "github.com/tikv/pd/pkg/testutil" - "github.com/tikv/pd/server" "github.com/tikv/pd/server/core" "github.com/tikv/pd/tests" ) @@ -37,7 +36,6 @@ type clusterWorkerTestSuite struct { func (s *clusterWorkerTestSuite) SetUpSuite(c *C) { s.ctx, s.cancel = context.WithCancel(context.Background()) - server.EnableZap = true } func (s *clusterWorkerTestSuite) TearDownSuite(c *C) { diff --git a/tests/server/global_config/global_config_test.go b/tests/server/global_config/global_config_test.go index cc4b73b8a56..87dc62e35a2 100644 --- a/tests/server/global_config/global_config_test.go +++ b/tests/server/global_config/global_config_test.go @@ -68,7 +68,6 @@ func (s *GlobalConfigTestSuite) SetUpSuite(c *C) { gsi, s.cleanup, err = server.NewTestServer(assertutil.NewChecker(func() {})) s.server = &server.GrpcServer{Server: gsi} c.Assert(err, IsNil) - server.EnableZap = true addr := s.server.GetAddr() s.client, err = grpc.Dial(strings.TrimPrefix(addr, "http://"), grpc.WithInsecure()) c.Assert(err, IsNil) diff --git a/tests/server/id/id_test.go b/tests/server/id/id_test.go index 1fb3563d039..b624ceb056f 100644 --- a/tests/server/id/id_test.go +++ b/tests/server/id/id_test.go @@ -22,7 +22,6 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/pdpb" "github.com/tikv/pd/pkg/testutil" - "github.com/tikv/pd/server" "github.com/tikv/pd/tests" "go.uber.org/goleak" ) @@ -46,7 +45,6 @@ type testAllocIDSuite struct { func (s *testAllocIDSuite) SetUpSuite(c *C) { s.ctx, s.cancel = context.WithCancel(context.Background()) - server.EnableZap = true } func (s *testAllocIDSuite) TearDownSuite(c *C) { diff --git a/tests/server/join/join_fail/join_fail_test.go b/tests/server/join/join_fail/join_fail_test.go index 8fed271ea95..bc4e98abdce 100644 --- a/tests/server/join/join_fail/join_fail_test.go +++ b/tests/server/join/join_fail/join_fail_test.go @@ -22,7 +22,6 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/failpoint" "github.com/tikv/pd/pkg/testutil" - "github.com/tikv/pd/server" "github.com/tikv/pd/tests" "go.uber.org/goleak" ) @@ -39,10 +38,6 @@ var _ = Suite(&joinTestSuite{}) type joinTestSuite struct{} -func (s *joinTestSuite) SetUpSuite(c *C) { - server.EnableZap = true -} - func (s *joinTestSuite) TestFailedPDJoinInStep1(c *C) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/tests/server/join/join_test.go b/tests/server/join/join_test.go index 520e5c817ad..8cc9cdcdb34 100644 --- a/tests/server/join/join_test.go +++ b/tests/server/join/join_test.go @@ -46,7 +46,6 @@ type joinTestSuite struct { func (s *joinTestSuite) SetUpSuite(c *C) { s.ctx, s.cancel = context.WithCancel(context.Background()) - server.EnableZap = true server.EtcdStartTimeout = 10 * time.Second } diff --git a/tests/server/member/member_test.go b/tests/server/member/member_test.go index 2eb2ef19ddd..0215d95e5ff 100644 --- a/tests/server/member/member_test.go +++ b/tests/server/member/member_test.go @@ -63,7 +63,6 @@ type memberTestSuite struct { func (s *memberTestSuite) SetUpSuite(c *C) { s.ctx, s.cancel = context.WithCancel(context.Background()) - server.EnableZap = true } func (s *memberTestSuite) TearDownSuite(c *C) { diff --git a/tests/server/region_syncer/region_syncer_test.go b/tests/server/region_syncer/region_syncer_test.go index dd7dc1fb5f1..c4c91806c9f 100644 --- a/tests/server/region_syncer/region_syncer_test.go +++ b/tests/server/region_syncer/region_syncer_test.go @@ -24,7 +24,6 @@ import ( "github.com/pingcap/kvproto/pkg/metapb" "github.com/tikv/pd/pkg/mock/mockid" "github.com/tikv/pd/pkg/testutil" - "github.com/tikv/pd/server" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/core" "github.com/tikv/pd/tests" @@ -48,7 +47,6 @@ type regionSyncerTestSuite struct { func (s *regionSyncerTestSuite) SetUpSuite(c *C) { s.ctx, s.cancel = context.WithCancel(context.Background()) - server.EnableZap = true } func (s *regionSyncerTestSuite) TearDownSuite(c *C) { diff --git a/tests/server/server_test.go b/tests/server/server_test.go index 56b357f7fcd..f75ef4e15f0 100644 --- a/tests/server/server_test.go +++ b/tests/server/server_test.go @@ -21,7 +21,6 @@ import ( . "github.com/pingcap/check" "github.com/tikv/pd/pkg/tempurl" "github.com/tikv/pd/pkg/testutil" - "github.com/tikv/pd/server" "github.com/tikv/pd/server/config" "github.com/tikv/pd/tests" "go.uber.org/goleak" @@ -47,7 +46,6 @@ type serverTestSuite struct { func (s *serverTestSuite) SetUpSuite(c *C) { s.ctx, s.cancel = context.WithCancel(context.Background()) - server.EnableZap = true } func (s *serverTestSuite) TearDownSuite(c *C) { diff --git a/tests/server/storage/hot_region_storage_test.go b/tests/server/storage/hot_region_storage_test.go index 78fdb5632c6..5a11f8c23c4 100644 --- a/tests/server/storage/hot_region_storage_test.go +++ b/tests/server/storage/hot_region_storage_test.go @@ -22,7 +22,6 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" - "github.com/tikv/pd/server" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/core" "github.com/tikv/pd/server/statistics" @@ -39,10 +38,6 @@ var _ = Suite(&hotRegionHistorySuite{}) type hotRegionHistorySuite struct{} -func (s *hotRegionHistorySuite) SetUpSuite(c *C) { - server.EnableZap = true -} - func (s *hotRegionHistorySuite) TestHotRegionStorage(c *C) { statistics.Denoising = false ctx, cancel := context.WithCancel(context.Background()) diff --git a/tests/server/tso/allocator_test.go b/tests/server/tso/allocator_test.go index b3b2002f2b0..c7bb38e5d9a 100644 --- a/tests/server/tso/allocator_test.go +++ b/tests/server/tso/allocator_test.go @@ -28,7 +28,6 @@ import ( "github.com/tikv/pd/pkg/etcdutil" "github.com/tikv/pd/pkg/slice" "github.com/tikv/pd/pkg/testutil" - "github.com/tikv/pd/server" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/tso" "github.com/tikv/pd/tests" @@ -43,7 +42,6 @@ type testAllocatorSuite struct { func (s *testAllocatorSuite) SetUpSuite(c *C) { s.ctx, s.cancel = context.WithCancel(context.Background()) - server.EnableZap = true } func (s *testAllocatorSuite) TearDownSuite(c *C) { diff --git a/tests/server/tso/consistency_test.go b/tests/server/tso/consistency_test.go index 974c0bb71c2..170a1b4e9a8 100644 --- a/tests/server/tso/consistency_test.go +++ b/tests/server/tso/consistency_test.go @@ -28,7 +28,6 @@ import ( "github.com/tikv/pd/pkg/grpcutil" "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/pkg/tsoutil" - "github.com/tikv/pd/server" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/tso" "github.com/tikv/pd/tests" @@ -51,7 +50,6 @@ func (s *testTSOConsistencySuite) SetUpSuite(c *C) { s.ctx, s.cancel = context.WithCancel(context.Background()) s.dcClientMap = make(map[string]pdpb.PDClient) s.tsPool = make(map[uint64]struct{}) - server.EnableZap = true } func (s *testTSOConsistencySuite) TearDownSuite(c *C) { diff --git a/tests/server/tso/global_tso_test.go b/tests/server/tso/global_tso_test.go index e67ed1d3798..1086751fa08 100644 --- a/tests/server/tso/global_tso_test.go +++ b/tests/server/tso/global_tso_test.go @@ -28,7 +28,6 @@ import ( "github.com/pingcap/kvproto/pkg/pdpb" "github.com/tikv/pd/pkg/grpcutil" "github.com/tikv/pd/pkg/testutil" - "github.com/tikv/pd/server" "github.com/tikv/pd/server/tso" "github.com/tikv/pd/tests" ) @@ -51,7 +50,6 @@ type testNormalGlobalTSOSuite struct { func (s *testNormalGlobalTSOSuite) SetUpSuite(c *C) { s.ctx, s.cancel = context.WithCancel(context.Background()) - server.EnableZap = true } func (s *testNormalGlobalTSOSuite) TearDownSuite(c *C) { diff --git a/tests/server/tso/manager_test.go b/tests/server/tso/manager_test.go index 68edc073445..26fa07cc1d5 100644 --- a/tests/server/tso/manager_test.go +++ b/tests/server/tso/manager_test.go @@ -26,7 +26,6 @@ import ( "github.com/pingcap/failpoint" "github.com/tikv/pd/pkg/etcdutil" "github.com/tikv/pd/pkg/testutil" - "github.com/tikv/pd/server" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/tso" "github.com/tikv/pd/tests" @@ -42,7 +41,6 @@ type testManagerSuite struct { func (s *testManagerSuite) SetUpSuite(c *C) { s.ctx, s.cancel = context.WithCancel(context.Background()) - server.EnableZap = true } func (s *testManagerSuite) TearDownSuite(c *C) { diff --git a/tests/server/tso/tso_test.go b/tests/server/tso/tso_test.go index b46a7b902aa..27bc53d5652 100644 --- a/tests/server/tso/tso_test.go +++ b/tests/server/tso/tso_test.go @@ -25,7 +25,6 @@ import ( "github.com/pingcap/kvproto/pkg/pdpb" "github.com/tikv/pd/pkg/grpcutil" "github.com/tikv/pd/pkg/testutil" - "github.com/tikv/pd/server" "github.com/tikv/pd/server/config" "github.com/tikv/pd/tests" ) @@ -39,7 +38,6 @@ type testTSOSuite struct { func (s *testTSOSuite) SetUpSuite(c *C) { s.ctx, s.cancel = context.WithCancel(context.Background()) - server.EnableZap = true } func (s *testTSOSuite) TearDownSuite(c *C) { diff --git a/tests/server/watch/leader_watch_test.go b/tests/server/watch/leader_watch_test.go index 160e089ecb7..88d1470d733 100644 --- a/tests/server/watch/leader_watch_test.go +++ b/tests/server/watch/leader_watch_test.go @@ -22,7 +22,6 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/failpoint" "github.com/tikv/pd/pkg/testutil" - "github.com/tikv/pd/server" "github.com/tikv/pd/server/config" "github.com/tikv/pd/tests" "go.uber.org/goleak" @@ -45,7 +44,6 @@ type watchTestSuite struct { func (s *watchTestSuite) SetUpSuite(c *C) { s.ctx, s.cancel = context.WithCancel(context.Background()) - server.EnableZap = true } func (s *watchTestSuite) TearDownSuite(c *C) { From f82e15ed161acf1955d55fa0c9738a81b60bed74 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Thu, 9 Jun 2022 15:06:31 +0800 Subject: [PATCH 34/82] tests: testify the autoscaling/compatibility/dashboard/pdbackup tests (#5120) ref tikv/pd#4813 Testify the autoscaling/compatibility/dashboard/pdbackup tests. Signed-off-by: JmPotato --- tests/autoscaling/autoscaling_test.go | 23 ++---- tests/compatibility/version_upgrade_test.go | 71 +++++++--------- tests/dashboard/race_test.go | 40 +++++---- tests/dashboard/service_test.go | 91 ++++++++++----------- tests/pdbackup/backup_test.go | 31 +++---- 5 files changed, 113 insertions(+), 143 deletions(-) diff --git a/tests/autoscaling/autoscaling_test.go b/tests/autoscaling/autoscaling_test.go index 62fcf9ee886..ce60f648136 100644 --- a/tests/autoscaling/autoscaling_test.go +++ b/tests/autoscaling/autoscaling_test.go @@ -20,37 +20,30 @@ import ( "net/http" "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/tests" "go.uber.org/goleak" ) -func Test(t *testing.T) { - TestingT(t) -} - func TestMain(m *testing.M) { goleak.VerifyTestMain(m, testutil.LeakOptions...) } -var _ = Suite(&apiTestSuite{}) - -type apiTestSuite struct{} - -func (s *apiTestSuite) TestAPI(c *C) { +func TestAPI(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() cluster, err := tests.NewTestCluster(ctx, 1) - c.Assert(err, IsNil) + re.NoError(err) defer cluster.Destroy() err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() leaderServer := cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) + re.NoError(leaderServer.BootstrapCluster()) var jsonStr = []byte(` { @@ -102,7 +95,7 @@ func (s *apiTestSuite) TestAPI(c *C) { ] }`) resp, err := http.Post(leaderServer.GetAddr()+"/autoscaling", "application/json", bytes.NewBuffer(jsonStr)) - c.Assert(err, IsNil) + re.NoError(err) defer resp.Body.Close() - c.Assert(resp.StatusCode, Equals, 200) + re.Equal(200, resp.StatusCode) } diff --git a/tests/compatibility/version_upgrade_test.go b/tests/compatibility/version_upgrade_test.go index 2fcdc1bf5d5..e600a848b3a 100644 --- a/tests/compatibility/version_upgrade_test.go +++ b/tests/compatibility/version_upgrade_test.go @@ -19,42 +19,26 @@ import ( "testing" "github.com/coreos/go-semver/semver" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/server" "github.com/tikv/pd/tests" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&compatibilityTestSuite{}) - -type compatibilityTestSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *compatibilityTestSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) -} - -func (s *compatibilityTestSuite) TearDownSuite(c *C) { - s.cancel() -} - -func (s *compatibilityTestSuite) TestStoreRegister(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 1) - c.Assert(err, IsNil) +func TestStoreRegister(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 1) + re.NoError(err) defer cluster.Destroy() err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() leaderServer := cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) + re.NoError(leaderServer.BootstrapCluster()) putStoreRequest := &pdpb.PutStoreRequest{ Header: &pdpb.RequestHeader{ClusterId: leaderServer.GetClusterID()}, @@ -67,21 +51,21 @@ func (s *compatibilityTestSuite) TestStoreRegister(c *C) { svr := &server.GrpcServer{Server: leaderServer.GetServer()} _, err = svr.PutStore(context.Background(), putStoreRequest) - c.Assert(err, IsNil) + re.NoError(err) // FIX ME: read v0.0.0 in sometime cluster.WaitLeader() version := leaderServer.GetClusterVersion() // Restart all PDs. err = cluster.StopAll() - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() leaderServer = cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer, NotNil) + re.NotNil(leaderServer) newVersion := leaderServer.GetClusterVersion() - c.Assert(version, Equals, newVersion) + re.Equal(version, newVersion) // putNewStore with old version putStoreRequest = &pdpb.PutStoreRequest{ @@ -93,18 +77,21 @@ func (s *compatibilityTestSuite) TestStoreRegister(c *C) { }, } _, err = svr.PutStore(context.Background(), putStoreRequest) - c.Assert(err, NotNil) + re.Error(err) } -func (s *compatibilityTestSuite) TestRollingUpgrade(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 1) - c.Assert(err, IsNil) +func TestRollingUpgrade(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 1) + re.NoError(err) defer cluster.Destroy() err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() leaderServer := cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) + re.NoError(leaderServer.BootstrapCluster()) stores := []*pdpb.PutStoreRequest{ { @@ -144,9 +131,9 @@ func (s *compatibilityTestSuite) TestRollingUpgrade(c *C) { svr := &server.GrpcServer{Server: leaderServer.GetServer()} for _, store := range stores { _, err = svr.PutStore(context.Background(), store) - c.Assert(err, IsNil) + re.NoError(err) } - c.Assert(leaderServer.GetClusterVersion(), Equals, semver.Version{Major: 2, Minor: 0, Patch: 1}) + re.Equal(semver.Version{Major: 2, Minor: 0, Patch: 1}, leaderServer.GetClusterVersion()) // rolling update for i, store := range stores { if i == 0 { @@ -155,11 +142,11 @@ func (s *compatibilityTestSuite) TestRollingUpgrade(c *C) { } store.Store.Version = "2.1.0" resp, err := svr.PutStore(context.Background(), store) - c.Assert(err, IsNil) + re.NoError(err) if i != len(stores)-1 { - c.Assert(leaderServer.GetClusterVersion(), Equals, semver.Version{Major: 2, Minor: 0, Patch: 1}) - c.Assert(resp.GetHeader().GetError(), IsNil) + re.Equal(semver.Version{Major: 2, Minor: 0, Patch: 1}, leaderServer.GetClusterVersion()) + re.Nil(resp.GetHeader().GetError()) } } - c.Assert(leaderServer.GetClusterVersion(), Equals, semver.Version{Major: 2, Minor: 1}) + re.Equal(semver.Version{Major: 2, Minor: 1}, leaderServer.GetClusterVersion()) } diff --git a/tests/dashboard/race_test.go b/tests/dashboard/race_test.go index 9ca82567b52..4bb31f55ecc 100644 --- a/tests/dashboard/race_test.go +++ b/tests/dashboard/race_test.go @@ -16,10 +16,10 @@ package dashboard_test import ( "context" + "testing" "time" - . "github.com/pingcap/check" - + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/dashboard" "github.com/tikv/pd/tests" @@ -27,33 +27,31 @@ import ( _ "github.com/tikv/pd/server/schedulers" ) -var _ = Suite(&raceTestSuite{}) +func TestCancelDuringStarting(t *testing.T) { + prepareTestConfig() + defer resetTestConfig() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() -type raceTestSuite struct{} + re := require.New(t) + cluster, err := tests.NewTestCluster(ctx, 1) + re.NoError(err) + defer cluster.Destroy() + re.NoError(cluster.RunInitialServers()) + cluster.WaitLeader() -func (s *raceTestSuite) SetUpSuite(c *C) { + time.Sleep(60 * time.Millisecond) + cancel() +} + +func prepareTestConfig() { dashboard.SetCheckInterval(50 * time.Millisecond) tests.WaitLeaderReturnDelay = 0 tests.WaitLeaderCheckInterval = 20 * time.Millisecond } -func (s *raceTestSuite) TearDownSuite(c *C) { +func resetTestConfig() { dashboard.SetCheckInterval(time.Second) tests.WaitLeaderReturnDelay = 20 * time.Millisecond tests.WaitLeaderCheckInterval = 500 * time.Millisecond } - -func (s *raceTestSuite) TestCancelDuringStarting(c *C) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - cluster, err := tests.NewTestCluster(ctx, 1) - c.Assert(err, IsNil) - defer cluster.Destroy() - err = cluster.RunInitialServers() - c.Assert(err, IsNil) - cluster.WaitLeader() - - time.Sleep(60 * time.Millisecond) - cancel() -} diff --git a/tests/dashboard/service_test.go b/tests/dashboard/service_test.go index f11f39e466a..12bc4496d6c 100644 --- a/tests/dashboard/service_test.go +++ b/tests/dashboard/service_test.go @@ -22,7 +22,7 @@ import ( "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/suite" "go.uber.org/goleak" "github.com/tikv/pd/pkg/dashboard" @@ -36,26 +36,25 @@ import ( _ "github.com/tikv/pd/server/schedulers" ) -func Test(t *testing.T) { - TestingT(t) -} - func TestMain(m *testing.M) { goleak.VerifyTestMain(m, testutil.LeakOptions...) } -var _ = Suite(&dashboardTestSuite{}) - type dashboardTestSuite struct { + suite.Suite ctx context.Context cancel context.CancelFunc httpClient *http.Client } -func (s *dashboardTestSuite) SetUpSuite(c *C) { +func TestDashboardTestSuite(t *testing.T) { + suite.Run(t, new(dashboardTestSuite)) +} + +func (suite *dashboardTestSuite) SetupSuite() { dashboard.SetCheckInterval(10 * time.Millisecond) - s.ctx, s.cancel = context.WithCancel(context.Background()) - s.httpClient = &http.Client{ + suite.ctx, suite.cancel = context.WithCancel(context.Background()) + suite.httpClient = &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { // ErrUseLastResponse can be returned by Client.CheckRedirect hooks to // control how redirects are processed. If returned, the next request @@ -66,73 +65,73 @@ func (s *dashboardTestSuite) SetUpSuite(c *C) { } } -func (s *dashboardTestSuite) TearDownSuite(c *C) { - s.cancel() - s.httpClient.CloseIdleConnections() +func (suite *dashboardTestSuite) TearDownSuite() { + suite.cancel() + suite.httpClient.CloseIdleConnections() dashboard.SetCheckInterval(time.Second) } -func (s *dashboardTestSuite) TestDashboardRedirect(c *C) { - s.testDashboard(c, false) +func (suite *dashboardTestSuite) TestDashboardRedirect() { + suite.testDashboard(false) } -func (s *dashboardTestSuite) TestDashboardProxy(c *C) { - s.testDashboard(c, true) +func (suite *dashboardTestSuite) TestDashboardProxy() { + suite.testDashboard(true) } -func (s *dashboardTestSuite) checkRespCode(c *C, url string, code int) { - resp, err := s.httpClient.Get(url) - c.Assert(err, IsNil) +func (suite *dashboardTestSuite) checkRespCode(url string, code int) { + resp, err := suite.httpClient.Get(url) + suite.NoError(err) _, err = io.ReadAll(resp.Body) - c.Assert(err, IsNil) + suite.NoError(err) resp.Body.Close() - c.Assert(resp.StatusCode, Equals, code) + suite.Equal(code, resp.StatusCode) } -func (s *dashboardTestSuite) waitForConfigSync() { +func waitForConfigSync() { time.Sleep(time.Second) } -func (s *dashboardTestSuite) checkServiceIsStarted(c *C, internalProxy bool, servers map[string]*tests.TestServer, leader *tests.TestServer) string { - s.waitForConfigSync() +func (suite *dashboardTestSuite) checkServiceIsStarted(internalProxy bool, servers map[string]*tests.TestServer, leader *tests.TestServer) string { + waitForConfigSync() dashboardAddress := leader.GetServer().GetPersistOptions().GetDashboardAddress() hasServiceNode := false for _, srv := range servers { - c.Assert(srv.GetPersistOptions().GetDashboardAddress(), Equals, dashboardAddress) + suite.Equal(dashboardAddress, srv.GetPersistOptions().GetDashboardAddress()) addr := srv.GetAddr() if addr == dashboardAddress || internalProxy { - s.checkRespCode(c, fmt.Sprintf("%s/dashboard/", addr), http.StatusOK) - s.checkRespCode(c, fmt.Sprintf("%s/dashboard/api/keyvisual/heatmaps", addr), http.StatusUnauthorized) + suite.checkRespCode(fmt.Sprintf("%s/dashboard/", addr), http.StatusOK) + suite.checkRespCode(fmt.Sprintf("%s/dashboard/api/keyvisual/heatmaps", addr), http.StatusUnauthorized) if addr == dashboardAddress { hasServiceNode = true } } else { - s.checkRespCode(c, fmt.Sprintf("%s/dashboard/", addr), http.StatusTemporaryRedirect) - s.checkRespCode(c, fmt.Sprintf("%s/dashboard/api/keyvisual/heatmaps", addr), http.StatusTemporaryRedirect) + suite.checkRespCode(fmt.Sprintf("%s/dashboard/", addr), http.StatusTemporaryRedirect) + suite.checkRespCode(fmt.Sprintf("%s/dashboard/api/keyvisual/heatmaps", addr), http.StatusTemporaryRedirect) } } - c.Assert(hasServiceNode, IsTrue) + suite.True(hasServiceNode) return dashboardAddress } -func (s *dashboardTestSuite) checkServiceIsStopped(c *C, servers map[string]*tests.TestServer) { - s.waitForConfigSync() +func (suite *dashboardTestSuite) checkServiceIsStopped(servers map[string]*tests.TestServer) { + waitForConfigSync() for _, srv := range servers { - c.Assert(srv.GetPersistOptions().GetDashboardAddress(), Equals, "none") + suite.Equal("none", srv.GetPersistOptions().GetDashboardAddress()) addr := srv.GetAddr() - s.checkRespCode(c, fmt.Sprintf("%s/dashboard/", addr), http.StatusNotFound) - s.checkRespCode(c, fmt.Sprintf("%s/dashboard/api/keyvisual/heatmaps", addr), http.StatusNotFound) + suite.checkRespCode(fmt.Sprintf("%s/dashboard/", addr), http.StatusNotFound) + suite.checkRespCode(fmt.Sprintf("%s/dashboard/api/keyvisual/heatmaps", addr), http.StatusNotFound) } } -func (s *dashboardTestSuite) testDashboard(c *C, internalProxy bool) { - cluster, err := tests.NewTestCluster(s.ctx, 3, func(conf *config.Config, serverName string) { +func (suite *dashboardTestSuite) testDashboard(internalProxy bool) { + cluster, err := tests.NewTestCluster(suite.ctx, 3, func(conf *config.Config, serverName string) { conf.Dashboard.InternalProxy = internalProxy }) - c.Assert(err, IsNil) + suite.NoError(err) defer cluster.Destroy() err = cluster.RunInitialServers() - c.Assert(err, IsNil) + suite.NoError(err) cmd := pdctlCmd.GetRootCmd() @@ -142,7 +141,7 @@ func (s *dashboardTestSuite) testDashboard(c *C, internalProxy bool) { leaderAddr := leader.GetAddr() // auto select node - dashboardAddress1 := s.checkServiceIsStarted(c, internalProxy, servers, leader) + dashboardAddress1 := suite.checkServiceIsStarted(internalProxy, servers, leader) // pd-ctl set another addr var dashboardAddress2 string @@ -154,13 +153,13 @@ func (s *dashboardTestSuite) testDashboard(c *C, internalProxy bool) { } args := []string{"-u", leaderAddr, "config", "set", "dashboard-address", dashboardAddress2} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - s.checkServiceIsStarted(c, internalProxy, servers, leader) - c.Assert(leader.GetServer().GetPersistOptions().GetDashboardAddress(), Equals, dashboardAddress2) + suite.NoError(err) + suite.checkServiceIsStarted(internalProxy, servers, leader) + suite.Equal(dashboardAddress2, leader.GetServer().GetPersistOptions().GetDashboardAddress()) // pd-ctl set stop args = []string{"-u", leaderAddr, "config", "set", "dashboard-address", "none"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - s.checkServiceIsStopped(c, servers) + suite.NoError(err) + suite.checkServiceIsStopped(servers) } diff --git a/tests/pdbackup/backup_test.go b/tests/pdbackup/backup_test.go index a36cf89f44d..b5034742f69 100644 --- a/tests/pdbackup/backup_test.go +++ b/tests/pdbackup/backup_test.go @@ -22,27 +22,20 @@ import ( "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/tests" "github.com/tikv/pd/tools/pd-backup/pdbackup" "go.etcd.io/etcd/clientv3" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&backupTestSuite{}) - -type backupTestSuite struct{} - -func (s *backupTestSuite) TestBackup(c *C) { +func TestBackup(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() cluster, err := tests.NewTestCluster(ctx, 1) - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() pdAddr := cluster.GetConfig().GetClientURL() urls := strings.Split(pdAddr, ",") @@ -52,18 +45,18 @@ func (s *backupTestSuite) TestBackup(c *C) { DialTimeout: 3 * time.Second, TLS: nil, }) - c.Assert(err, IsNil) + re.NoError(err) backupInfo, err := pdbackup.GetBackupInfo(client, pdAddr) - c.Assert(err, IsNil) - c.Assert(backupInfo, NotNil) + re.NoError(err) + re.NotNil(backupInfo) backBytes, err := json.Marshal(backupInfo) - c.Assert(err, IsNil) + re.NoError(err) var formatBuffer bytes.Buffer err = json.Indent(&formatBuffer, backBytes, "", " ") - c.Assert(err, IsNil) + re.NoError(err) newInfo := &pdbackup.BackupInfo{} err = json.Unmarshal(formatBuffer.Bytes(), newInfo) - c.Assert(err, IsNil) - c.Assert(backupInfo, DeepEquals, newInfo) + re.NoError(err) + re.Equal(newInfo, backupInfo) } From ae157b5eb647a6b1adb7138c1ae486bd3952eb1c Mon Sep 17 00:00:00 2001 From: LLThomas Date: Fri, 10 Jun 2022 11:22:30 +0800 Subject: [PATCH 35/82] storage: migrate test framework to testify (#5139) ref tikv/pd#4813 As the title says. Signed-off-by: LLThomas --- server/storage/hot_region_storage_test.go | 32 ++-- server/storage/storage_gc_test.go | 102 ++++++------ server/storage/storage_test.go | 192 +++++++++++----------- 3 files changed, 159 insertions(+), 167 deletions(-) diff --git a/server/storage/hot_region_storage_test.go b/server/storage/hot_region_storage_test.go index 7447c47f8d8..29dc4140317 100644 --- a/server/storage/hot_region_storage_test.go +++ b/server/storage/hot_region_storage_test.go @@ -21,11 +21,10 @@ import ( "math/rand" "os" "path/filepath" - "reflect" "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/server/core" ) @@ -104,21 +103,11 @@ func (m *MockPackHotRegionInfo) ClearHotRegion() { m.historyHotWrites = make([]HistoryHotRegion, 0) } -var _ = SerialSuites(&testHotRegionStorage{}) - -type testHotRegionStorage struct { - ctx context.Context - cancel context.CancelFunc -} - -func (t *testHotRegionStorage) SetUpSuite(c *C) { - t.ctx, t.cancel = context.WithCancel(context.Background()) -} - -func (t *testHotRegionStorage) TestHotRegionWrite(c *C) { +func TestHotRegionWrite(t *testing.T) { + re := require.New(t) packHotRegionInfo := &MockPackHotRegionInfo{} store, clean, err := newTestHotRegionStorage(10*time.Minute, 1, packHotRegionInfo) - c.Assert(err, IsNil) + re.NoError(err) defer clean() now := time.Now() hotRegionStorages := []HistoryHotRegion{ @@ -172,20 +161,21 @@ func (t *testHotRegionStorage) TestHotRegionWrite(c *C) { for next, err := iter.Next(); next != nil && err == nil; next, err = iter.Next() { copyHotRegionStorages[index].StartKey = core.HexRegionKeyStr([]byte(copyHotRegionStorages[index].StartKey)) copyHotRegionStorages[index].EndKey = core.HexRegionKeyStr([]byte(copyHotRegionStorages[index].EndKey)) - c.Assert(reflect.DeepEqual(©HotRegionStorages[index], next), IsTrue) + re.Equal(©HotRegionStorages[index], next) index++ } - c.Assert(err, IsNil) - c.Assert(index, Equals, 3) + re.NoError(err) + re.Equal(3, index) } -func (t *testHotRegionStorage) TestHotRegionDelete(c *C) { +func TestHotRegionDelete(t *testing.T) { + re := require.New(t) defaultRemainDay := 7 defaultDelteData := 30 deleteDate := time.Now().AddDate(0, 0, 0) packHotRegionInfo := &MockPackHotRegionInfo{} store, clean, err := newTestHotRegionStorage(10*time.Minute, uint64(defaultRemainDay), packHotRegionInfo) - c.Assert(err, IsNil) + re.NoError(err) defer clean() historyHotRegions := make([]HistoryHotRegion, 0) for i := 0; i < defaultDelteData; i++ { @@ -207,7 +197,7 @@ func (t *testHotRegionStorage) TestHotRegionDelete(c *C) { num := 0 for next, err := iter.Next(); next != nil && err == nil; next, err = iter.Next() { num++ - c.Assert(reflect.DeepEqual(next, &historyHotRegions[defaultRemainDay-num]), IsTrue) + re.Equal(&historyHotRegions[defaultRemainDay-num], next) } } diff --git a/server/storage/storage_gc_test.go b/server/storage/storage_gc_test.go index eb0e51c79d0..371dc5759f9 100644 --- a/server/storage/storage_gc_test.go +++ b/server/storage/storage_gc_test.go @@ -16,18 +16,14 @@ package storage import ( "math" + "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/failpoint" + "github.com/stretchr/testify/require" "github.com/tikv/pd/server/storage/endpoint" ) -var _ = Suite(&testStorageGCSuite{}) - -type testStorageGCSuite struct { -} - func testGCSafePoints() ([]string, []uint64) { spaceIDs := []string{ "keySpace1", @@ -73,20 +69,22 @@ func testServiceSafePoints() ([]string, []*endpoint.ServiceSafePoint) { return spaceIDs, serviceSafePoints } -func (s *testStorageGCSuite) TestSaveLoadServiceSafePoint(c *C) { +func TestSaveLoadServiceSafePoint(t *testing.T) { + re := require.New(t) storage := NewStorageWithMemoryBackend() testSpaceID, testSafePoints := testServiceSafePoints() for i := range testSpaceID { - c.Assert(storage.SaveServiceSafePoint(testSpaceID[i], testSafePoints[i]), IsNil) + re.NoError(storage.SaveServiceSafePoint(testSpaceID[i], testSafePoints[i])) } for i := range testSpaceID { loadedSafePoint, err := storage.LoadServiceSafePoint(testSpaceID[i], testSafePoints[i].ServiceID) - c.Assert(err, IsNil) - c.Assert(loadedSafePoint, DeepEquals, testSafePoints[i]) + re.NoError(err) + re.Equal(testSafePoints[i], loadedSafePoint) } } -func (s *testStorageGCSuite) TestLoadMinServiceSafePoint(c *C) { +func TestLoadMinServiceSafePoint(t *testing.T) { + re := require.New(t) storage := NewStorageWithMemoryBackend() currentTime := time.Now() expireAt1 := currentTime.Add(100 * time.Second).Unix() @@ -101,116 +99,120 @@ func (s *testStorageGCSuite) TestLoadMinServiceSafePoint(c *C) { testKeySpace := "test" for _, serviceSafePoint := range serviceSafePoints { - c.Assert(storage.SaveServiceSafePoint(testKeySpace, serviceSafePoint), IsNil) + re.NoError(storage.SaveServiceSafePoint(testKeySpace, serviceSafePoint)) } // enabling failpoint to make expired key removal immediately observable - c.Assert(failpoint.Enable("github.com/tikv/pd/server/storage/endpoint/removeExpiredKeys", "return(true)"), IsNil) + re.NoError(failpoint.Enable("github.com/tikv/pd/server/storage/endpoint/removeExpiredKeys", "return(true)")) minSafePoint, err := storage.LoadMinServiceSafePoint(testKeySpace, currentTime) - c.Assert(err, IsNil) - c.Assert(minSafePoint, DeepEquals, serviceSafePoints[0]) + re.NoError(err) + re.Equal(serviceSafePoints[0], minSafePoint) // the safePoint with ServiceID 0 should be removed due to expiration minSafePoint2, err := storage.LoadMinServiceSafePoint(testKeySpace, currentTime.Add(150*time.Second)) - c.Assert(err, IsNil) - c.Assert(minSafePoint2, DeepEquals, serviceSafePoints[1]) + re.NoError(err) + re.Equal(serviceSafePoints[1], minSafePoint2) // verify that service safe point with ServiceID 0 has been removed ssp, err := storage.LoadServiceSafePoint(testKeySpace, "0") - c.Assert(err, IsNil) - c.Assert(ssp, IsNil) + re.NoError(err) + re.Nil(ssp) // all remaining service safePoints should be removed due to expiration ssp, err = storage.LoadMinServiceSafePoint(testKeySpace, currentTime.Add(500*time.Second)) - c.Assert(err, IsNil) - c.Assert(ssp, IsNil) - c.Assert(failpoint.Disable("github.com/tikv/pd/server/storage/endpoint/removeExpiredKeys"), IsNil) + re.NoError(err) + re.Nil(ssp) + re.NoError(failpoint.Disable("github.com/tikv/pd/server/storage/endpoint/removeExpiredKeys")) } -func (s *testStorageGCSuite) TestRemoveServiceSafePoint(c *C) { +func TestRemoveServiceSafePoint(t *testing.T) { + re := require.New(t) storage := NewStorageWithMemoryBackend() testSpaceID, testSafePoints := testServiceSafePoints() // save service safe points for i := range testSpaceID { - c.Assert(storage.SaveServiceSafePoint(testSpaceID[i], testSafePoints[i]), IsNil) + re.NoError(storage.SaveServiceSafePoint(testSpaceID[i], testSafePoints[i])) } // remove saved service safe points for i := range testSpaceID { - c.Assert(storage.RemoveServiceSafePoint(testSpaceID[i], testSafePoints[i].ServiceID), IsNil) + re.NoError(storage.RemoveServiceSafePoint(testSpaceID[i], testSafePoints[i].ServiceID)) } // check that service safe points are empty for i := range testSpaceID { loadedSafePoint, err := storage.LoadServiceSafePoint(testSpaceID[i], testSafePoints[i].ServiceID) - c.Assert(err, IsNil) - c.Assert(loadedSafePoint, IsNil) + re.NoError(err) + re.Nil(loadedSafePoint) } } -func (s *testStorageGCSuite) TestSaveLoadGCSafePoint(c *C) { +func TestSaveLoadGCSafePoint(t *testing.T) { + re := require.New(t) storage := NewStorageWithMemoryBackend() testSpaceIDs, testSafePoints := testGCSafePoints() for i := range testSpaceIDs { testSpaceID := testSpaceIDs[i] testSafePoint := testSafePoints[i] err := storage.SaveKeySpaceGCSafePoint(testSpaceID, testSafePoint) - c.Assert(err, IsNil) + re.NoError(err) loaded, err := storage.LoadKeySpaceGCSafePoint(testSpaceID) - c.Assert(err, IsNil) - c.Assert(loaded, Equals, testSafePoint) + re.NoError(err) + re.Equal(testSafePoint, loaded) } } -func (s *testStorageGCSuite) TestLoadAllKeySpaceGCSafePoints(c *C) { +func TestLoadAllKeySpaceGCSafePoints(t *testing.T) { + re := require.New(t) storage := NewStorageWithMemoryBackend() testSpaceIDs, testSafePoints := testGCSafePoints() for i := range testSpaceIDs { err := storage.SaveKeySpaceGCSafePoint(testSpaceIDs[i], testSafePoints[i]) - c.Assert(err, IsNil) + re.NoError(err) } loadedSafePoints, err := storage.LoadAllKeySpaceGCSafePoints(true) - c.Assert(err, IsNil) + re.NoError(err) for i := range loadedSafePoints { - c.Assert(loadedSafePoints[i].SpaceID, Equals, testSpaceIDs[i]) - c.Assert(loadedSafePoints[i].SafePoint, Equals, testSafePoints[i]) + re.Equal(testSpaceIDs[i], loadedSafePoints[i].SpaceID) + re.Equal(testSafePoints[i], loadedSafePoints[i].SafePoint) } // saving some service safe points. spaceIDs, safePoints := testServiceSafePoints() for i := range spaceIDs { - c.Assert(storage.SaveServiceSafePoint(spaceIDs[i], safePoints[i]), IsNil) + re.NoError(storage.SaveServiceSafePoint(spaceIDs[i], safePoints[i])) } // verify that service safe points do not interfere with gc safe points. loadedSafePoints, err = storage.LoadAllKeySpaceGCSafePoints(true) - c.Assert(err, IsNil) + re.NoError(err) for i := range loadedSafePoints { - c.Assert(loadedSafePoints[i].SpaceID, Equals, testSpaceIDs[i]) - c.Assert(loadedSafePoints[i].SafePoint, Equals, testSafePoints[i]) + re.Equal(testSpaceIDs[i], loadedSafePoints[i].SpaceID) + re.Equal(testSafePoints[i], loadedSafePoints[i].SafePoint) } // verify that when withGCSafePoint set to false, returned safePoints is 0 loadedSafePoints, err = storage.LoadAllKeySpaceGCSafePoints(false) - c.Assert(err, IsNil) + re.NoError(err) for i := range loadedSafePoints { - c.Assert(loadedSafePoints[i].SpaceID, Equals, testSpaceIDs[i]) - c.Assert(loadedSafePoints[i].SafePoint, Equals, uint64(0)) + re.Equal(testSpaceIDs[i], loadedSafePoints[i].SpaceID) + re.Equal(uint64(0), loadedSafePoints[i].SafePoint) } } -func (s *testStorageGCSuite) TestLoadEmpty(c *C) { +func TestLoadEmpty(t *testing.T) { + re := require.New(t) storage := NewStorageWithMemoryBackend() // loading non-existing GC safepoint should return 0 gcSafePoint, err := storage.LoadKeySpaceGCSafePoint("testKeySpace") - c.Assert(err, IsNil) - c.Assert(gcSafePoint, Equals, uint64(0)) + re.NoError(err) + re.Equal(uint64(0), gcSafePoint) // loading non-existing service safepoint should return nil serviceSafePoint, err := storage.LoadServiceSafePoint("testKeySpace", "testService") - c.Assert(err, IsNil) - c.Assert(serviceSafePoint, IsNil) + re.NoError(err) + re.Nil(serviceSafePoint) // loading empty key spaces should return empty slices safePoints, err := storage.LoadAllKeySpaceGCSafePoints(true) - c.Assert(err, IsNil) - c.Assert(safePoints, HasLen, 0) + re.NoError(err) + re.Len(safePoints, 0) } diff --git a/server/storage/storage_test.go b/server/storage/storage_test.go index 51870a62133..3310c10de74 100644 --- a/server/storage/storage_test.go +++ b/server/storage/storage_test.go @@ -23,69 +23,61 @@ import ( "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/server/core" "github.com/tikv/pd/server/storage/endpoint" "go.etcd.io/etcd/clientv3" ) -func TestStorage(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testStorageSuite{}) - -type testStorageSuite struct { -} - -func (s *testStorageSuite) TestBasic(c *C) { +func TestBasic(t *testing.T) { + re := require.New(t) storage := NewStorageWithMemoryBackend() - c.Assert(endpoint.StorePath(123), Equals, "raft/s/00000000000000000123") - c.Assert(endpoint.RegionPath(123), Equals, "raft/r/00000000000000000123") + re.Equal("raft/s/00000000000000000123", endpoint.StorePath(123)) + re.Equal("raft/r/00000000000000000123", endpoint.RegionPath(123)) meta := &metapb.Cluster{Id: 123} ok, err := storage.LoadMeta(meta) - c.Assert(ok, IsFalse) - c.Assert(err, IsNil) - c.Assert(storage.SaveMeta(meta), IsNil) + re.False(ok) + re.NoError(err) + re.NoError(storage.SaveMeta(meta)) newMeta := &metapb.Cluster{} ok, err = storage.LoadMeta(newMeta) - c.Assert(ok, IsTrue) - c.Assert(err, IsNil) - c.Assert(newMeta, DeepEquals, meta) + re.True(ok) + re.NoError(err) + re.Equal(meta, newMeta) store := &metapb.Store{Id: 123} ok, err = storage.LoadStore(123, store) - c.Assert(ok, IsFalse) - c.Assert(err, IsNil) - c.Assert(storage.SaveStore(store), IsNil) + re.False(ok) + re.NoError(err) + re.NoError(storage.SaveStore(store)) newStore := &metapb.Store{} ok, err = storage.LoadStore(123, newStore) - c.Assert(ok, IsTrue) - c.Assert(err, IsNil) - c.Assert(newStore, DeepEquals, store) + re.True(ok) + re.NoError(err) + re.Equal(store, newStore) region := &metapb.Region{Id: 123} ok, err = storage.LoadRegion(123, region) - c.Assert(ok, IsFalse) - c.Assert(err, IsNil) - c.Assert(storage.SaveRegion(region), IsNil) + re.False(ok) + re.NoError(err) + re.NoError(storage.SaveRegion(region)) newRegion := &metapb.Region{} ok, err = storage.LoadRegion(123, newRegion) - c.Assert(ok, IsTrue) - c.Assert(err, IsNil) - c.Assert(newRegion, DeepEquals, region) + re.True(ok) + re.NoError(err) + re.Equal(region, newRegion) err = storage.DeleteRegion(region) - c.Assert(err, IsNil) + re.NoError(err) ok, err = storage.LoadRegion(123, newRegion) - c.Assert(ok, IsFalse) - c.Assert(err, IsNil) + re.False(ok) + re.NoError(err) } -func mustSaveStores(c *C, s Storage, n int) []*metapb.Store { +func mustSaveStores(re *require.Assertions, s Storage, n int) []*metapb.Store { stores := make([]*metapb.Store, 0, n) for i := 0; i < n; i++ { store := &metapb.Store{Id: uint64(i)} @@ -93,60 +85,64 @@ func mustSaveStores(c *C, s Storage, n int) []*metapb.Store { } for _, store := range stores { - c.Assert(s.SaveStore(store), IsNil) + re.NoError(s.SaveStore(store)) } return stores } -func (s *testStorageSuite) TestLoadStores(c *C) { +func TestLoadStores(t *testing.T) { + re := require.New(t) storage := NewStorageWithMemoryBackend() cache := core.NewStoresInfo() n := 10 - stores := mustSaveStores(c, storage, n) - c.Assert(storage.LoadStores(cache.SetStore), IsNil) + stores := mustSaveStores(re, storage, n) + re.NoError(storage.LoadStores(cache.SetStore)) - c.Assert(cache.GetStoreCount(), Equals, n) + re.Equal(n, cache.GetStoreCount()) for _, store := range cache.GetMetaStores() { - c.Assert(store, DeepEquals, stores[store.GetId()]) + re.Equal(stores[store.GetId()], store) } } -func (s *testStorageSuite) TestStoreWeight(c *C) { +func TestStoreWeight(t *testing.T) { + re := require.New(t) storage := NewStorageWithMemoryBackend() cache := core.NewStoresInfo() const n = 3 - mustSaveStores(c, storage, n) - c.Assert(storage.SaveStoreWeight(1, 2.0, 3.0), IsNil) - c.Assert(storage.SaveStoreWeight(2, 0.2, 0.3), IsNil) - c.Assert(storage.LoadStores(cache.SetStore), IsNil) + mustSaveStores(re, storage, n) + re.NoError(storage.SaveStoreWeight(1, 2.0, 3.0)) + re.NoError(storage.SaveStoreWeight(2, 0.2, 0.3)) + re.NoError(storage.LoadStores(cache.SetStore)) leaderWeights := []float64{1.0, 2.0, 0.2} regionWeights := []float64{1.0, 3.0, 0.3} for i := 0; i < n; i++ { - c.Assert(cache.GetStore(uint64(i)).GetLeaderWeight(), Equals, leaderWeights[i]) - c.Assert(cache.GetStore(uint64(i)).GetRegionWeight(), Equals, regionWeights[i]) + re.Equal(leaderWeights[i], cache.GetStore(uint64(i)).GetLeaderWeight()) + re.Equal(regionWeights[i], cache.GetStore(uint64(i)).GetRegionWeight()) } } -func (s *testStorageSuite) TestLoadGCSafePoint(c *C) { +func TestLoadGCSafePoint(t *testing.T) { + re := require.New(t) storage := NewStorageWithMemoryBackend() testData := []uint64{0, 1, 2, 233, 2333, 23333333333, math.MaxUint64} r, e := storage.LoadGCSafePoint() - c.Assert(r, Equals, uint64(0)) - c.Assert(e, IsNil) + re.Equal(uint64(0), r) + re.NoError(e) for _, safePoint := range testData { err := storage.SaveGCSafePoint(safePoint) - c.Assert(err, IsNil) + re.NoError(err) safePoint1, err := storage.LoadGCSafePoint() - c.Assert(err, IsNil) - c.Assert(safePoint, Equals, safePoint1) + re.NoError(err) + re.Equal(safePoint1, safePoint) } } -func (s *testStorageSuite) TestSaveServiceGCSafePoint(c *C) { +func TestSaveServiceGCSafePoint(t *testing.T) { + re := require.New(t) storage := NewStorageWithMemoryBackend() expireAt := time.Now().Add(100 * time.Second).Unix() serviceSafePoints := []*endpoint.ServiceSafePoint{ @@ -156,28 +152,29 @@ func (s *testStorageSuite) TestSaveServiceGCSafePoint(c *C) { } for _, ssp := range serviceSafePoints { - c.Assert(storage.SaveServiceGCSafePoint(ssp), IsNil) + re.NoError(storage.SaveServiceGCSafePoint(ssp)) } prefix := endpoint.GCSafePointServicePrefixPath() prefixEnd := clientv3.GetPrefixRangeEnd(prefix) keys, values, err := storage.LoadRange(prefix, prefixEnd, len(serviceSafePoints)) - c.Assert(err, IsNil) - c.Assert(keys, HasLen, 3) - c.Assert(values, HasLen, 3) + re.NoError(err) + re.Len(keys, 3) + re.Len(values, 3) ssp := &endpoint.ServiceSafePoint{} for i, key := range keys { - c.Assert(strings.HasSuffix(key, serviceSafePoints[i].ServiceID), IsTrue) + re.True(strings.HasSuffix(key, serviceSafePoints[i].ServiceID)) - c.Assert(json.Unmarshal([]byte(values[i]), ssp), IsNil) - c.Assert(ssp.ServiceID, Equals, serviceSafePoints[i].ServiceID) - c.Assert(ssp.ExpiredAt, Equals, serviceSafePoints[i].ExpiredAt) - c.Assert(ssp.SafePoint, Equals, serviceSafePoints[i].SafePoint) + re.NoError(json.Unmarshal([]byte(values[i]), ssp)) + re.Equal(serviceSafePoints[i].ServiceID, ssp.ServiceID) + re.Equal(serviceSafePoints[i].ExpiredAt, ssp.ExpiredAt) + re.Equal(serviceSafePoints[i].SafePoint, ssp.SafePoint) } } -func (s *testStorageSuite) TestLoadMinServiceGCSafePoint(c *C) { +func TestLoadMinServiceGCSafePoint(t *testing.T) { + re := require.New(t) storage := NewStorageWithMemoryBackend() expireAt := time.Now().Add(1000 * time.Second).Unix() serviceSafePoints := []*endpoint.ServiceSafePoint{ @@ -187,44 +184,45 @@ func (s *testStorageSuite) TestLoadMinServiceGCSafePoint(c *C) { } for _, ssp := range serviceSafePoints { - c.Assert(storage.SaveServiceGCSafePoint(ssp), IsNil) + re.NoError(storage.SaveServiceGCSafePoint(ssp)) } // gc_worker's safepoint will be automatically inserted when loading service safepoints. Here the returned // safepoint can be either of "gc_worker" or "2". ssp, err := storage.LoadMinServiceGCSafePoint(time.Now()) - c.Assert(err, IsNil) - c.Assert(ssp.SafePoint, Equals, uint64(2)) + re.NoError(err) + re.Equal(uint64(2), ssp.SafePoint) // Advance gc_worker's safepoint - c.Assert(storage.SaveServiceGCSafePoint(&endpoint.ServiceSafePoint{ + re.NoError(storage.SaveServiceGCSafePoint(&endpoint.ServiceSafePoint{ ServiceID: "gc_worker", ExpiredAt: math.MaxInt64, SafePoint: 10, - }), IsNil) + })) ssp, err = storage.LoadMinServiceGCSafePoint(time.Now()) - c.Assert(err, IsNil) - c.Assert(ssp.ServiceID, Equals, "2") - c.Assert(ssp.ExpiredAt, Equals, expireAt) - c.Assert(ssp.SafePoint, Equals, uint64(2)) + re.NoError(err) + re.Equal("2", ssp.ServiceID) + re.Equal(expireAt, ssp.ExpiredAt) + re.Equal(uint64(2), ssp.SafePoint) } -func (s *testStorageSuite) TestLoadRegions(c *C) { +func TestLoadRegions(t *testing.T) { + re := require.New(t) storage := NewStorageWithMemoryBackend() cache := core.NewRegionsInfo() n := 10 - regions := mustSaveRegions(c, storage, n) - c.Assert(storage.LoadRegions(context.Background(), cache.SetRegion), IsNil) + regions := mustSaveRegions(re, storage, n) + re.NoError(storage.LoadRegions(context.Background(), cache.SetRegion)) - c.Assert(cache.GetRegionCount(), Equals, n) + re.Equal(n, cache.GetRegionCount()) for _, region := range cache.GetMetaRegions() { - c.Assert(region, DeepEquals, regions[region.GetId()]) + re.Equal(regions[region.GetId()], region) } } -func mustSaveRegions(c *C, s endpoint.RegionStorage, n int) []*metapb.Region { +func mustSaveRegions(re *require.Assertions, s endpoint.RegionStorage, n int) []*metapb.Region { regions := make([]*metapb.Region, 0, n) for i := 0; i < n; i++ { region := newTestRegionMeta(uint64(i)) @@ -232,7 +230,7 @@ func mustSaveRegions(c *C, s endpoint.RegionStorage, n int) []*metapb.Region { } for _, region := range regions { - c.Assert(s.SaveRegion(region), IsNil) + re.NoError(s.SaveRegion(region)) } return regions @@ -246,36 +244,38 @@ func newTestRegionMeta(regionID uint64) *metapb.Region { } } -func (s *testStorageSuite) TestLoadRegionsToCache(c *C) { +func TestLoadRegionsToCache(t *testing.T) { + re := require.New(t) storage := NewStorageWithMemoryBackend() cache := core.NewRegionsInfo() n := 10 - regions := mustSaveRegions(c, storage, n) - c.Assert(storage.LoadRegionsOnce(context.Background(), cache.SetRegion), IsNil) + regions := mustSaveRegions(re, storage, n) + re.NoError(storage.LoadRegionsOnce(context.Background(), cache.SetRegion)) - c.Assert(cache.GetRegionCount(), Equals, n) + re.Equal(n, cache.GetRegionCount()) for _, region := range cache.GetMetaRegions() { - c.Assert(region, DeepEquals, regions[region.GetId()]) + re.Equal(regions[region.GetId()], region) } n = 20 - mustSaveRegions(c, storage, n) - c.Assert(storage.LoadRegionsOnce(context.Background(), cache.SetRegion), IsNil) - c.Assert(cache.GetRegionCount(), Equals, n) + mustSaveRegions(re, storage, n) + re.NoError(storage.LoadRegionsOnce(context.Background(), cache.SetRegion)) + re.Equal(n, cache.GetRegionCount()) } -func (s *testStorageSuite) TestLoadRegionsExceedRangeLimit(c *C) { - c.Assert(failpoint.Enable("github.com/tikv/pd/server/storage/kv/withRangeLimit", "return(500)"), IsNil) +func TestLoadRegionsExceedRangeLimit(t *testing.T) { + re := require.New(t) + re.NoError(failpoint.Enable("github.com/tikv/pd/server/storage/kv/withRangeLimit", "return(500)")) storage := NewStorageWithMemoryBackend() cache := core.NewRegionsInfo() n := 1000 - regions := mustSaveRegions(c, storage, n) - c.Assert(storage.LoadRegions(context.Background(), cache.SetRegion), IsNil) - c.Assert(cache.GetRegionCount(), Equals, n) + regions := mustSaveRegions(re, storage, n) + re.NoError(storage.LoadRegions(context.Background(), cache.SetRegion)) + re.Equal(n, cache.GetRegionCount()) for _, region := range cache.GetMetaRegions() { - c.Assert(region, DeepEquals, regions[region.GetId()]) + re.Equal(regions[region.GetId()], region) } - c.Assert(failpoint.Disable("github.com/tikv/pd/server/storage/kv/withRangeLimit"), IsNil) + re.NoError(failpoint.Disable("github.com/tikv/pd/server/storage/kv/withRangeLimit")) } From 6807f6e401b8170491dff61a8cda1b336739115e Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Fri, 10 Jun 2022 14:46:30 +0800 Subject: [PATCH 36/82] syncer: migrate test framework to testify (#5141) ref tikv/pd#4813 Signed-off-by: Ryan Leung Co-authored-by: Ti Chi Robot --- server/region_syncer/client_test.go | 22 ++++---- server/region_syncer/history_buffer_test.go | 59 +++++++++------------ 2 files changed, 36 insertions(+), 45 deletions(-) diff --git a/server/region_syncer/client_test.go b/server/region_syncer/client_test.go index 1f10af778ba..ca39cee4859 100644 --- a/server/region_syncer/client_test.go +++ b/server/region_syncer/client_test.go @@ -17,28 +17,26 @@ package syncer import ( "context" "os" + "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/grpcutil" "github.com/tikv/pd/server/core" "github.com/tikv/pd/server/storage" ) -var _ = Suite(&testClientSuite{}) - -type testClientSuite struct{} - // For issue https://github.com/tikv/pd/issues/3936 -func (t *testClientSuite) TestLoadRegion(c *C) { +func TestLoadRegion(t *testing.T) { + re := require.New(t) tempDir, err := os.MkdirTemp(os.TempDir(), "region_syncer_load_region") - c.Assert(err, IsNil) + re.NoError(err) defer os.RemoveAll(tempDir) rs, err := storage.NewStorageWithLevelDBBackend(context.Background(), tempDir, nil) - c.Assert(err, IsNil) + re.NoError(err) server := &mockServer{ ctx: context.Background(), @@ -48,9 +46,9 @@ func (t *testClientSuite) TestLoadRegion(c *C) { for i := 0; i < 30; i++ { rs.SaveRegion(&metapb.Region{Id: uint64(i) + 1}) } - c.Assert(failpoint.Enable("github.com/tikv/pd/server/storage/base_backend/slowLoadRegion", "return(true)"), IsNil) + re.NoError(failpoint.Enable("github.com/tikv/pd/server/storage/base_backend/slowLoadRegion", "return(true)")) defer func() { - c.Assert(failpoint.Disable("github.com/tikv/pd/server/storage/base_backend/slowLoadRegion"), IsNil) + re.NoError(failpoint.Disable("github.com/tikv/pd/server/storage/base_backend/slowLoadRegion")) }() rc := NewRegionSyncer(server) @@ -58,8 +56,8 @@ func (t *testClientSuite) TestLoadRegion(c *C) { rc.StartSyncWithLeader("") time.Sleep(time.Second) rc.StopSyncWithLeader() - c.Assert(time.Since(start), Greater, time.Second) // make sure failpoint is injected - c.Assert(time.Since(start), Less, time.Second*2) + re.Greater(time.Since(start), time.Second) // make sure failpoint is injected + re.Less(time.Since(start), time.Second*2) } type mockServer struct { diff --git a/server/region_syncer/history_buffer_test.go b/server/region_syncer/history_buffer_test.go index 47fa6b66f8f..49cbebdf266 100644 --- a/server/region_syncer/history_buffer_test.go +++ b/server/region_syncer/history_buffer_test.go @@ -17,21 +17,14 @@ package syncer import ( "testing" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/server/core" "github.com/tikv/pd/server/storage/kv" ) -var _ = Suite(&testHistoryBuffer{}) - -type testHistoryBuffer struct{} - -func Test(t *testing.T) { - TestingT(t) -} - -func (t *testHistoryBuffer) TestBufferSize(c *C) { +func TestBufferSize(t *testing.T) { + re := require.New(t) var regions []*core.RegionInfo for i := 0; i <= 100; i++ { regions = append(regions, core.NewRegionInfo(&metapb.Region{Id: uint64(i)}, nil)) @@ -39,23 +32,23 @@ func (t *testHistoryBuffer) TestBufferSize(c *C) { // size equals 1 h := newHistoryBuffer(1, kv.NewMemoryKV()) - c.Assert(h.len(), Equals, 0) + re.Equal(0, h.len()) for _, r := range regions { h.Record(r) } - c.Assert(h.len(), Equals, 1) - c.Assert(h.get(100), Equals, regions[h.nextIndex()-1]) - c.Assert(h.get(99), IsNil) + re.Equal(1, h.len()) + re.Equal(regions[h.nextIndex()-1], h.get(100)) + re.Nil(h.get(99)) // size equals 2 h = newHistoryBuffer(2, kv.NewMemoryKV()) for _, r := range regions { h.Record(r) } - c.Assert(h.len(), Equals, 2) - c.Assert(h.get(100), Equals, regions[h.nextIndex()-1]) - c.Assert(h.get(99), Equals, regions[h.nextIndex()-2]) - c.Assert(h.get(98), IsNil) + re.Equal(2, h.len()) + re.Equal(regions[h.nextIndex()-1], h.get(100)) + re.Equal(regions[h.nextIndex()-2], h.get(99)) + re.Nil(h.get(98)) // size equals 100 kvMem := kv.NewMemoryKV() @@ -63,33 +56,33 @@ func (t *testHistoryBuffer) TestBufferSize(c *C) { for i := 0; i < 6; i++ { h1.Record(regions[i]) } - c.Assert(h1.len(), Equals, 6) - c.Assert(h1.nextIndex(), Equals, uint64(6)) + re.Equal(6, h1.len()) + re.Equal(uint64(6), h1.nextIndex()) h1.persist() // restart the buffer h2 := newHistoryBuffer(100, kvMem) - c.Assert(h2.nextIndex(), Equals, uint64(6)) - c.Assert(h2.firstIndex(), Equals, uint64(6)) - c.Assert(h2.get(h.nextIndex()-1), IsNil) - c.Assert(h2.len(), Equals, 0) + re.Equal(uint64(6), h2.nextIndex()) + re.Equal(uint64(6), h2.firstIndex()) + re.Nil(h2.get(h.nextIndex() - 1)) + re.Equal(0, h2.len()) for _, r := range regions { index := h2.nextIndex() h2.Record(r) - c.Assert(h2.get(index), Equals, r) + re.Equal(r, h2.get(index)) } - c.Assert(h2.nextIndex(), Equals, uint64(107)) - c.Assert(h2.get(h2.nextIndex()), IsNil) + re.Equal(uint64(107), h2.nextIndex()) + re.Nil(h2.get(h2.nextIndex())) s, err := h2.kv.Load(historyKey) - c.Assert(err, IsNil) + re.NoError(err) // flush in index 106 - c.Assert(s, Equals, "106") + re.Equal("106", s) histories := h2.RecordsFrom(uint64(1)) - c.Assert(histories, HasLen, 0) + re.Len(histories, 0) histories = h2.RecordsFrom(h2.firstIndex()) - c.Assert(histories, HasLen, 100) - c.Assert(h2.firstIndex(), Equals, uint64(7)) - c.Assert(histories, DeepEquals, regions[1:]) + re.Len(histories, 100) + re.Equal(uint64(7), h2.firstIndex()) + re.Equal(regions[1:], histories) } From 0015a5b51efab6f17d2235b11e39290c5075e164 Mon Sep 17 00:00:00 2001 From: disksing Date: Sat, 11 Jun 2022 00:22:31 +0800 Subject: [PATCH 37/82] dr-autosync: cleanup configurations (#5106) ref tikv/pd#4399 Signed-off-by: disksing --- server/api/admin.go | 22 ----- server/api/router.go | 1 - server/config/config.go | 29 +++--- server/config/config_test.go | 1 - server/replication/replication_mode.go | 44 ++------- server/replication/replication_mode_test.go | 102 +++++++++----------- tests/pdctl/config/config_test.go | 5 +- 7 files changed, 66 insertions(+), 138 deletions(-) diff --git a/server/api/admin.go b/server/api/admin.go index 36419234fc9..2954874d7fd 100644 --- a/server/api/admin.go +++ b/server/api/admin.go @@ -116,25 +116,3 @@ func (h *adminHandler) SavePersistFile(w http.ResponseWriter, r *http.Request) { } h.rd.Text(w, http.StatusOK, "") } - -// Intentionally no swagger mark as it is supposed to be only used in -// server-to-server. -func (h *adminHandler) UpdateWaitAsyncTime(w http.ResponseWriter, r *http.Request) { - var input map[string]interface{} - if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &input); err != nil { - return - } - memberIDValue, ok := input["member_id"].(string) - if !ok || len(memberIDValue) == 0 { - h.rd.JSON(w, http.StatusBadRequest, "invalid member id") - return - } - memberID, err := strconv.ParseUint(memberIDValue, 10, 64) - if err != nil { - h.rd.JSON(w, http.StatusBadRequest, "invalid member id") - return - } - cluster := getCluster(r) - cluster.GetReplicationMode().UpdateMemberWaitAsyncTime(memberID) - h.rd.JSON(w, http.StatusOK, nil) -} diff --git a/server/api/router.go b/server/api/router.go index 98be613a97e..e755341ebef 100644 --- a/server/api/router.go +++ b/server/api/router.go @@ -281,7 +281,6 @@ func createRouter(prefix string, svr *server.Server) *mux.Router { registerFunc(clusterRouter, "/admin/cache/region/{id}", adminHandler.DeleteRegionCache, setMethods("DELETE"), setAuditBackend(localLog)) registerFunc(clusterRouter, "/admin/reset-ts", adminHandler.ResetTS, setMethods("POST"), setAuditBackend(localLog)) registerFunc(apiRouter, "/admin/persist-file/{file_name}", adminHandler.SavePersistFile, setMethods("POST"), setAuditBackend(localLog)) - registerFunc(clusterRouter, "/admin/replication_mode/wait-async", adminHandler.UpdateWaitAsyncTime, setMethods("POST"), setAuditBackend(localLog)) serviceMiddlewareHandler := newServiceMiddlewareHandler(svr, rd) registerFunc(apiRouter, "/service-middleware/config", serviceMiddlewareHandler.GetServiceMiddlewareConfig, setMethods("GET")) diff --git a/server/config/config.go b/server/config/config.go index 384c7108816..df833594f74 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -245,9 +245,8 @@ const ( defaultDashboardAddress = "auto" - defaultDRWaitStoreTimeout = time.Minute - defaultDRWaitSyncTimeout = time.Minute - defaultDRWaitAsyncTimeout = 2 * time.Minute + defaultDRWaitStoreTimeout = time.Minute + defaultDRTiKVSyncTimeoutHint = time.Minute defaultTSOSaveInterval = time.Duration(defaultLeaderLease) * time.Second // DefaultTSOUpdatePhysicalInterval is the default value of the config `TSOUpdatePhysicalInterval`. @@ -1389,26 +1388,22 @@ func NormalizeReplicationMode(m string) string { // DRAutoSyncReplicationConfig is the configuration for auto sync mode between 2 data centers. type DRAutoSyncReplicationConfig struct { - LabelKey string `toml:"label-key" json:"label-key"` - Primary string `toml:"primary" json:"primary"` - DR string `toml:"dr" json:"dr"` - PrimaryReplicas int `toml:"primary-replicas" json:"primary-replicas"` - DRReplicas int `toml:"dr-replicas" json:"dr-replicas"` - WaitStoreTimeout typeutil.Duration `toml:"wait-store-timeout" json:"wait-store-timeout"` - WaitSyncTimeout typeutil.Duration `toml:"wait-sync-timeout" json:"wait-sync-timeout"` - WaitAsyncTimeout typeutil.Duration `toml:"wait-async-timeout" json:"wait-async-timeout"` - PauseRegionSplit bool `toml:"pause-region-split" json:"pause-region-split,string"` + LabelKey string `toml:"label-key" json:"label-key"` + Primary string `toml:"primary" json:"primary"` + DR string `toml:"dr" json:"dr"` + PrimaryReplicas int `toml:"primary-replicas" json:"primary-replicas"` + DRReplicas int `toml:"dr-replicas" json:"dr-replicas"` + WaitStoreTimeout typeutil.Duration `toml:"wait-store-timeout" json:"wait-store-timeout"` + TiKVSyncTimeoutHint typeutil.Duration `toml:"tikv-sync-timeout-hint" json:"tikv-sync-timeout-hint"` + PauseRegionSplit bool `toml:"pause-region-split" json:"pause-region-split,string"` } func (c *DRAutoSyncReplicationConfig) adjust(meta *configMetaData) { if !meta.IsDefined("wait-store-timeout") { c.WaitStoreTimeout = typeutil.NewDuration(defaultDRWaitStoreTimeout) } - if !meta.IsDefined("wait-sync-timeout") { - c.WaitSyncTimeout = typeutil.NewDuration(defaultDRWaitSyncTimeout) - } - if !meta.IsDefined("wait-async-timeout") { - c.WaitAsyncTimeout = typeutil.NewDuration(defaultDRWaitAsyncTimeout) + if !meta.IsDefined("tikv-sync-timeout-hint") { + c.TiKVSyncTimeoutHint = typeutil.NewDuration(defaultDRTiKVSyncTimeoutHint) } } diff --git a/server/config/config_test.go b/server/config/config_test.go index 885e24d8d8b..032c0526739 100644 --- a/server/config/config_test.go +++ b/server/config/config_test.go @@ -454,7 +454,6 @@ wait-store-timeout = "120s" re.Equal(2, cfg.ReplicationMode.DRAutoSync.PrimaryReplicas) re.Equal(1, cfg.ReplicationMode.DRAutoSync.DRReplicas) re.Equal(2*time.Minute, cfg.ReplicationMode.DRAutoSync.WaitStoreTimeout.Duration) - re.Equal(time.Minute, cfg.ReplicationMode.DRAutoSync.WaitSyncTimeout.Duration) cfg = NewConfig() meta, err = toml.Decode("", &cfg) diff --git a/server/replication/replication_mode.go b/server/replication/replication_mode.go index c5fa078794c..bfd944eff92 100644 --- a/server/replication/replication_mode.go +++ b/server/replication/replication_mode.go @@ -86,19 +86,17 @@ type ModeManager struct { drSampleTotalRegion int // number of regions in sample drTotalRegion int // number of all regions - drMemberWaitAsyncTime map[uint64]time.Time // last sync time with follower nodes - drStoreStatus sync.Map + drStoreStatus sync.Map } // NewReplicationModeManager creates the replicate mode manager. func NewReplicationModeManager(config config.ReplicationModeConfig, storage endpoint.ReplicationStatusStorage, cluster schedule.Cluster, fileReplicater FileReplicater) (*ModeManager, error) { m := &ModeManager{ - initTime: time.Now(), - config: config, - storage: storage, - cluster: cluster, - fileReplicater: fileReplicater, - drMemberWaitAsyncTime: make(map[uint64]time.Time), + initTime: time.Now(), + config: config, + storage: storage, + cluster: cluster, + fileReplicater: fileReplicater, } switch config.ReplicationMode { case modeMajority: @@ -129,15 +127,6 @@ func (m *ModeManager) UpdateConfig(config config.ReplicationModeConfig) error { return nil } -// UpdateMemberWaitAsyncTime updates a member's wait async time. -func (m *ModeManager) UpdateMemberWaitAsyncTime(memberID uint64) { - m.Lock() - defer m.Unlock() - t := time.Now() - log.Info("udpate member wait async time", zap.Uint64("memberID", memberID), zap.Time("time", t)) - m.drMemberWaitAsyncTime[memberID] = t -} - // GetReplicationStatus returns the status to sync with tikv servers. func (m *ModeManager) GetReplicationStatus() *pb.ReplicationStatus { m.RLock() @@ -153,7 +142,7 @@ func (m *ModeManager) GetReplicationStatus() *pb.ReplicationStatus { LabelKey: m.config.DRAutoSync.LabelKey, State: pb.DRAutoSyncState(pb.DRAutoSyncState_value[strings.ToUpper(m.drAutoSync.State)]), StateId: m.drAutoSync.StateID, - WaitSyncTimeoutHint: int32(m.config.DRAutoSync.WaitSyncTimeout.Seconds()), + WaitSyncTimeoutHint: int32(m.config.DRAutoSync.TiKVSyncTimeoutHint.Seconds()), AvailableStores: m.drAutoSync.AvailableStores, PauseRegionSplit: m.config.DRAutoSync.PauseRegionSplit && m.drAutoSync.State != drStateSync, } @@ -239,23 +228,6 @@ func (m *ModeManager) loadDRAutoSync() error { return nil } -func (m *ModeManager) drCheckAsyncTimeout() bool { - m.RLock() - defer m.RUnlock() - timeout := m.config.DRAutoSync.WaitAsyncTimeout.Duration - if timeout == 0 { - return true - } - // make sure all members are timeout. - for _, t := range m.drMemberWaitAsyncTime { - if time.Since(t) <= timeout { - return false - } - } - // make sure all members that have synced with previous leader are timeout. - return time.Since(m.initTime) > timeout -} - func (m *ModeManager) drSwitchToAsyncWait(availableStores []uint64) error { m.Lock() defer m.Unlock() @@ -471,7 +443,7 @@ func (m *ModeManager) tickDR() { switch m.drGetState() { case drStateSync: // If hasMajority is false, the cluster is always unavailable. Switch to async won't help. - if !canSync && hasMajority && m.drCheckAsyncTimeout() { + if !canSync && hasMajority { m.drSwitchToAsyncWait(stores[primaryUp]) } case drStateAsyncWait: diff --git a/server/replication/replication_mode_test.go b/server/replication/replication_mode_test.go index 8162da599ff..ee478d4ce6f 100644 --- a/server/replication/replication_mode_test.go +++ b/server/replication/replication_mode_test.go @@ -43,13 +43,13 @@ func TestInitial(t *testing.T) { re.Equal(&pb.ReplicationStatus{Mode: pb.ReplicationMode_MAJORITY}, rep.GetReplicationStatus()) conf = config.ReplicationModeConfig{ReplicationMode: modeDRAutoSync, DRAutoSync: config.DRAutoSyncReplicationConfig{ - LabelKey: "dr-label", - Primary: "l1", - DR: "l2", - PrimaryReplicas: 2, - DRReplicas: 1, - WaitStoreTimeout: typeutil.Duration{Duration: time.Minute}, - WaitSyncTimeout: typeutil.Duration{Duration: time.Minute}, + LabelKey: "dr-label", + Primary: "l1", + DR: "l2", + PrimaryReplicas: 2, + DRReplicas: 1, + WaitStoreTimeout: typeutil.Duration{Duration: time.Minute}, + TiKVSyncTimeoutHint: typeutil.Duration{Duration: time.Minute}, }} rep, err = NewReplicationModeManager(conf, store, cluster, newMockReplicator([]uint64{1})) re.NoError(err) @@ -70,8 +70,8 @@ func TestStatus(t *testing.T) { defer cancel() store := storage.NewStorageWithMemoryBackend() conf := config.ReplicationModeConfig{ReplicationMode: modeDRAutoSync, DRAutoSync: config.DRAutoSyncReplicationConfig{ - LabelKey: "dr-label", - WaitSyncTimeout: typeutil.Duration{Duration: time.Minute}, + LabelKey: "dr-label", + TiKVSyncTimeoutHint: typeutil.Duration{Duration: time.Minute}, }} cluster := mockcluster.NewCluster(ctx, config.NewTestOptions()) rep, err := NewReplicationModeManager(conf, store, cluster, newMockReplicator([]uint64{1})) @@ -165,13 +165,13 @@ func TestStateSwitch(t *testing.T) { defer cancel() store := storage.NewStorageWithMemoryBackend() conf := config.ReplicationModeConfig{ReplicationMode: modeDRAutoSync, DRAutoSync: config.DRAutoSyncReplicationConfig{ - LabelKey: "zone", - Primary: "zone1", - DR: "zone2", - PrimaryReplicas: 4, - DRReplicas: 1, - WaitStoreTimeout: typeutil.Duration{Duration: time.Minute}, - WaitSyncTimeout: typeutil.Duration{Duration: time.Minute}, + LabelKey: "zone", + Primary: "zone1", + DR: "zone2", + PrimaryReplicas: 4, + DRReplicas: 1, + WaitStoreTimeout: typeutil.Duration{Duration: time.Minute}, + TiKVSyncTimeoutHint: typeutil.Duration{Duration: time.Minute}, }} cluster := mockcluster.NewCluster(ctx, config.NewTestOptions()) replicator := newMockReplicator([]uint64{1}) @@ -352,13 +352,13 @@ func TestReplicateState(t *testing.T) { defer cancel() store := storage.NewStorageWithMemoryBackend() conf := config.ReplicationModeConfig{ReplicationMode: modeDRAutoSync, DRAutoSync: config.DRAutoSyncReplicationConfig{ - LabelKey: "zone", - Primary: "zone1", - DR: "zone2", - PrimaryReplicas: 2, - DRReplicas: 1, - WaitStoreTimeout: typeutil.Duration{Duration: time.Minute}, - WaitSyncTimeout: typeutil.Duration{Duration: time.Minute}, + LabelKey: "zone", + Primary: "zone1", + DR: "zone2", + PrimaryReplicas: 2, + DRReplicas: 1, + WaitStoreTimeout: typeutil.Duration{Duration: time.Minute}, + TiKVSyncTimeoutHint: typeutil.Duration{Duration: time.Minute}, }} cluster := mockcluster.NewCluster(ctx, config.NewTestOptions()) replicator := newMockReplicator([]uint64{1}) @@ -395,14 +395,13 @@ func TestAsynctimeout(t *testing.T) { defer cancel() store := storage.NewStorageWithMemoryBackend() conf := config.ReplicationModeConfig{ReplicationMode: modeDRAutoSync, DRAutoSync: config.DRAutoSyncReplicationConfig{ - LabelKey: "zone", - Primary: "zone1", - DR: "zone2", - PrimaryReplicas: 2, - DRReplicas: 1, - WaitStoreTimeout: typeutil.Duration{Duration: time.Minute}, - WaitSyncTimeout: typeutil.Duration{Duration: time.Minute}, - WaitAsyncTimeout: typeutil.Duration{Duration: 2 * time.Minute}, + LabelKey: "zone", + Primary: "zone1", + DR: "zone2", + PrimaryReplicas: 2, + DRReplicas: 1, + WaitStoreTimeout: typeutil.Duration{Duration: time.Minute}, + TiKVSyncTimeoutHint: typeutil.Duration{Duration: time.Minute}, }} cluster := mockcluster.NewCluster(ctx, config.NewTestOptions()) var replicator mockFileReplicator @@ -415,19 +414,6 @@ func TestAsynctimeout(t *testing.T) { setStoreState(cluster, "up", "up", "down") rep.tickDR() - re.Equal(drStateSync, rep.drGetState()) // cannot switch state due to recently start - - rep.initTime = time.Now().Add(-3 * time.Minute) - rep.tickDR() - re.Equal(drStateAsyncWait, rep.drGetState()) - - rep.drSwitchToSync() - rep.UpdateMemberWaitAsyncTime(42) - rep.tickDR() - re.Equal(drStateSync, rep.drGetState()) // cannot switch state due to member not timeout - - rep.drMemberWaitAsyncTime[42] = time.Now().Add(-3 * time.Minute) - rep.tickDR() re.Equal(drStateAsyncWait, rep.drGetState()) } @@ -453,13 +439,13 @@ func TestRecoverProgress(t *testing.T) { store := storage.NewStorageWithMemoryBackend() conf := config.ReplicationModeConfig{ReplicationMode: modeDRAutoSync, DRAutoSync: config.DRAutoSyncReplicationConfig{ - LabelKey: "zone", - Primary: "zone1", - DR: "zone2", - PrimaryReplicas: 2, - DRReplicas: 1, - WaitStoreTimeout: typeutil.Duration{Duration: time.Minute}, - WaitSyncTimeout: typeutil.Duration{Duration: time.Minute}, + LabelKey: "zone", + Primary: "zone1", + DR: "zone2", + PrimaryReplicas: 2, + DRReplicas: 1, + WaitStoreTimeout: typeutil.Duration{Duration: time.Minute}, + TiKVSyncTimeoutHint: typeutil.Duration{Duration: time.Minute}, }} cluster := mockcluster.NewCluster(ctx, config.NewTestOptions()) cluster.AddLabelsStore(1, 1, map[string]string{}) @@ -516,13 +502,13 @@ func TestRecoverProgressWithSplitAndMerge(t *testing.T) { store := storage.NewStorageWithMemoryBackend() conf := config.ReplicationModeConfig{ReplicationMode: modeDRAutoSync, DRAutoSync: config.DRAutoSyncReplicationConfig{ - LabelKey: "zone", - Primary: "zone1", - DR: "zone2", - PrimaryReplicas: 2, - DRReplicas: 1, - WaitStoreTimeout: typeutil.Duration{Duration: time.Minute}, - WaitSyncTimeout: typeutil.Duration{Duration: time.Minute}, + LabelKey: "zone", + Primary: "zone1", + DR: "zone2", + PrimaryReplicas: 2, + DRReplicas: 1, + WaitStoreTimeout: typeutil.Duration{Duration: time.Minute}, + TiKVSyncTimeoutHint: typeutil.Duration{Duration: time.Minute}, }} cluster := mockcluster.NewCluster(ctx, config.NewTestOptions()) cluster.AddLabelsStore(1, 1, map[string]string{}) diff --git a/tests/pdctl/config/config_test.go b/tests/pdctl/config/config_test.go index ad99d583133..cb564699b53 100644 --- a/tests/pdctl/config/config_test.go +++ b/tests/pdctl/config/config_test.go @@ -566,9 +566,8 @@ func (s *configTestSuite) TestReplicationMode(c *C) { conf := config.ReplicationModeConfig{ ReplicationMode: "majority", DRAutoSync: config.DRAutoSyncReplicationConfig{ - WaitStoreTimeout: typeutil.NewDuration(time.Minute), - WaitSyncTimeout: typeutil.NewDuration(time.Minute), - WaitAsyncTimeout: typeutil.NewDuration(2 * time.Minute), + WaitStoreTimeout: typeutil.NewDuration(time.Minute), + TiKVSyncTimeoutHint: typeutil.NewDuration(time.Minute), }, } check := func() { From c8775b6176a03eb25cb1622fc46ba1d496872df7 Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Sat, 11 Jun 2022 16:02:31 +0800 Subject: [PATCH 38/82] encryptionkm: migrate test framework to testify (#5136) ref tikv/pd#4813 Signed-off-by: Ryan Leung Co-authored-by: Ti Chi Robot --- server/encryptionkm/key_manager_test.go | 576 ++++++++++++------------ 1 file changed, 292 insertions(+), 284 deletions(-) diff --git a/server/encryptionkm/key_manager_test.go b/server/encryptionkm/key_manager_test.go index e9fc39e5789..5e0d864942c 100644 --- a/server/encryptionkm/key_manager_test.go +++ b/server/encryptionkm/key_manager_test.go @@ -15,7 +15,6 @@ package encryptionkm import ( - "bytes" "context" "encoding/hex" "fmt" @@ -26,8 +25,8 @@ import ( "time" "github.com/gogo/protobuf/proto" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/encryptionpb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/encryption" "github.com/tikv/pd/pkg/etcdutil" "github.com/tikv/pd/pkg/tempurl" @@ -37,14 +36,6 @@ import ( "go.etcd.io/etcd/embed" ) -func TestKeyManager(t *testing.T) { - TestingT(t) -} - -type testKeyManagerSuite struct{} - -var _ = SerialSuites(&testKeyManagerSuite{}) - const ( testMasterKey = "8fd7e3e917c170d92f3e51a981dd7bc8fba11f3df7d8df994842f6e86f69b530" testMasterKey2 = "8fd7e3e917c170d92f3e51a981dd7bc8fba11f3df7d8df994842f6e86f69b531" @@ -57,29 +48,29 @@ func getTestDataKey() []byte { return key } -func newTestEtcd(c *C) (client *clientv3.Client, cleanup func()) { +func newTestEtcd(re *require.Assertions) (client *clientv3.Client, cleanup func()) { cfg := embed.NewConfig() cfg.Name = "test_etcd" cfg.Dir, _ = os.MkdirTemp("/tmp", "test_etcd") cfg.Logger = "zap" pu, err := url.Parse(tempurl.Alloc()) - c.Assert(err, IsNil) + re.NoError(err) cfg.LPUrls = []url.URL{*pu} cfg.APUrls = cfg.LPUrls cu, err := url.Parse(tempurl.Alloc()) - c.Assert(err, IsNil) + re.NoError(err) cfg.LCUrls = []url.URL{*cu} cfg.ACUrls = cfg.LCUrls cfg.InitialCluster = fmt.Sprintf("%s=%s", cfg.Name, &cfg.LPUrls[0]) cfg.ClusterState = embed.ClusterStateFlagNew server, err := embed.StartEtcd(cfg) - c.Assert(err, IsNil) + re.NoError(err) <-server.Server.ReadyNotify() client, err = clientv3.New(clientv3.Config{ Endpoints: []string{cfg.LCUrls[0].String()}, }) - c.Assert(err, IsNil) + re.NoError(err) cleanup = func() { client.Close() @@ -90,16 +81,16 @@ func newTestEtcd(c *C) (client *clientv3.Client, cleanup func()) { return client, cleanup } -func newTestKeyFile(c *C, key ...string) (keyFilePath string, cleanup func()) { +func newTestKeyFile(re *require.Assertions, key ...string) (keyFilePath string, cleanup func()) { testKey := testMasterKey for _, k := range key { testKey = k } tempDir, err := os.MkdirTemp("/tmp", "test_key_file") - c.Assert(err, IsNil) + re.NoError(err) keyFilePath = tempDir + "/key" err = os.WriteFile(keyFilePath, []byte(testKey), 0600) - c.Assert(err, IsNil) + re.NoError(err) cleanup = func() { os.RemoveAll(tempDir) @@ -108,53 +99,55 @@ func newTestKeyFile(c *C, key ...string) (keyFilePath string, cleanup func()) { return keyFilePath, cleanup } -func newTestLeader(c *C, client *clientv3.Client) *election.Leadership { +func newTestLeader(re *require.Assertions, client *clientv3.Client) *election.Leadership { leader := election.NewLeadership(client, "test_leader", "test") timeout := int64(30000000) // about a year. err := leader.Campaign(timeout, "") - c.Assert(err, IsNil) + re.NoError(err) return leader } -func checkMasterKeyMeta(c *C, value []byte, meta *encryptionpb.MasterKey, ciphertextKey []byte) { +func checkMasterKeyMeta(re *require.Assertions, value []byte, meta *encryptionpb.MasterKey, ciphertextKey []byte) { content := &encryptionpb.EncryptedContent{} err := content.Unmarshal(value) - c.Assert(err, IsNil) - c.Assert(proto.Equal(content.MasterKey, meta), IsTrue) - c.Assert(bytes.Equal(content.CiphertextKey, ciphertextKey), IsTrue) + re.NoError(err) + re.True(proto.Equal(content.MasterKey, meta)) + re.Equal(content.CiphertextKey, ciphertextKey) } -func (s *testKeyManagerSuite) TestNewKeyManagerBasic(c *C) { +func TestNewKeyManagerBasic(t *testing.T) { + re := require.New(t) // Initialize. - client, cleanupEtcd := newTestEtcd(c) + client, cleanupEtcd := newTestEtcd(re) defer cleanupEtcd() // Use default config. config := &encryption.Config{} err := config.Adjust() - c.Assert(err, IsNil) + re.NoError(err) // Create the key manager. m, err := NewKeyManager(client, config) - c.Assert(err, IsNil) + re.NoError(err) // Check config. - c.Assert(m.method, Equals, encryptionpb.EncryptionMethod_PLAINTEXT) - c.Assert(m.masterKeyMeta.GetPlaintext(), NotNil) + re.Equal(encryptionpb.EncryptionMethod_PLAINTEXT, m.method) + re.NotNil(m.masterKeyMeta.GetPlaintext()) // Check loaded keys. - c.Assert(m.keys.Load(), IsNil) + re.Nil(m.keys.Load()) // Check etcd KV. value, err := etcdutil.GetValue(client, EncryptionKeysPath) - c.Assert(err, IsNil) - c.Assert(value, IsNil) + re.NoError(err) + re.Nil(value) } -func (s *testKeyManagerSuite) TestNewKeyManagerWithCustomConfig(c *C) { +func TestNewKeyManagerWithCustomConfig(t *testing.T) { + re := require.New(t) // Initialize. - client, cleanupEtcd := newTestEtcd(c) + client, cleanupEtcd := newTestEtcd(re) defer cleanupEtcd() - keyFile, cleanupKeyFile := newTestKeyFile(c) + keyFile, cleanupKeyFile := newTestKeyFile(re) defer cleanupKeyFile() // Custom config rotatePeriod, err := time.ParseDuration("100h") - c.Assert(err, IsNil) + re.NoError(err) config := &encryption.Config{ DataEncryptionMethod: "aes128-ctr", DataKeyRotationPeriod: typeutil.NewDuration(rotatePeriod), @@ -166,36 +159,37 @@ func (s *testKeyManagerSuite) TestNewKeyManagerWithCustomConfig(c *C) { }, } err = config.Adjust() - c.Assert(err, IsNil) + re.NoError(err) // Create the key manager. m, err := NewKeyManager(client, config) - c.Assert(err, IsNil) + re.NoError(err) // Check config. - c.Assert(m.method, Equals, encryptionpb.EncryptionMethod_AES128_CTR) - c.Assert(m.dataKeyRotationPeriod, Equals, rotatePeriod) - c.Assert(m.masterKeyMeta, NotNil) + re.Equal(encryptionpb.EncryptionMethod_AES128_CTR, m.method) + re.Equal(rotatePeriod, m.dataKeyRotationPeriod) + re.NotNil(m.masterKeyMeta) keyFileMeta := m.masterKeyMeta.GetFile() - c.Assert(keyFileMeta, NotNil) - c.Assert(keyFileMeta.Path, Equals, config.MasterKey.MasterKeyFileConfig.FilePath) + re.NotNil(keyFileMeta) + re.Equal(config.MasterKey.MasterKeyFileConfig.FilePath, keyFileMeta.Path) // Check loaded keys. - c.Assert(m.keys.Load(), IsNil) + re.Nil(m.keys.Load()) // Check etcd KV. value, err := etcdutil.GetValue(client, EncryptionKeysPath) - c.Assert(err, IsNil) - c.Assert(value, IsNil) + re.NoError(err) + re.Nil(value) } -func (s *testKeyManagerSuite) TestNewKeyManagerLoadKeys(c *C) { +func TestNewKeyManagerLoadKeys(t *testing.T) { + re := require.New(t) // Initialize. - client, cleanupEtcd := newTestEtcd(c) + client, cleanupEtcd := newTestEtcd(re) defer cleanupEtcd() - keyFile, cleanupKeyFile := newTestKeyFile(c) + keyFile, cleanupKeyFile := newTestKeyFile(re) defer cleanupKeyFile() - leadership := newTestLeader(c, client) + leadership := newTestLeader(re, client) // Use default config. config := &encryption.Config{} err := config.Adjust() - c.Assert(err, IsNil) + re.NoError(err) // Store initial keys in etcd. masterKeyMeta := newMasterKey(keyFile) keys := &encryptionpb.KeyDictionary{ @@ -210,39 +204,40 @@ func (s *testKeyManagerSuite) TestNewKeyManagerLoadKeys(c *C) { }, } err = saveKeys(leadership, masterKeyMeta, keys, defaultKeyManagerHelper()) - c.Assert(err, IsNil) + re.NoError(err) // Create the key manager. m, err := NewKeyManager(client, config) - c.Assert(err, IsNil) + re.NoError(err) // Check config. - c.Assert(m.method, Equals, encryptionpb.EncryptionMethod_PLAINTEXT) - c.Assert(m.masterKeyMeta.GetPlaintext(), NotNil) + re.Equal(encryptionpb.EncryptionMethod_PLAINTEXT, m.method) + re.NotNil(m.masterKeyMeta.GetPlaintext()) // Check loaded keys. - c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) + re.True(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys)) // Check etcd KV. resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) - c.Assert(err, IsNil) + re.NoError(err) storedKeys, err := extractKeysFromKV(resp.Kvs[0], defaultKeyManagerHelper()) - c.Assert(err, IsNil) - c.Assert(proto.Equal(storedKeys, keys), IsTrue) + re.NoError(err) + re.True(proto.Equal(storedKeys, keys)) } -func (s *testKeyManagerSuite) TestGetCurrentKey(c *C) { +func TestGetCurrentKey(t *testing.T) { + re := require.New(t) // Initialize. - client, cleanupEtcd := newTestEtcd(c) + client, cleanupEtcd := newTestEtcd(re) defer cleanupEtcd() // Use default config. config := &encryption.Config{} err := config.Adjust() - c.Assert(err, IsNil) + re.NoError(err) // Create the key manager. m, err := NewKeyManager(client, config) - c.Assert(err, IsNil) + re.NoError(err) // Test encryption disabled. currentKeyID, currentKey, err := m.GetCurrentKey() - c.Assert(err, IsNil) - c.Assert(currentKeyID, Equals, uint64(disableEncryptionKeyID)) - c.Assert(currentKey, IsNil) + re.NoError(err) + re.Equal(uint64(disableEncryptionKeyID), currentKeyID) + re.Nil(currentKey) // Test normal case. keys := &encryptionpb.KeyDictionary{ CurrentKeyId: 123, @@ -257,9 +252,9 @@ func (s *testKeyManagerSuite) TestGetCurrentKey(c *C) { } m.keys.Store(keys) currentKeyID, currentKey, err = m.GetCurrentKey() - c.Assert(err, IsNil) - c.Assert(currentKeyID, Equals, keys.CurrentKeyId) - c.Assert(proto.Equal(currentKey, keys.Keys[keys.CurrentKeyId]), IsTrue) + re.NoError(err) + re.Equal(keys.CurrentKeyId, currentKeyID) + re.True(proto.Equal(currentKey, keys.Keys[keys.CurrentKeyId])) // Test current key missing. keys = &encryptionpb.KeyDictionary{ CurrentKeyId: 123, @@ -267,16 +262,17 @@ func (s *testKeyManagerSuite) TestGetCurrentKey(c *C) { } m.keys.Store(keys) _, _, err = m.GetCurrentKey() - c.Assert(err, NotNil) + re.Error(err) } -func (s *testKeyManagerSuite) TestGetKey(c *C) { +func TestGetKey(t *testing.T) { + re := require.New(t) // Initialize. - client, cleanupEtcd := newTestEtcd(c) + client, cleanupEtcd := newTestEtcd(re) defer cleanupEtcd() - keyFile, cleanupKeyFile := newTestKeyFile(c) + keyFile, cleanupKeyFile := newTestKeyFile(re) defer cleanupKeyFile() - leadership := newTestLeader(c, client) + leadership := newTestLeader(re, client) // Store initial keys in etcd. masterKeyMeta := newMasterKey(keyFile) keys := &encryptionpb.KeyDictionary{ @@ -297,18 +293,18 @@ func (s *testKeyManagerSuite) TestGetKey(c *C) { }, } err := saveKeys(leadership, masterKeyMeta, keys, defaultKeyManagerHelper()) - c.Assert(err, IsNil) + re.NoError(err) // Use default config. config := &encryption.Config{} err = config.Adjust() - c.Assert(err, IsNil) + re.NoError(err) // Create the key manager. m, err := NewKeyManager(client, config) - c.Assert(err, IsNil) + re.NoError(err) // Get existing key. key, err := m.GetKey(uint64(123)) - c.Assert(err, IsNil) - c.Assert(proto.Equal(key, keys.Keys[123]), IsTrue) + re.NoError(err) + re.True(proto.Equal(key, keys.Keys[123])) // Get key that require a reload. // Deliberately cancel watcher, delete a key and check if it has reloaded. loadedKeys := m.keys.Load().(*encryptionpb.KeyDictionary) @@ -317,21 +313,22 @@ func (s *testKeyManagerSuite) TestGetKey(c *C) { m.keys.Store(loadedKeys) m.mu.keysRevision = 0 key, err = m.GetKey(uint64(456)) - c.Assert(err, IsNil) - c.Assert(proto.Equal(key, keys.Keys[456]), IsTrue) - c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) + re.NoError(err) + re.True(proto.Equal(key, keys.Keys[456])) + re.True(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys)) // Get non-existing key. _, err = m.GetKey(uint64(789)) - c.Assert(err, NotNil) + re.Error(err) } -func (s *testKeyManagerSuite) TestLoadKeyEmpty(c *C) { +func TestLoadKeyEmpty(t *testing.T) { + re := require.New(t) // Initialize. - client, cleanupEtcd := newTestEtcd(c) + client, cleanupEtcd := newTestEtcd(re) defer cleanupEtcd() - keyFile, cleanupKeyFile := newTestKeyFile(c) + keyFile, cleanupKeyFile := newTestKeyFile(re) defer cleanupKeyFile() - leadership := newTestLeader(c, client) + leadership := newTestLeader(re, client) // Store initial keys in etcd. masterKeyMeta := newMasterKey(keyFile) keys := &encryptionpb.KeyDictionary{ @@ -346,29 +343,30 @@ func (s *testKeyManagerSuite) TestLoadKeyEmpty(c *C) { }, } err := saveKeys(leadership, masterKeyMeta, keys, defaultKeyManagerHelper()) - c.Assert(err, IsNil) + re.NoError(err) // Use default config. config := &encryption.Config{} err = config.Adjust() - c.Assert(err, IsNil) + re.NoError(err) // Create the key manager. m, err := NewKeyManager(client, config) - c.Assert(err, IsNil) + re.NoError(err) // Simulate keys get deleted. _, err = client.Delete(context.Background(), EncryptionKeysPath) - c.Assert(err, IsNil) - c.Assert(m.loadKeys(), NotNil) + re.NoError(err) + re.NotNil(m.loadKeys()) } -func (s *testKeyManagerSuite) TestWatcher(c *C) { +func TestWatcher(t *testing.T) { + re := require.New(t) // Initialize. ctx, cancel := context.WithCancel(context.Background()) defer cancel() - client, cleanupEtcd := newTestEtcd(c) + client, cleanupEtcd := newTestEtcd(re) defer cleanupEtcd() - keyFile, cleanupKeyFile := newTestKeyFile(c) + keyFile, cleanupKeyFile := newTestKeyFile(re) defer cleanupKeyFile() - leadership := newTestLeader(c, client) + leadership := newTestLeader(re, client) // Setup helper helper := defaultKeyManagerHelper() // Listen on watcher event @@ -380,15 +378,15 @@ func (s *testKeyManagerSuite) TestWatcher(c *C) { // Use default config. config := &encryption.Config{} err := config.Adjust() - c.Assert(err, IsNil) + re.NoError(err) // Create the key manager. m, err := newKeyManagerImpl(client, config, helper) - c.Assert(err, IsNil) + re.NoError(err) go m.StartBackgroundLoop(ctx) _, err = m.GetKey(123) - c.Assert(err, NotNil) + re.Error(err) _, err = m.GetKey(456) - c.Assert(err, NotNil) + re.Error(err) // Update keys in etcd masterKeyMeta := newMasterKey(keyFile) keys := &encryptionpb.KeyDictionary{ @@ -403,13 +401,13 @@ func (s *testKeyManagerSuite) TestWatcher(c *C) { }, } err = saveKeys(leadership, masterKeyMeta, keys, defaultKeyManagerHelper()) - c.Assert(err, IsNil) + re.NoError(err) <-reloadEvent key, err := m.GetKey(123) - c.Assert(err, IsNil) - c.Assert(proto.Equal(key, keys.Keys[123]), IsTrue) + re.NoError(err) + re.True(proto.Equal(key, keys.Keys[123])) _, err = m.GetKey(456) - c.Assert(err, NotNil) + re.Error(err) // Update again keys = &encryptionpb.KeyDictionary{ CurrentKeyId: 456, @@ -429,48 +427,50 @@ func (s *testKeyManagerSuite) TestWatcher(c *C) { }, } err = saveKeys(leadership, masterKeyMeta, keys, defaultKeyManagerHelper()) - c.Assert(err, IsNil) + re.NoError(err) <-reloadEvent key, err = m.GetKey(123) - c.Assert(err, IsNil) - c.Assert(proto.Equal(key, keys.Keys[123]), IsTrue) + re.NoError(err) + re.True(proto.Equal(key, keys.Keys[123])) key, err = m.GetKey(456) - c.Assert(err, IsNil) - c.Assert(proto.Equal(key, keys.Keys[456]), IsTrue) + re.NoError(err) + re.True(proto.Equal(key, keys.Keys[456])) } -func (s *testKeyManagerSuite) TestSetLeadershipWithEncryptionOff(c *C) { +func TestSetLeadershipWithEncryptionOff(t *testing.T) { + re := require.New(t) // Initialize. - client, cleanupEtcd := newTestEtcd(c) + client, cleanupEtcd := newTestEtcd(re) defer cleanupEtcd() // Use default config. config := &encryption.Config{} err := config.Adjust() - c.Assert(err, IsNil) + re.NoError(err) // Create the key manager. m, err := NewKeyManager(client, config) - c.Assert(err, IsNil) - c.Assert(m.keys.Load(), IsNil) + re.NoError(err) + re.Nil(m.keys.Load()) // Set leadership - leadership := newTestLeader(c, client) + leadership := newTestLeader(re, client) err = m.SetLeadership(leadership) - c.Assert(err, IsNil) + re.NoError(err) // Check encryption stays off. - c.Assert(m.keys.Load(), IsNil) + re.Nil(m.keys.Load()) value, err := etcdutil.GetValue(client, EncryptionKeysPath) - c.Assert(err, IsNil) - c.Assert(value, IsNil) + re.NoError(err) + re.Nil(value) } -func (s *testKeyManagerSuite) TestSetLeadershipWithEncryptionEnabling(c *C) { +func TestSetLeadershipWithEncryptionEnabling(t *testing.T) { + re := require.New(t) // Initialize. ctx, cancel := context.WithCancel(context.Background()) defer cancel() - client, cleanupEtcd := newTestEtcd(c) + client, cleanupEtcd := newTestEtcd(re) defer cleanupEtcd() - keyFile, cleanupKeyFile := newTestKeyFile(c) + keyFile, cleanupKeyFile := newTestKeyFile(re) defer cleanupKeyFile() - leadership := newTestLeader(c, client) + leadership := newTestLeader(re, client) // Setup helper helper := defaultKeyManagerHelper() // Listen on watcher event @@ -490,41 +490,42 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithEncryptionEnabling(c *C) { }, } err := config.Adjust() - c.Assert(err, IsNil) + re.NoError(err) // Create the key manager. m, err := newKeyManagerImpl(client, config, helper) - c.Assert(err, IsNil) - c.Assert(m.keys.Load(), IsNil) + re.NoError(err) + re.Nil(m.keys.Load()) go m.StartBackgroundLoop(ctx) // Set leadership err = m.SetLeadership(leadership) - c.Assert(err, IsNil) + re.NoError(err) // Check encryption is on and persisted. <-reloadEvent - c.Assert(m.keys.Load(), NotNil) + re.NotNil(m.keys.Load()) currentKeyID, currentKey, err := m.GetCurrentKey() - c.Assert(err, IsNil) + re.NoError(err) method, err := config.GetMethod() - c.Assert(err, IsNil) - c.Assert(currentKey.Method, Equals, method) + re.NoError(err) + re.Equal(method, currentKey.Method) loadedKeys := m.keys.Load().(*encryptionpb.KeyDictionary) - c.Assert(proto.Equal(loadedKeys.Keys[currentKeyID], currentKey), IsTrue) + re.True(proto.Equal(loadedKeys.Keys[currentKeyID], currentKey)) resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) - c.Assert(err, IsNil) + re.NoError(err) storedKeys, err := extractKeysFromKV(resp.Kvs[0], defaultKeyManagerHelper()) - c.Assert(err, IsNil) - c.Assert(proto.Equal(loadedKeys, storedKeys), IsTrue) + re.NoError(err) + re.True(proto.Equal(loadedKeys, storedKeys)) } -func (s *testKeyManagerSuite) TestSetLeadershipWithEncryptionMethodChanged(c *C) { +func TestSetLeadershipWithEncryptionMethodChanged(t *testing.T) { + re := require.New(t) // Initialize. ctx, cancel := context.WithCancel(context.Background()) defer cancel() - client, cleanupEtcd := newTestEtcd(c) + client, cleanupEtcd := newTestEtcd(re) defer cleanupEtcd() - keyFile, cleanupKeyFile := newTestKeyFile(c) + keyFile, cleanupKeyFile := newTestKeyFile(re) defer cleanupKeyFile() - leadership := newTestLeader(c, client) + leadership := newTestLeader(re, client) // Setup helper helper := defaultKeyManagerHelper() // Mock time @@ -555,7 +556,7 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithEncryptionMethodChanged(c *C) }, } err := saveKeys(leadership, masterKeyMeta, keys, defaultKeyManagerHelper()) - c.Assert(err, IsNil) + re.NoError(err) // Config with different encrption method. config := &encryption.Config{ DataEncryptionMethod: "aes256-ctr", @@ -567,41 +568,42 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithEncryptionMethodChanged(c *C) }, } err = config.Adjust() - c.Assert(err, IsNil) + re.NoError(err) // Create the key manager. m, err := newKeyManagerImpl(client, config, helper) - c.Assert(err, IsNil) - c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) + re.NoError(err) + re.True(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys)) go m.StartBackgroundLoop(ctx) // Set leadership err = m.SetLeadership(leadership) - c.Assert(err, IsNil) + re.NoError(err) // Check encryption method is updated. <-reloadEvent - c.Assert(m.keys.Load(), NotNil) + re.NotNil(m.keys.Load()) currentKeyID, currentKey, err := m.GetCurrentKey() - c.Assert(err, IsNil) - c.Assert(currentKey.Method, Equals, encryptionpb.EncryptionMethod_AES256_CTR) - c.Assert(currentKey.Key, HasLen, 32) + re.NoError(err) + re.Equal(encryptionpb.EncryptionMethod_AES256_CTR, currentKey.Method) + re.Len(currentKey.Key, 32) loadedKeys := m.keys.Load().(*encryptionpb.KeyDictionary) - c.Assert(loadedKeys.CurrentKeyId, Equals, currentKeyID) - c.Assert(proto.Equal(loadedKeys.Keys[123], keys.Keys[123]), IsTrue) + re.Equal(currentKeyID, loadedKeys.CurrentKeyId) + re.True(proto.Equal(loadedKeys.Keys[123], keys.Keys[123])) resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) - c.Assert(err, IsNil) + re.NoError(err) storedKeys, err := extractKeysFromKV(resp.Kvs[0], defaultKeyManagerHelper()) - c.Assert(err, IsNil) - c.Assert(proto.Equal(loadedKeys, storedKeys), IsTrue) + re.NoError(err) + re.True(proto.Equal(loadedKeys, storedKeys)) } -func (s *testKeyManagerSuite) TestSetLeadershipWithCurrentKeyExposed(c *C) { +func TestSetLeadershipWithCurrentKeyExposed(t *testing.T) { + re := require.New(t) // Initialize. ctx, cancel := context.WithCancel(context.Background()) defer cancel() - client, cleanupEtcd := newTestEtcd(c) + client, cleanupEtcd := newTestEtcd(re) defer cleanupEtcd() - keyFile, cleanupKeyFile := newTestKeyFile(c) + keyFile, cleanupKeyFile := newTestKeyFile(re) defer cleanupKeyFile() - leadership := newTestLeader(c, client) + leadership := newTestLeader(re, client) // Setup helper helper := defaultKeyManagerHelper() // Mock time @@ -626,7 +628,7 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithCurrentKeyExposed(c *C) { }, } err := saveKeys(leadership, masterKeyMeta, keys, defaultKeyManagerHelper()) - c.Assert(err, IsNil) + re.NoError(err) // Config with different encrption method. config := &encryption.Config{ DataEncryptionMethod: "aes128-ctr", @@ -638,42 +640,43 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithCurrentKeyExposed(c *C) { }, } err = config.Adjust() - c.Assert(err, IsNil) + re.NoError(err) // Create the key manager. m, err := newKeyManagerImpl(client, config, helper) - c.Assert(err, IsNil) - c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) + re.NoError(err) + re.True(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys)) go m.StartBackgroundLoop(ctx) // Set leadership err = m.SetLeadership(leadership) - c.Assert(err, IsNil) + re.NoError(err) // Check encryption method is updated. <-reloadEvent - c.Assert(m.keys.Load(), NotNil) + re.NotNil(m.keys.Load()) currentKeyID, currentKey, err := m.GetCurrentKey() - c.Assert(err, IsNil) - c.Assert(currentKey.Method, Equals, encryptionpb.EncryptionMethod_AES128_CTR) - c.Assert(currentKey.Key, HasLen, 16) - c.Assert(currentKey.WasExposed, IsFalse) + re.NoError(err) + re.Equal(encryptionpb.EncryptionMethod_AES128_CTR, currentKey.Method) + re.Len(currentKey.Key, 16) + re.False(currentKey.WasExposed) loadedKeys := m.keys.Load().(*encryptionpb.KeyDictionary) - c.Assert(loadedKeys.CurrentKeyId, Equals, currentKeyID) - c.Assert(proto.Equal(loadedKeys.Keys[123], keys.Keys[123]), IsTrue) + re.Equal(currentKeyID, loadedKeys.CurrentKeyId) + re.True(proto.Equal(loadedKeys.Keys[123], keys.Keys[123])) resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) - c.Assert(err, IsNil) + re.NoError(err) storedKeys, err := extractKeysFromKV(resp.Kvs[0], defaultKeyManagerHelper()) - c.Assert(err, IsNil) - c.Assert(proto.Equal(loadedKeys, storedKeys), IsTrue) + re.NoError(err) + re.True(proto.Equal(loadedKeys, storedKeys)) } -func (s *testKeyManagerSuite) TestSetLeadershipWithCurrentKeyExpired(c *C) { +func TestSetLeadershipWithCurrentKeyExpired(t *testing.T) { + re := require.New(t) // Initialize. ctx, cancel := context.WithCancel(context.Background()) defer cancel() - client, cleanupEtcd := newTestEtcd(c) + client, cleanupEtcd := newTestEtcd(re) defer cleanupEtcd() - keyFile, cleanupKeyFile := newTestKeyFile(c) + keyFile, cleanupKeyFile := newTestKeyFile(re) defer cleanupKeyFile() - leadership := newTestLeader(c, client) + leadership := newTestLeader(re, client) // Setup helper helper := defaultKeyManagerHelper() // Mock time @@ -698,10 +701,10 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithCurrentKeyExpired(c *C) { }, } err := saveKeys(leadership, masterKeyMeta, keys, defaultKeyManagerHelper()) - c.Assert(err, IsNil) + re.NoError(err) // Config with 100s rotation period. rotationPeriod, err := time.ParseDuration("100s") - c.Assert(err, IsNil) + re.NoError(err) config := &encryption.Config{ DataEncryptionMethod: "aes128-ctr", DataKeyRotationPeriod: typeutil.NewDuration(rotationPeriod), @@ -713,45 +716,46 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithCurrentKeyExpired(c *C) { }, } err = config.Adjust() - c.Assert(err, IsNil) + re.NoError(err) // Create the key manager. m, err := newKeyManagerImpl(client, config, helper) - c.Assert(err, IsNil) - c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) + re.NoError(err) + re.True(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys)) go m.StartBackgroundLoop(ctx) // Set leadership err = m.SetLeadership(leadership) - c.Assert(err, IsNil) + re.NoError(err) // Check encryption method is updated. <-reloadEvent - c.Assert(m.keys.Load(), NotNil) + re.NotNil(m.keys.Load()) currentKeyID, currentKey, err := m.GetCurrentKey() - c.Assert(err, IsNil) - c.Assert(currentKey.Method, Equals, encryptionpb.EncryptionMethod_AES128_CTR) - c.Assert(currentKey.Key, HasLen, 16) - c.Assert(currentKey.WasExposed, IsFalse) - c.Assert(currentKey.CreationTime, Equals, uint64(helper.now().Unix())) + re.NoError(err) + re.Equal(encryptionpb.EncryptionMethod_AES128_CTR, currentKey.Method) + re.Len(currentKey.Key, 16) + re.False(currentKey.WasExposed) + re.Equal(uint64(helper.now().Unix()), currentKey.CreationTime) loadedKeys := m.keys.Load().(*encryptionpb.KeyDictionary) - c.Assert(loadedKeys.CurrentKeyId, Equals, currentKeyID) - c.Assert(proto.Equal(loadedKeys.Keys[123], keys.Keys[123]), IsTrue) + re.Equal(currentKeyID, loadedKeys.CurrentKeyId) + re.True(proto.Equal(loadedKeys.Keys[123], keys.Keys[123])) resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) - c.Assert(err, IsNil) + re.NoError(err) storedKeys, err := extractKeysFromKV(resp.Kvs[0], defaultKeyManagerHelper()) - c.Assert(err, IsNil) - c.Assert(proto.Equal(loadedKeys, storedKeys), IsTrue) + re.NoError(err) + re.True(proto.Equal(loadedKeys, storedKeys)) } -func (s *testKeyManagerSuite) TestSetLeadershipWithMasterKeyChanged(c *C) { +func TestSetLeadershipWithMasterKeyChanged(t *testing.T) { + re := require.New(t) // Initialize. ctx, cancel := context.WithCancel(context.Background()) defer cancel() - client, cleanupEtcd := newTestEtcd(c) + client, cleanupEtcd := newTestEtcd(re) defer cleanupEtcd() - keyFile, cleanupKeyFile := newTestKeyFile(c) + keyFile, cleanupKeyFile := newTestKeyFile(re) defer cleanupKeyFile() - keyFile2, cleanupKeyFile2 := newTestKeyFile(c, testMasterKey2) + keyFile2, cleanupKeyFile2 := newTestKeyFile(re, testMasterKey2) defer cleanupKeyFile2() - leadership := newTestLeader(c, client) + leadership := newTestLeader(re, client) // Setup helper helper := defaultKeyManagerHelper() // Mock time @@ -776,7 +780,7 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithMasterKeyChanged(c *C) { }, } err := saveKeys(leadership, masterKeyMeta, keys, defaultKeyManagerHelper()) - c.Assert(err, IsNil) + re.NoError(err) // Config with a different master key. config := &encryption.Config{ DataEncryptionMethod: "aes128-ctr", @@ -788,35 +792,36 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithMasterKeyChanged(c *C) { }, } err = config.Adjust() - c.Assert(err, IsNil) + re.NoError(err) // Create the key manager. m, err := newKeyManagerImpl(client, config, helper) - c.Assert(err, IsNil) - c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) + re.NoError(err) + re.True(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys)) go m.StartBackgroundLoop(ctx) // Set leadership err = m.SetLeadership(leadership) - c.Assert(err, IsNil) + re.NoError(err) // Check keys are the same, but encrypted with the new master key. <-reloadEvent - c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) + re.True(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys)) resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) - c.Assert(err, IsNil) + re.NoError(err) storedKeys, err := extractKeysFromKV(resp.Kvs[0], defaultKeyManagerHelper()) - c.Assert(err, IsNil) - c.Assert(proto.Equal(storedKeys, keys), IsTrue) + re.NoError(err) + re.True(proto.Equal(storedKeys, keys)) meta, err := config.GetMasterKeyMeta() - c.Assert(err, IsNil) - checkMasterKeyMeta(c, resp.Kvs[0].Value, meta, nil) + re.NoError(err) + checkMasterKeyMeta(re, resp.Kvs[0].Value, meta, nil) } -func (s *testKeyManagerSuite) TestSetLeadershipMasterKeyWithCiphertextKey(c *C) { +func TestSetLeadershipMasterKeyWithCiphertextKey(t *testing.T) { + re := require.New(t) // Initialize. - client, cleanupEtcd := newTestEtcd(c) + client, cleanupEtcd := newTestEtcd(re) defer cleanupEtcd() - keyFile, cleanupKeyFile := newTestKeyFile(c) + keyFile, cleanupKeyFile := newTestKeyFile(re) defer cleanupKeyFile() - leadership := newTestLeader(c, client) + leadership := newTestLeader(re, client) // Setup helper helper := defaultKeyManagerHelper() // Mock time @@ -831,10 +836,10 @@ func (s *testKeyManagerSuite) TestSetLeadershipMasterKeyWithCiphertextKey(c *C) ) (*encryption.MasterKey, error) { if newMasterKeyCalled < 2 { // initial load and save. no ciphertextKey - c.Assert(ciphertext, IsNil) + re.Nil(ciphertext) } else if newMasterKeyCalled == 2 { // called by loadKeys after saveKeys - c.Assert(bytes.Equal(ciphertext, outputCiphertextKey), IsTrue) + re.Equal(ciphertext, outputCiphertextKey) } newMasterKeyCalled += 1 return encryption.NewCustomMasterKeyForTest(outputMasterKey, outputCiphertextKey), nil @@ -853,7 +858,7 @@ func (s *testKeyManagerSuite) TestSetLeadershipMasterKeyWithCiphertextKey(c *C) }, } err := saveKeys(leadership, masterKeyMeta, keys, defaultKeyManagerHelper()) - c.Assert(err, IsNil) + re.NoError(err) // Config with a different master key. config := &encryption.Config{ DataEncryptionMethod: "aes128-ctr", @@ -865,37 +870,38 @@ func (s *testKeyManagerSuite) TestSetLeadershipMasterKeyWithCiphertextKey(c *C) }, } err = config.Adjust() - c.Assert(err, IsNil) + re.NoError(err) // Create the key manager. m, err := newKeyManagerImpl(client, config, helper) - c.Assert(err, IsNil) - c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) + re.NoError(err) + re.True(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys)) // Set leadership err = m.SetLeadership(leadership) - c.Assert(err, IsNil) - c.Assert(newMasterKeyCalled, Equals, 3) + re.NoError(err) + re.Equal(3, newMasterKeyCalled) // Check if keys are the same - c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) + re.True(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys)) resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) - c.Assert(err, IsNil) + re.NoError(err) storedKeys, err := extractKeysFromKV(resp.Kvs[0], defaultKeyManagerHelper()) - c.Assert(err, IsNil) - c.Assert(proto.Equal(storedKeys, keys), IsTrue) + re.NoError(err) + re.True(proto.Equal(storedKeys, keys)) meta, err := config.GetMasterKeyMeta() - c.Assert(err, IsNil) + re.NoError(err) // Check ciphertext key is stored with keys. - checkMasterKeyMeta(c, resp.Kvs[0].Value, meta, outputCiphertextKey) + checkMasterKeyMeta(re, resp.Kvs[0].Value, meta, outputCiphertextKey) } -func (s *testKeyManagerSuite) TestSetLeadershipWithEncryptionDisabling(c *C) { +func TestSetLeadershipWithEncryptionDisabling(t *testing.T) { + re := require.New(t) // Initialize. ctx, cancel := context.WithCancel(context.Background()) defer cancel() - client, cleanupEtcd := newTestEtcd(c) + client, cleanupEtcd := newTestEtcd(re) defer cleanupEtcd() - keyFile, cleanupKeyFile := newTestKeyFile(c) + keyFile, cleanupKeyFile := newTestKeyFile(re) defer cleanupKeyFile() - leadership := newTestLeader(c, client) + leadership := newTestLeader(re, client) // Setup helper helper := defaultKeyManagerHelper() // Listen on watcher event @@ -918,41 +924,42 @@ func (s *testKeyManagerSuite) TestSetLeadershipWithEncryptionDisabling(c *C) { }, } err := saveKeys(leadership, masterKeyMeta, keys, defaultKeyManagerHelper()) - c.Assert(err, IsNil) + re.NoError(err) // Use default config. config := &encryption.Config{} err = config.Adjust() - c.Assert(err, IsNil) + re.NoError(err) // Create the key manager. m, err := newKeyManagerImpl(client, config, helper) - c.Assert(err, IsNil) - c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) + re.NoError(err) + re.True(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys)) go m.StartBackgroundLoop(ctx) // Set leadership err = m.SetLeadership(leadership) - c.Assert(err, IsNil) + re.NoError(err) // Check encryption is disabled <-reloadEvent expectedKeys := proto.Clone(keys).(*encryptionpb.KeyDictionary) expectedKeys.CurrentKeyId = disableEncryptionKeyID expectedKeys.Keys[123].WasExposed = true - c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), expectedKeys), IsTrue) + re.True(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), expectedKeys)) resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) - c.Assert(err, IsNil) + re.NoError(err) storedKeys, err := extractKeysFromKV(resp.Kvs[0], defaultKeyManagerHelper()) - c.Assert(err, IsNil) - c.Assert(proto.Equal(storedKeys, expectedKeys), IsTrue) + re.NoError(err) + re.True(proto.Equal(storedKeys, expectedKeys)) } -func (s *testKeyManagerSuite) TestKeyRotation(c *C) { +func TestKeyRotation(t *testing.T) { + re := require.New(t) // Initialize. ctx, cancel := context.WithCancel(context.Background()) defer cancel() - client, cleanupEtcd := newTestEtcd(c) + client, cleanupEtcd := newTestEtcd(re) defer cleanupEtcd() - keyFile, cleanupKeyFile := newTestKeyFile(c) + keyFile, cleanupKeyFile := newTestKeyFile(re) defer cleanupKeyFile() - leadership := newTestLeader(c, client) + leadership := newTestLeader(re, client) // Setup helper helper := defaultKeyManagerHelper() // Mock time @@ -986,10 +993,10 @@ func (s *testKeyManagerSuite) TestKeyRotation(c *C) { }, } err := saveKeys(leadership, masterKeyMeta, keys, defaultKeyManagerHelper()) - c.Assert(err, IsNil) + re.NoError(err) // Config with 100s rotation period. rotationPeriod, err := time.ParseDuration("100s") - c.Assert(err, IsNil) + re.NoError(err) config := &encryption.Config{ DataEncryptionMethod: "aes128-ctr", DataKeyRotationPeriod: typeutil.NewDuration(rotationPeriod), @@ -1001,22 +1008,22 @@ func (s *testKeyManagerSuite) TestKeyRotation(c *C) { }, } err = config.Adjust() - c.Assert(err, IsNil) + re.NoError(err) // Create the key manager. m, err := newKeyManagerImpl(client, config, helper) - c.Assert(err, IsNil) - c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) + re.NoError(err) + re.True(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys)) go m.StartBackgroundLoop(ctx) // Set leadership err = m.SetLeadership(leadership) - c.Assert(err, IsNil) + re.NoError(err) // Check keys - c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) + re.True(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys)) resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) - c.Assert(err, IsNil) + re.NoError(err) storedKeys, err := extractKeysFromKV(resp.Kvs[0], defaultKeyManagerHelper()) - c.Assert(err, IsNil) - c.Assert(proto.Equal(storedKeys, keys), IsTrue) + re.NoError(err) + re.True(proto.Equal(storedKeys, keys)) // Advance time and trigger ticker atomic.AddInt64(&mockNow, int64(101)) mockTick <- time.Unix(atomic.LoadInt64(&mockNow), 0) @@ -1024,32 +1031,33 @@ func (s *testKeyManagerSuite) TestKeyRotation(c *C) { <-reloadEvent // Check key is rotated. currentKeyID, currentKey, err := m.GetCurrentKey() - c.Assert(err, IsNil) - c.Assert(currentKeyID, Not(Equals), uint64(123)) - c.Assert(currentKey.Method, Equals, encryptionpb.EncryptionMethod_AES128_CTR) - c.Assert(currentKey.Key, HasLen, 16) - c.Assert(currentKey.CreationTime, Equals, uint64(mockNow)) - c.Assert(currentKey.WasExposed, IsFalse) + re.NoError(err) + re.NotEqual(uint64(123), currentKeyID) + re.Equal(encryptionpb.EncryptionMethod_AES128_CTR, currentKey.Method) + re.Len(currentKey.Key, 16) + re.Equal(uint64(mockNow), currentKey.CreationTime) + re.False(currentKey.WasExposed) loadedKeys := m.keys.Load().(*encryptionpb.KeyDictionary) - c.Assert(loadedKeys.CurrentKeyId, Equals, currentKeyID) - c.Assert(proto.Equal(loadedKeys.Keys[123], keys.Keys[123]), IsTrue) - c.Assert(proto.Equal(loadedKeys.Keys[currentKeyID], currentKey), IsTrue) + re.Equal(currentKeyID, loadedKeys.CurrentKeyId) + re.True(proto.Equal(loadedKeys.Keys[123], keys.Keys[123])) + re.True(proto.Equal(loadedKeys.Keys[currentKeyID], currentKey)) resp, err = etcdutil.EtcdKVGet(client, EncryptionKeysPath) - c.Assert(err, IsNil) + re.NoError(err) storedKeys, err = extractKeysFromKV(resp.Kvs[0], defaultKeyManagerHelper()) - c.Assert(err, IsNil) - c.Assert(proto.Equal(storedKeys, loadedKeys), IsTrue) + re.NoError(err) + re.True(proto.Equal(storedKeys, loadedKeys)) } -func (s *testKeyManagerSuite) TestKeyRotationConflict(c *C) { +func TestKeyRotationConflict(t *testing.T) { + re := require.New(t) // Initialize. ctx, cancel := context.WithCancel(context.Background()) defer cancel() - client, cleanupEtcd := newTestEtcd(c) + client, cleanupEtcd := newTestEtcd(re) defer cleanupEtcd() - keyFile, cleanupKeyFile := newTestKeyFile(c) + keyFile, cleanupKeyFile := newTestKeyFile(re) defer cleanupKeyFile() - leadership := newTestLeader(c, client) + leadership := newTestLeader(re, client) // Setup helper helper := defaultKeyManagerHelper() // Mock time @@ -1093,10 +1101,10 @@ func (s *testKeyManagerSuite) TestKeyRotationConflict(c *C) { }, } err := saveKeys(leadership, masterKeyMeta, keys, defaultKeyManagerHelper()) - c.Assert(err, IsNil) + re.NoError(err) // Config with 100s rotation period. rotationPeriod, err := time.ParseDuration("100s") - c.Assert(err, IsNil) + re.NoError(err) config := &encryption.Config{ DataEncryptionMethod: "aes128-ctr", DataKeyRotationPeriod: typeutil.NewDuration(rotationPeriod), @@ -1108,22 +1116,22 @@ func (s *testKeyManagerSuite) TestKeyRotationConflict(c *C) { }, } err = config.Adjust() - c.Assert(err, IsNil) + re.NoError(err) // Create the key manager. m, err := newKeyManagerImpl(client, config, helper) - c.Assert(err, IsNil) - c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) + re.NoError(err) + re.True(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys)) go m.StartBackgroundLoop(ctx) // Set leadership err = m.SetLeadership(leadership) - c.Assert(err, IsNil) + re.NoError(err) // Check keys - c.Assert(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys), IsTrue) + re.True(proto.Equal(m.keys.Load().(*encryptionpb.KeyDictionary), keys)) resp, err := etcdutil.EtcdKVGet(client, EncryptionKeysPath) - c.Assert(err, IsNil) + re.NoError(err) storedKeys, err := extractKeysFromKV(resp.Kvs[0], defaultKeyManagerHelper()) - c.Assert(err, IsNil) - c.Assert(proto.Equal(storedKeys, keys), IsTrue) + re.NoError(err) + re.True(proto.Equal(storedKeys, keys)) // Invalidate leader after leader check. atomic.StoreInt32(&shouldResetLeader, 1) atomic.StoreInt32(&shouldListenSaveKeysFailure, 1) @@ -1134,10 +1142,10 @@ func (s *testKeyManagerSuite) TestKeyRotationConflict(c *C) { <-saveKeysFailureEvent // Check keys is unchanged. resp, err = etcdutil.EtcdKVGet(client, EncryptionKeysPath) - c.Assert(err, IsNil) + re.NoError(err) storedKeys, err = extractKeysFromKV(resp.Kvs[0], defaultKeyManagerHelper()) - c.Assert(err, IsNil) - c.Assert(proto.Equal(storedKeys, keys), IsTrue) + re.NoError(err) + re.True(proto.Equal(storedKeys, keys)) } func newMasterKey(keyFile string) *encryptionpb.MasterKey { From 83f14760e1c2cb53d68667f3192036ade7c5dcc0 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Mon, 13 Jun 2022 13:32:32 +0800 Subject: [PATCH 39/82] server: fix the bug that causes wrong statistics for over/undersized regions (#5137) close tikv/pd#5107 Fix the bug that causes wrong statistics for over/undersized regions. Signed-off-by: JmPotato Co-authored-by: Ti Chi Robot --- server/cluster/cluster.go | 5 ++++ server/cluster/cluster_test.go | 39 ++++++++++++++++++++++++++ server/config/persist_options.go | 14 +++++++++ server/core/region.go | 5 ++++ server/statistics/region_collection.go | 35 ++++++++++++++++++++--- 5 files changed, 94 insertions(+), 4 deletions(-) diff --git a/server/cluster/cluster.go b/server/cluster/cluster.go index 7c6f0c3701b..7b39a29d580 100644 --- a/server/cluster/cluster.go +++ b/server/cluster/cluster.go @@ -792,6 +792,11 @@ func (c *RaftCluster) processRegionHeartbeat(region *core.RegionInfo) error { // Mark isNew if the region in cache does not have leader. isNew, saveKV, saveCache, needSync := regionGuide(region, origin) if !saveKV && !saveCache && !isNew { + // Due to some config changes need to update the region stats as well, + // so we do some extra checks here. + if c.regionStats != nil && c.regionStats.RegionStatsNeedUpdate(region) { + c.regionStats.Observe(region, c.getRegionStoresLocked(region)) + } return nil } diff --git a/server/cluster/cluster_test.go b/server/cluster/cluster_test.go index c91899c662d..530abff2b87 100644 --- a/server/cluster/cluster_test.go +++ b/server/cluster/cluster_test.go @@ -881,6 +881,45 @@ func (s *testClusterInfoSuite) TestRegionFlowChanged(c *C) { c.Assert(newRegion.GetBytesRead(), Equals, uint64(1000)) } +func (s *testClusterInfoSuite) TestRegionSizeChanged(c *C) { + _, opt, err := newTestScheduleConfig() + c.Assert(err, IsNil) + cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(s.ctx, cluster, nil) + cluster.regionStats = statistics.NewRegionStatistics(cluster.GetOpts(), cluster.ruleManager, cluster.storeConfigManager) + region := newTestRegions(1, 3, 3)[0] + cluster.opt.GetMaxMergeRegionKeys() + curMaxMergeSize := int64(cluster.opt.GetMaxMergeRegionSize()) + curMaxMergeKeys := int64(cluster.opt.GetMaxMergeRegionKeys()) + region = region.Clone( + core.WithLeader(region.GetPeers()[2]), + core.SetApproximateSize(curMaxMergeSize-1), + core.SetApproximateKeys(curMaxMergeKeys-1), + core.SetFromHeartbeat(true), + ) + cluster.processRegionHeartbeat(region) + regionID := region.GetID() + c.Assert(cluster.regionStats.IsRegionStatsType(regionID, statistics.UndersizedRegion), IsTrue) + // Test ApproximateSize and ApproximateKeys change. + region = region.Clone( + core.WithLeader(region.GetPeers()[2]), + core.SetApproximateSize(curMaxMergeSize+1), + core.SetApproximateKeys(curMaxMergeKeys+1), + core.SetFromHeartbeat(true), + ) + cluster.processRegionHeartbeat(region) + c.Assert(cluster.regionStats.IsRegionStatsType(regionID, statistics.UndersizedRegion), IsFalse) + // Test MaxMergeRegionSize and MaxMergeRegionKeys change. + cluster.opt.SetMaxMergeRegionSize((uint64(curMaxMergeSize + 2))) + cluster.opt.SetMaxMergeRegionKeys((uint64(curMaxMergeKeys + 2))) + cluster.processRegionHeartbeat(region) + c.Assert(cluster.regionStats.IsRegionStatsType(regionID, statistics.UndersizedRegion), IsTrue) + cluster.opt.SetMaxMergeRegionSize((uint64(curMaxMergeSize))) + cluster.opt.SetMaxMergeRegionKeys((uint64(curMaxMergeKeys))) + cluster.processRegionHeartbeat(region) + c.Assert(cluster.regionStats.IsRegionStatsType(regionID, statistics.UndersizedRegion), IsFalse) +} + func (s *testClusterInfoSuite) TestConcurrentReportBucket(c *C) { _, opt, err := newTestScheduleConfig() c.Assert(err, IsNil) diff --git a/server/config/persist_options.go b/server/config/persist_options.go index 5882c123947..fe7203722c2 100644 --- a/server/config/persist_options.go +++ b/server/config/persist_options.go @@ -260,6 +260,20 @@ func (o *PersistOptions) SetSplitMergeInterval(splitMergeInterval time.Duration) o.SetScheduleConfig(v) } +// SetMaxMergeRegionSize sets the max merge region size. +func (o *PersistOptions) SetMaxMergeRegionSize(maxMergeRegionSize uint64) { + v := o.GetScheduleConfig().Clone() + v.MaxMergeRegionSize = maxMergeRegionSize + o.SetScheduleConfig(v) +} + +// SetMaxMergeRegionKeys sets the max merge region keys. +func (o *PersistOptions) SetMaxMergeRegionKeys(maxMergeRegionKeys uint64) { + v := o.GetScheduleConfig().Clone() + v.MaxMergeRegionKeys = maxMergeRegionKeys + o.SetScheduleConfig(v) +} + // SetStoreLimit sets a store limit for a given type and rate. func (o *PersistOptions) SetStoreLimit(storeID uint64, typ storelimit.Type, ratePerMin float64) { v := o.GetScheduleConfig().Clone() diff --git a/server/core/region.go b/server/core/region.go index 6ec75d4fef5..cc688712ad8 100644 --- a/server/core/region.go +++ b/server/core/region.go @@ -229,6 +229,11 @@ func (r *RegionInfo) NeedMerge(mergeSize int64, mergeKeys int64) bool { return r.GetApproximateSize() <= mergeSize && r.GetApproximateKeys() <= mergeKeys } +// IsOversized indicates whether the region is oversized. +func (r *RegionInfo) IsOversized(maxSize int64, maxKeys int64) bool { + return r.GetApproximateSize() >= maxSize || r.GetApproximateKeys() >= maxKeys +} + // GetTerm returns the current term of the region func (r *RegionInfo) GetTerm() uint64 { return r.term diff --git a/server/statistics/region_collection.go b/server/statistics/region_collection.go index 807a93d87a6..1c46d7acdda 100644 --- a/server/statistics/region_collection.go +++ b/server/statistics/region_collection.go @@ -101,6 +101,14 @@ func (r *RegionStatistics) GetRegionStatsByType(typ RegionStatisticType) []*core return res } +// IsRegionStatsType returns whether the status of the region is the given type. +func (r *RegionStatistics) IsRegionStatsType(regionID uint64, typ RegionStatisticType) bool { + r.RLock() + defer r.RUnlock() + _, exist := r.stats[typ][regionID] + return exist +} + // GetOfflineRegionStatsByType gets the status of the offline region by types. The regions here need to be cloned, otherwise, it may cause data race problems. func (r *RegionStatistics) GetOfflineRegionStatsByType(typ RegionStatisticType) []*core.RegionInfo { r.RLock() @@ -128,6 +136,18 @@ func (r *RegionStatistics) deleteOfflineEntry(deleteIndex RegionStatisticType, r } } +// RegionStatsNeedUpdate checks whether the region's status need to be updated +// due to some special state types. +func (r *RegionStatistics) RegionStatsNeedUpdate(region *core.RegionInfo) bool { + regionID := region.GetID() + if r.IsRegionStatsType(regionID, OversizedRegion) != + region.IsOversized(int64(r.storeConfigManager.GetStoreConfig().GetRegionMaxSize()), int64(r.storeConfigManager.GetStoreConfig().GetRegionMaxKeys())) { + return true + } + return r.IsRegionStatsType(regionID, UndersizedRegion) != + region.NeedMerge(int64(r.opt.GetMaxMergeRegionSize()), int64(r.opt.GetMaxMergeRegionKeys())) +} + // Observe records the current regions' status. func (r *RegionStatistics) Observe(region *core.RegionInfo, stores []*core.StoreInfo) { r.Lock() @@ -169,6 +189,9 @@ func (r *RegionStatistics) Observe(region *core.RegionInfo, stores []*core.Store } } + // Better to make sure once any of these conditions changes, it will trigger the heartbeat `save_cache`. + // Otherwise, the state may be out-of-date for a long time, which needs another way to apply the change ASAP. + // For example, see `RegionStatsNeedUpdate` above to know how `OversizedRegion` and ``UndersizedRegion` are updated. conditions := map[RegionStatisticType]bool{ MissPeer: len(region.GetPeers()) < desiredReplicas, ExtraPeer: len(region.GetPeers()) > desiredReplicas, @@ -176,10 +199,14 @@ func (r *RegionStatistics) Observe(region *core.RegionInfo, stores []*core.Store PendingPeer: len(region.GetPendingPeers()) > 0, LearnerPeer: len(region.GetLearners()) > 0, EmptyRegion: region.GetApproximateSize() <= core.EmptyRegionApproximateSize, - OversizedRegion: region.GetApproximateSize() >= int64(r.storeConfigManager.GetStoreConfig().GetRegionMaxSize()) || - region.GetApproximateKeys() >= int64(r.storeConfigManager.GetStoreConfig().GetRegionMaxKeys()), - UndersizedRegion: region.NeedMerge(int64(r.opt.GetScheduleConfig().MaxMergeRegionSize), - int64(r.opt.GetScheduleConfig().MaxMergeRegionKeys)), + OversizedRegion: region.IsOversized( + int64(r.storeConfigManager.GetStoreConfig().GetRegionMaxSize()), + int64(r.storeConfigManager.GetStoreConfig().GetRegionMaxKeys()), + ), + UndersizedRegion: region.NeedMerge( + int64(r.opt.GetMaxMergeRegionSize()), + int64(r.opt.GetMaxMergeRegionKeys()), + ), } for typ, c := range conditions { From 1c2a4da9aed281634ea36ab8ae5a0cd1960b639e Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Mon, 13 Jun 2022 15:04:33 +0800 Subject: [PATCH 40/82] syncer: fix the wrong gRPC code usage (#5142) ref tikv/pd#5122 Signed-off-by: Ryan Leung Co-authored-by: Ti Chi Robot --- server/region_syncer/client.go | 4 ++-- server/region_syncer/client_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/server/region_syncer/client.go b/server/region_syncer/client.go index 3f2419ad720..ba9fd2edcd6 100644 --- a/server/region_syncer/client.go +++ b/server/region_syncer/client.go @@ -90,7 +90,7 @@ func (s *RegionSyncer) syncRegion(ctx context.Context, conn *grpc.ClientConn) (C cli := pdpb.NewPDClient(conn) syncStream, err := cli.SyncRegions(ctx) if err != nil { - return nil, errs.ErrGRPCCreateStream.Wrap(err).FastGenWithCause() + return nil, err } err = syncStream.Send(&pdpb.SyncRegionRequest{ Header: &pdpb.RequestHeader{ClusterId: s.server.ClusterID()}, @@ -98,7 +98,7 @@ func (s *RegionSyncer) syncRegion(ctx context.Context, conn *grpc.ClientConn) (C StartIndex: s.history.GetNextIndex(), }) if err != nil { - return nil, errs.ErrGRPCSend.Wrap(err).FastGenWithCause() + return nil, err } return syncStream, nil diff --git a/server/region_syncer/client_test.go b/server/region_syncer/client_test.go index ca39cee4859..b63deaae3e0 100644 --- a/server/region_syncer/client_test.go +++ b/server/region_syncer/client_test.go @@ -27,6 +27,8 @@ import ( "github.com/tikv/pd/pkg/grpcutil" "github.com/tikv/pd/server/core" "github.com/tikv/pd/server/storage" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) // For issue https://github.com/tikv/pd/issues/3936 @@ -60,6 +62,29 @@ func TestLoadRegion(t *testing.T) { re.Less(time.Since(start), time.Second*2) } +func TestErrorCode(t *testing.T) { + re := require.New(t) + tempDir, err := os.MkdirTemp(os.TempDir(), "region_syncer_err") + re.NoError(err) + defer os.RemoveAll(tempDir) + rs, err := storage.NewStorageWithLevelDBBackend(context.Background(), tempDir, nil) + re.NoError(err) + server := &mockServer{ + ctx: context.Background(), + storage: storage.NewCoreStorage(storage.NewStorageWithMemoryBackend(), rs), + bc: core.NewBasicCluster(), + } + ctx, cancel := context.WithCancel(context.TODO()) + rc := NewRegionSyncer(server) + conn, err := grpcutil.GetClientConn(ctx, "127.0.0.1", nil) + re.NoError(err) + cancel() + _, err = rc.syncRegion(ctx, conn) + ev, ok := status.FromError(err) + re.True(ok) + re.Equal(codes.Canceled, ev.Code()) +} + type mockServer struct { ctx context.Context member, leader *pdpb.Member From 958d687c32389ce5a06eb1dc3303d5d440328d6e Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Mon, 13 Jun 2022 15:24:34 +0800 Subject: [PATCH 41/82] statistics: migrate test framework to testify (#5140) ref tikv/pd#4813 Signed-off-by: Ryan Leung Co-authored-by: Ti Chi Robot --- server/statistics/hot_peer_cache_test.go | 310 ++++++++++---------- server/statistics/kind_test.go | 43 ++- server/statistics/region_collection_test.go | 153 +++++----- server/statistics/store_collection_test.go | 46 ++- server/statistics/store_test.go | 18 +- server/statistics/topn_test.go | 72 ++--- 6 files changed, 314 insertions(+), 328 deletions(-) diff --git a/server/statistics/hot_peer_cache_test.go b/server/statistics/hot_peer_cache_test.go index 347e2a423d8..c021f05df3f 100644 --- a/server/statistics/hot_peer_cache_test.go +++ b/server/statistics/hot_peer_cache_test.go @@ -20,27 +20,24 @@ import ( "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/movingaverage" "github.com/tikv/pd/server/core" ) -var _ = Suite(&testHotPeerCache{}) - -type testHotPeerCache struct{} - -func (t *testHotPeerCache) TestStoreTimeUnsync(c *C) { +func TestStoreTimeUnsync(t *testing.T) { + re := require.New(t) cache := NewHotPeerCache(Write) intervals := []uint64{120, 60} for _, interval := range intervals { region := buildRegion(Write, 3, interval) - checkAndUpdate(c, cache, region, 3) + checkAndUpdate(re, cache, region, 3) { stats := cache.RegionStats(0) - c.Assert(stats, HasLen, 3) + re.Len(stats, 3) for _, s := range stats { - c.Assert(s, HasLen, 1) + re.Len(s, 1) } } } @@ -62,7 +59,8 @@ type testCacheCase struct { actionType ActionType } -func (t *testHotPeerCache) TestCache(c *C) { +func TestCache(t *testing.T) { + re := require.New(t) tests := []*testCacheCase{ {Read, transferLeader, 3, Update}, {Read, movePeer, 4, Remove}, @@ -71,26 +69,22 @@ func (t *testHotPeerCache) TestCache(c *C) { {Write, movePeer, 4, Remove}, {Write, addReplica, 4, Remove}, } - for _, t := range tests { - testCache(c, t) - } -} - -func testCache(c *C, t *testCacheCase) { - defaultSize := map[RWType]int{ - Read: 3, // all peers - Write: 3, // all peers - } - cache := NewHotPeerCache(t.kind) - region := buildRegion(t.kind, 3, 60) - checkAndUpdate(c, cache, region, defaultSize[t.kind]) - checkHit(c, cache, region, t.kind, Add) // all peers are new - - srcStore, region := schedule(c, t.operator, region, 10) - res := checkAndUpdate(c, cache, region, t.expect) - checkHit(c, cache, region, t.kind, Update) // hit cache - if t.expect != defaultSize[t.kind] { - checkOp(c, res, srcStore, t.actionType) + for _, test := range tests { + defaultSize := map[RWType]int{ + Read: 3, // all peers + Write: 3, // all peers + } + cache := NewHotPeerCache(test.kind) + region := buildRegion(test.kind, 3, 60) + checkAndUpdate(re, cache, region, defaultSize[test.kind]) + checkHit(re, cache, region, test.kind, Add) // all peers are new + + srcStore, region := schedule(re, test.operator, region, 10) + res := checkAndUpdate(re, cache, region, test.expect) + checkHit(re, cache, region, test.kind, Update) // hit cache + if test.expect != defaultSize[test.kind] { + checkOp(re, res, srcStore, test.actionType) + } } } @@ -127,35 +121,35 @@ func updateFlow(cache *hotPeerCache, res []*HotPeerStat) []*HotPeerStat { return res } -type check func(c *C, cache *hotPeerCache, region *core.RegionInfo, expect ...int) (res []*HotPeerStat) +type check func(re *require.Assertions, cache *hotPeerCache, region *core.RegionInfo, expect ...int) (res []*HotPeerStat) -func checkAndUpdate(c *C, cache *hotPeerCache, region *core.RegionInfo, expect ...int) (res []*HotPeerStat) { +func checkAndUpdate(re *require.Assertions, cache *hotPeerCache, region *core.RegionInfo, expect ...int) (res []*HotPeerStat) { res = checkFlow(cache, region, region.GetPeers()) if len(expect) != 0 { - c.Assert(res, HasLen, expect[0]) + re.Len(res, expect[0]) } return updateFlow(cache, res) } // Check and update peers in the specified order that old item that he items that have not expired come first, and the items that have expired come second. // This order is also similar to the previous version. By the way the order in now version is random. -func checkAndUpdateWithOrdering(c *C, cache *hotPeerCache, region *core.RegionInfo, expect ...int) (res []*HotPeerStat) { +func checkAndUpdateWithOrdering(re *require.Assertions, cache *hotPeerCache, region *core.RegionInfo, expect ...int) (res []*HotPeerStat) { res = checkFlow(cache, region, orderingPeers(cache, region)) if len(expect) != 0 { - c.Assert(res, HasLen, expect[0]) + re.Len(res, expect[0]) } return updateFlow(cache, res) } -func checkAndUpdateSkipOne(c *C, cache *hotPeerCache, region *core.RegionInfo, expect ...int) (res []*HotPeerStat) { +func checkAndUpdateSkipOne(re *require.Assertions, cache *hotPeerCache, region *core.RegionInfo, expect ...int) (res []*HotPeerStat) { res = checkFlow(cache, region, region.GetPeers()[1:]) if len(expect) != 0 { - c.Assert(res, HasLen, expect[0]) + re.Len(res, expect[0]) } return updateFlow(cache, res) } -func checkHit(c *C, cache *hotPeerCache, region *core.RegionInfo, kind RWType, actionType ActionType) { +func checkHit(re *require.Assertions, cache *hotPeerCache, region *core.RegionInfo, kind RWType, actionType ActionType) { var peers []*metapb.Peer if kind == Read { peers = []*metapb.Peer{region.GetLeader()} @@ -164,15 +158,15 @@ func checkHit(c *C, cache *hotPeerCache, region *core.RegionInfo, kind RWType, a } for _, peer := range peers { item := cache.getOldHotPeerStat(region.GetID(), peer.StoreId) - c.Assert(item, NotNil) - c.Assert(item.actionType, Equals, actionType) + re.NotNil(item) + re.Equal(actionType, item.actionType) } } -func checkOp(c *C, ret []*HotPeerStat, storeID uint64, actionType ActionType) { +func checkOp(re *require.Assertions, ret []*HotPeerStat, storeID uint64, actionType ActionType) { for _, item := range ret { if item.StoreID == storeID { - c.Assert(item.actionType, Equals, actionType) + re.Equal(actionType, item.actionType) return } } @@ -192,7 +186,7 @@ func checkIntervalSum(cache *hotPeerCache, region *core.RegionInfo) bool { } // checkIntervalSumContinuous checks whether the interval sum of the peer is continuous. -func checkIntervalSumContinuous(c *C, intervalSums map[uint64]int, rets []*HotPeerStat, interval uint64) { +func checkIntervalSumContinuous(re *require.Assertions, intervalSums map[uint64]int, rets []*HotPeerStat, interval uint64) { for _, ret := range rets { if ret.actionType == Remove { delete(intervalSums, ret.StoreID) @@ -201,27 +195,27 @@ func checkIntervalSumContinuous(c *C, intervalSums map[uint64]int, rets []*HotPe new := int(ret.getIntervalSum() / 1000000000) if ret.source == direct { if old, ok := intervalSums[ret.StoreID]; ok { - c.Assert(new, Equals, (old+int(interval))%RegionHeartBeatReportInterval) + re.Equal((old+int(interval))%RegionHeartBeatReportInterval, new) } } intervalSums[ret.StoreID] = new } } -func schedule(c *C, operator operator, region *core.RegionInfo, targets ...uint64) (srcStore uint64, _ *core.RegionInfo) { +func schedule(re *require.Assertions, operator operator, region *core.RegionInfo, targets ...uint64) (srcStore uint64, _ *core.RegionInfo) { switch operator { case transferLeader: _, newLeader := pickFollower(region) return region.GetLeader().StoreId, region.Clone(core.WithLeader(newLeader)) case movePeer: - c.Assert(targets, HasLen, 1) + re.Len(targets, 1) index, _ := pickFollower(region) srcStore := region.GetPeers()[index].StoreId region := region.Clone(core.WithAddPeer(&metapb.Peer{Id: targets[0]*10 + 1, StoreId: targets[0]})) region = region.Clone(core.WithRemoveStorePeer(srcStore)) return srcStore, region case addReplica: - c.Assert(targets, HasLen, 1) + re.Len(targets, 1) region := region.Clone(core.WithAddPeer(&metapb.Peer{Id: targets[0]*10 + 1, StoreId: targets[0]})) return 0, region case removeReplica: @@ -307,7 +301,8 @@ func newPeers(n int, pid genID, sid genID) []*metapb.Peer { return peers } -func (t *testHotPeerCache) TestUpdateHotPeerStat(c *C) { +func TestUpdateHotPeerStat(t *testing.T) { + re := require.New(t) cache := NewHotPeerCache(Read) // we statistic read peer info from store heartbeat rather than region heartbeat m := RegionHeartBeatReportInterval / StoreHeartBeatReportInterval @@ -315,69 +310,70 @@ func (t *testHotPeerCache) TestUpdateHotPeerStat(c *C) { // skip interval=0 newItem := &HotPeerStat{actionType: Update, thresholds: []float64{0.0, 0.0, 0.0}, Kind: Read} newItem = cache.updateHotPeerStat(nil, newItem, nil, []float64{0.0, 0.0, 0.0}, 0) - c.Check(newItem, IsNil) + re.Nil(newItem) // new peer, interval is larger than report interval, but no hot newItem = &HotPeerStat{actionType: Update, thresholds: []float64{1.0, 1.0, 1.0}, Kind: Read} newItem = cache.updateHotPeerStat(nil, newItem, nil, []float64{0.0, 0.0, 0.0}, 10*time.Second) - c.Check(newItem, IsNil) + re.Nil(newItem) // new peer, interval is less than report interval newItem = &HotPeerStat{actionType: Update, thresholds: []float64{0.0, 0.0, 0.0}, Kind: Read} newItem = cache.updateHotPeerStat(nil, newItem, nil, []float64{60.0, 60.0, 60.0}, 4*time.Second) - c.Check(newItem, NotNil) - c.Check(newItem.HotDegree, Equals, 0) - c.Check(newItem.AntiCount, Equals, 0) + re.NotNil(newItem) + re.Equal(0, newItem.HotDegree) + re.Equal(0, newItem.AntiCount) // sum of interval is less than report interval oldItem := newItem newItem = cache.updateHotPeerStat(nil, newItem, oldItem, []float64{60.0, 60.0, 60.0}, 4*time.Second) - c.Check(newItem.HotDegree, Equals, 0) - c.Check(newItem.AntiCount, Equals, 0) + re.Equal(0, newItem.HotDegree) + re.Equal(0, newItem.AntiCount) // sum of interval is larger than report interval, and hot oldItem = newItem newItem = cache.updateHotPeerStat(nil, newItem, oldItem, []float64{60.0, 60.0, 60.0}, 4*time.Second) - c.Check(newItem.HotDegree, Equals, 1) - c.Check(newItem.AntiCount, Equals, 2*m) + re.Equal(1, newItem.HotDegree) + re.Equal(2*m, newItem.AntiCount) // sum of interval is less than report interval oldItem = newItem newItem = cache.updateHotPeerStat(nil, newItem, oldItem, []float64{60.0, 60.0, 60.0}, 4*time.Second) - c.Check(newItem.HotDegree, Equals, 1) - c.Check(newItem.AntiCount, Equals, 2*m) + re.Equal(1, newItem.HotDegree) + re.Equal(2*m, newItem.AntiCount) // sum of interval is larger than report interval, and hot oldItem = newItem newItem = cache.updateHotPeerStat(nil, newItem, oldItem, []float64{60.0, 60.0, 60.0}, 10*time.Second) - c.Check(newItem.HotDegree, Equals, 2) - c.Check(newItem.AntiCount, Equals, 2*m) + re.Equal(2, newItem.HotDegree) + re.Equal(2*m, newItem.AntiCount) // sum of interval is larger than report interval, and cold oldItem = newItem newItem.thresholds = []float64{10.0, 10.0, 10.0} newItem = cache.updateHotPeerStat(nil, newItem, oldItem, []float64{60.0, 60.0, 60.0}, 10*time.Second) - c.Check(newItem.HotDegree, Equals, 1) - c.Check(newItem.AntiCount, Equals, 2*m-1) + re.Equal(1, newItem.HotDegree) + re.Equal(2*m-1, newItem.AntiCount) // sum of interval is larger than report interval, and cold for i := 0; i < 2*m-1; i++ { oldItem = newItem newItem = cache.updateHotPeerStat(nil, newItem, oldItem, []float64{60.0, 60.0, 60.0}, 10*time.Second) } - c.Check(newItem.HotDegree, Less, 0) - c.Check(newItem.AntiCount, Equals, 0) - c.Check(newItem.actionType, Equals, Remove) + re.Less(newItem.HotDegree, 0) + re.Equal(0, newItem.AntiCount) + re.Equal(Remove, newItem.actionType) } -func (t *testHotPeerCache) TestThresholdWithUpdateHotPeerStat(c *C) { +func TestThresholdWithUpdateHotPeerStat(t *testing.T) { + re := require.New(t) byteRate := minHotThresholds[RegionReadBytes] * 2 expectThreshold := byteRate * HotThresholdRatio - t.testMetrics(c, 120., byteRate, expectThreshold) - t.testMetrics(c, 60., byteRate, expectThreshold) - t.testMetrics(c, 30., byteRate, expectThreshold) - t.testMetrics(c, 17., byteRate, expectThreshold) - t.testMetrics(c, 1., byteRate, expectThreshold) + testMetrics(re, 120., byteRate, expectThreshold) + testMetrics(re, 60., byteRate, expectThreshold) + testMetrics(re, 30., byteRate, expectThreshold) + testMetrics(re, 17., byteRate, expectThreshold) + testMetrics(re, 1., byteRate, expectThreshold) } -func (t *testHotPeerCache) testMetrics(c *C, interval, byteRate, expectThreshold float64) { +func testMetrics(re *require.Assertions, interval, byteRate, expectThreshold float64) { cache := NewHotPeerCache(Read) storeID := uint64(1) - c.Assert(byteRate, GreaterEqual, minHotThresholds[RegionReadBytes]) + re.GreaterOrEqual(byteRate, minHotThresholds[RegionReadBytes]) for i := uint64(1); i < TopNN+10; i++ { var oldItem *HotPeerStat for { @@ -401,14 +397,15 @@ func (t *testHotPeerCache) testMetrics(c *C, interval, byteRate, expectThreshold } thresholds := cache.calcHotThresholds(storeID) if i < TopNN { - c.Assert(thresholds[RegionReadBytes], Equals, minHotThresholds[RegionReadBytes]) + re.Equal(minHotThresholds[RegionReadBytes], thresholds[RegionReadBytes]) } else { - c.Assert(thresholds[RegionReadBytes], Equals, expectThreshold) + re.Equal(expectThreshold, thresholds[RegionReadBytes]) } } } -func (t *testHotPeerCache) TestRemoveFromCache(c *C) { +func TestRemoveFromCache(t *testing.T) { + re := require.New(t) peerCount := 3 interval := uint64(5) checkers := []check{checkAndUpdate, checkAndUpdateWithOrdering} @@ -418,29 +415,30 @@ func (t *testHotPeerCache) TestRemoveFromCache(c *C) { // prepare intervalSums := make(map[uint64]int) for i := 1; i <= 200; i++ { - rets := checker(c, cache, region) - checkIntervalSumContinuous(c, intervalSums, rets, interval) + rets := checker(re, cache, region) + checkIntervalSumContinuous(re, intervalSums, rets, interval) } // make the interval sum of peers are different - checkAndUpdateSkipOne(c, cache, region) + checkAndUpdateSkipOne(re, cache, region) checkIntervalSum(cache, region) // check whether cold cache is cleared var isClear bool intervalSums = make(map[uint64]int) region = region.Clone(core.SetWrittenBytes(0), core.SetWrittenKeys(0), core.SetWrittenQuery(0)) for i := 1; i <= 200; i++ { - rets := checker(c, cache, region) - checkIntervalSumContinuous(c, intervalSums, rets, interval) + rets := checker(re, cache, region) + checkIntervalSumContinuous(re, intervalSums, rets, interval) if len(cache.storesOfRegion[region.GetID()]) == 0 { isClear = true break } } - c.Assert(isClear, IsTrue) + re.True(isClear) } } -func (t *testHotPeerCache) TestRemoveFromCacheRandom(c *C) { +func TestRemoveFromCacheRandom(t *testing.T) { + re := require.New(t) peerCounts := []int{3, 5} intervals := []uint64{120, 60, 10, 5} checkers := []check{checkAndUpdate, checkAndUpdateWithOrdering} @@ -455,12 +453,12 @@ func (t *testHotPeerCache) TestRemoveFromCacheRandom(c *C) { step := func(i int) { tmp := uint64(0) if i%5 == 0 { - tmp, region = schedule(c, removeReplica, region) + tmp, region = schedule(re, removeReplica, region) } - rets := checker(c, cache, region) - checkIntervalSumContinuous(c, intervalSums, rets, interval) + rets := checker(re, cache, region) + checkIntervalSumContinuous(re, intervalSums, rets, interval) if i%5 == 0 { - _, region = schedule(c, addReplica, region, target) + _, region = schedule(re, addReplica, region, target) target = tmp } } @@ -473,9 +471,9 @@ func (t *testHotPeerCache) TestRemoveFromCacheRandom(c *C) { } } if interval < RegionHeartBeatReportInterval { - c.Assert(checkIntervalSum(cache, region), IsTrue) + re.True(checkIntervalSum(cache, region)) } - c.Assert(cache.storesOfRegion[region.GetID()], HasLen, peerCount) + re.Len(cache.storesOfRegion[region.GetID()], peerCount) // check whether cold cache is cleared var isClear bool @@ -488,119 +486,98 @@ func (t *testHotPeerCache) TestRemoveFromCacheRandom(c *C) { break } } - c.Assert(isClear, IsTrue) + re.True(isClear) } } } } -func checkCoolDown(c *C, cache *hotPeerCache, region *core.RegionInfo, expect bool) { +func checkCoolDown(re *require.Assertions, cache *hotPeerCache, region *core.RegionInfo, expect bool) { item := cache.getOldHotPeerStat(region.GetID(), region.GetLeader().GetStoreId()) - c.Assert(item.IsNeedCoolDownTransferLeader(3), Equals, expect) + re.Equal(expect, item.IsNeedCoolDownTransferLeader(3)) } -func (t *testHotPeerCache) TestCoolDownTransferLeader(c *C) { +func TestCoolDownTransferLeader(t *testing.T) { + re := require.New(t) cache := NewHotPeerCache(Read) region := buildRegion(Read, 3, 60) moveLeader := func() { - _, region = schedule(c, movePeer, region, 10) - checkAndUpdate(c, cache, region) - checkCoolDown(c, cache, region, false) - _, region = schedule(c, transferLeader, region, 10) - checkAndUpdate(c, cache, region) - checkCoolDown(c, cache, region, true) + _, region = schedule(re, movePeer, region, 10) + checkAndUpdate(re, cache, region) + checkCoolDown(re, cache, region, false) + _, region = schedule(re, transferLeader, region, 10) + checkAndUpdate(re, cache, region) + checkCoolDown(re, cache, region, true) } transferLeader := func() { - _, region = schedule(c, transferLeader, region) - checkAndUpdate(c, cache, region) - checkCoolDown(c, cache, region, true) + _, region = schedule(re, transferLeader, region) + checkAndUpdate(re, cache, region) + checkCoolDown(re, cache, region, true) } movePeer := func() { - _, region = schedule(c, movePeer, region, 10) - checkAndUpdate(c, cache, region) - checkCoolDown(c, cache, region, false) + _, region = schedule(re, movePeer, region, 10) + checkAndUpdate(re, cache, region) + checkCoolDown(re, cache, region, false) } addReplica := func() { - _, region = schedule(c, addReplica, region, 10) - checkAndUpdate(c, cache, region) - checkCoolDown(c, cache, region, false) + _, region = schedule(re, addReplica, region, 10) + checkAndUpdate(re, cache, region) + checkCoolDown(re, cache, region, false) } removeReplica := func() { - _, region = schedule(c, removeReplica, region, 10) - checkAndUpdate(c, cache, region) - checkCoolDown(c, cache, region, false) + _, region = schedule(re, removeReplica, region, 10) + checkAndUpdate(re, cache, region) + checkCoolDown(re, cache, region, false) } cases := []func(){moveLeader, transferLeader, movePeer, addReplica, removeReplica} for _, runCase := range cases { cache = NewHotPeerCache(Read) region = buildRegion(Read, 3, 60) for i := 1; i <= 200; i++ { - checkAndUpdate(c, cache, region) + checkAndUpdate(re, cache, region) } - checkCoolDown(c, cache, region, false) + checkCoolDown(re, cache, region, false) runCase() } } // See issue #4510 -func (t *testHotPeerCache) TestCacheInherit(c *C) { +func TestCacheInherit(t *testing.T) { + re := require.New(t) cache := NewHotPeerCache(Read) region := buildRegion(Read, 3, 10) // prepare for i := 1; i <= 200; i++ { - checkAndUpdate(c, cache, region) + checkAndUpdate(re, cache, region) } // move peer newStoreID := uint64(10) - _, region = schedule(c, addReplica, region, newStoreID) - checkAndUpdate(c, cache, region) - newStoreID, region = schedule(c, removeReplica, region) - rets := checkAndUpdate(c, cache, region) + _, region = schedule(re, addReplica, region, newStoreID) + checkAndUpdate(re, cache, region) + newStoreID, region = schedule(re, removeReplica, region) + rets := checkAndUpdate(re, cache, region) for _, ret := range rets { if ret.actionType != Remove { flow := ret.GetLoads()[RegionReadBytes] - c.Assert(flow, Equals, float64(region.GetBytesRead()/ReadReportInterval)) + re.Equal(float64(region.GetBytesRead()/ReadReportInterval), flow) } } // new flow newFlow := region.GetBytesRead() * 10 region = region.Clone(core.SetReadBytes(newFlow)) for i := 1; i <= 200; i++ { - checkAndUpdate(c, cache, region) + checkAndUpdate(re, cache, region) } // move peer - _, region = schedule(c, addReplica, region, newStoreID) - checkAndUpdate(c, cache, region) - _, region = schedule(c, removeReplica, region) - rets = checkAndUpdate(c, cache, region) + _, region = schedule(re, addReplica, region, newStoreID) + checkAndUpdate(re, cache, region) + _, region = schedule(re, removeReplica, region) + rets = checkAndUpdate(re, cache, region) for _, ret := range rets { if ret.actionType != Remove { flow := ret.GetLoads()[RegionReadBytes] - c.Assert(flow, Equals, float64(newFlow/ReadReportInterval)) - } - } -} - -func BenchmarkCheckRegionFlow(b *testing.B) { - cache := NewHotPeerCache(Read) - region := buildRegion(Read, 3, 10) - peerInfos := make([]*core.PeerInfo, 0) - for _, peer := range region.GetPeers() { - peerInfo := core.NewPeerInfo(peer, region.GetLoads(), 10) - peerInfos = append(peerInfos, peerInfo) - } - b.ResetTimer() - for i := 0; i < b.N; i++ { - items := make([]*HotPeerStat, 0) - for _, peerInfo := range peerInfos { - item := cache.checkPeerFlow(peerInfo, region) - if item != nil { - items = append(items, item) - } - } - for _, ret := range items { - cache.updateStat(ret) + re.Equal(float64(newFlow/ReadReportInterval), flow) } } } @@ -610,7 +587,7 @@ type testMovingAverageCase struct { expect []float64 } -func checkMovingAverage(c *C, testCase *testMovingAverageCase) { +func checkMovingAverage(re *require.Assertions, testCase *testMovingAverageCase) { interval := 1 * time.Second tm := movingaverage.NewTimeMedian(DefaultAotSize, DefaultWriteMfSize, interval) var results []float64 @@ -618,11 +595,11 @@ func checkMovingAverage(c *C, testCase *testMovingAverageCase) { tm.Add(data, interval) results = append(results, tm.Get()) } - c.Assert(results, DeepEquals, testCase.expect) + re.Equal(testCase.expect, results) } -// -func (t *testHotPeerCache) TestUnstableData(c *C) { +func TestUnstableData(t *testing.T) { + re := require.New(t) cases := []*testMovingAverageCase{ { report: []float64{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, @@ -650,6 +627,29 @@ func (t *testHotPeerCache) TestUnstableData(c *C) { }, } for i := range cases { - checkMovingAverage(c, cases[i]) + checkMovingAverage(re, cases[i]) + } +} + +func BenchmarkCheckRegionFlow(b *testing.B) { + cache := NewHotPeerCache(Read) + region := buildRegion(Read, 3, 10) + peerInfos := make([]*core.PeerInfo, 0) + for _, peer := range region.GetPeers() { + peerInfo := core.NewPeerInfo(peer, region.GetLoads(), 10) + peerInfos = append(peerInfos, peerInfo) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + items := make([]*HotPeerStat, 0) + for _, peerInfo := range peerInfos { + item := cache.checkPeerFlow(peerInfo, region) + if item != nil { + items = append(items, item) + } + } + for _, ret := range items { + cache.updateStat(ret) + } } } diff --git a/server/statistics/kind_test.go b/server/statistics/kind_test.go index 86e9a77e10b..ccde182eefe 100644 --- a/server/statistics/kind_test.go +++ b/server/statistics/kind_test.go @@ -15,17 +15,16 @@ package statistics import ( - . "github.com/pingcap/check" + "testing" + "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/server/core" ) -var _ = Suite(&testRegionInfoSuite{}) - -type testRegionInfoSuite struct{} - -func (s *testRegionInfoSuite) TestGetLoads(c *C) { +func TestGetLoads(t *testing.T) { + re := require.New(t) queryStats := &pdpb.QueryStats{ Get: 5, Coprocessor: 6, @@ -45,24 +44,24 @@ func (s *testRegionInfoSuite) TestGetLoads(c *C) { core.SetWrittenKeys(4), core.SetQueryStats(queryStats)) loads := regionA.GetLoads() - c.Assert(loads, HasLen, int(RegionStatCount)) - c.Assert(float64(regionA.GetBytesRead()), Equals, loads[RegionReadBytes]) - c.Assert(float64(regionA.GetKeysRead()), Equals, loads[RegionReadKeys]) - c.Assert(float64(regionA.GetReadQueryNum()), Equals, loads[RegionReadQuery]) + re.Len(loads, int(RegionStatCount)) + re.Equal(float64(regionA.GetBytesRead()), loads[RegionReadBytes]) + re.Equal(float64(regionA.GetKeysRead()), loads[RegionReadKeys]) + re.Equal(float64(regionA.GetReadQueryNum()), loads[RegionReadQuery]) readQuery := float64(queryStats.Coprocessor + queryStats.Get + queryStats.Scan) - c.Assert(float64(regionA.GetReadQueryNum()), Equals, readQuery) - c.Assert(float64(regionA.GetBytesWritten()), Equals, loads[RegionWriteBytes]) - c.Assert(float64(regionA.GetKeysWritten()), Equals, loads[RegionWriteKeys]) - c.Assert(float64(regionA.GetWriteQueryNum()), Equals, loads[RegionWriteQuery]) + re.Equal(float64(regionA.GetReadQueryNum()), readQuery) + re.Equal(float64(regionA.GetBytesWritten()), loads[RegionWriteBytes]) + re.Equal(float64(regionA.GetKeysWritten()), loads[RegionWriteKeys]) + re.Equal(float64(regionA.GetWriteQueryNum()), loads[RegionWriteQuery]) writeQuery := float64(queryStats.Put + queryStats.Delete + queryStats.DeleteRange + queryStats.AcquirePessimisticLock + queryStats.Rollback + queryStats.Prewrite + queryStats.Commit) - c.Assert(float64(regionA.GetWriteQueryNum()), Equals, writeQuery) + re.Equal(float64(regionA.GetWriteQueryNum()), writeQuery) loads = regionA.GetWriteLoads() - c.Assert(loads, HasLen, int(RegionStatCount)) - c.Assert(0.0, Equals, loads[RegionReadBytes]) - c.Assert(0.0, Equals, loads[RegionReadKeys]) - c.Assert(0.0, Equals, loads[RegionReadQuery]) - c.Assert(float64(regionA.GetBytesWritten()), Equals, loads[RegionWriteBytes]) - c.Assert(float64(regionA.GetKeysWritten()), Equals, loads[RegionWriteKeys]) - c.Assert(float64(regionA.GetWriteQueryNum()), Equals, loads[RegionWriteQuery]) + re.Len(loads, int(RegionStatCount)) + re.Equal(0.0, loads[RegionReadBytes]) + re.Equal(0.0, loads[RegionReadKeys]) + re.Equal(0.0, loads[RegionReadQuery]) + re.Equal(float64(regionA.GetBytesWritten()), loads[RegionWriteBytes]) + re.Equal(float64(regionA.GetKeysWritten()), loads[RegionWriteKeys]) + re.Equal(float64(regionA.GetWriteQueryNum()), loads[RegionWriteQuery]) } diff --git a/server/statistics/region_collection_test.go b/server/statistics/region_collection_test.go index eb100e958fd..932c35f139e 100644 --- a/server/statistics/region_collection_test.go +++ b/server/statistics/region_collection_test.go @@ -17,36 +17,21 @@ package statistics import ( "testing" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/core" "github.com/tikv/pd/server/schedule/placement" "github.com/tikv/pd/server/storage" - "github.com/tikv/pd/server/storage/endpoint" ) -func TestStatistics(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testRegionStatisticsSuite{}) - -type testRegionStatisticsSuite struct { - store endpoint.RuleStorage - manager *placement.RuleManager -} - -func (t *testRegionStatisticsSuite) SetUpTest(c *C) { - t.store = storage.NewStorageWithMemoryBackend() - var err error - t.manager = placement.NewRuleManager(t.store, nil, nil) - err = t.manager.Initialize(3, []string{"zone", "rack", "host"}) - c.Assert(err, IsNil) -} - -func (t *testRegionStatisticsSuite) TestRegionStatistics(c *C) { +func TestRegionStatistics(t *testing.T) { + re := require.New(t) + store := storage.NewStorageWithMemoryBackend() + manager := placement.NewRuleManager(store, nil, nil) + err := manager.Initialize(3, []string{"zone", "rack", "host"}) + re.NoError(err) opt := config.NewTestOptions() opt.SetPlacementRuleEnabled(false) peers := []*metapb.Peer{ @@ -80,14 +65,14 @@ func (t *testRegionStatisticsSuite) TestRegionStatistics(c *C) { r2 := &metapb.Region{Id: 2, Peers: peers[0:2], StartKey: []byte("cc"), EndKey: []byte("dd")} region1 := core.NewRegionInfo(r1, peers[0]) region2 := core.NewRegionInfo(r2, peers[0]) - regionStats := NewRegionStatistics(opt, t.manager, nil) + regionStats := NewRegionStatistics(opt, manager, nil) regionStats.Observe(region1, stores) - c.Assert(regionStats.stats[ExtraPeer], HasLen, 1) - c.Assert(regionStats.stats[LearnerPeer], HasLen, 1) - c.Assert(regionStats.stats[EmptyRegion], HasLen, 1) - c.Assert(regionStats.stats[UndersizedRegion], HasLen, 1) - c.Assert(regionStats.offlineStats[ExtraPeer], HasLen, 1) - c.Assert(regionStats.offlineStats[LearnerPeer], HasLen, 1) + re.Len(regionStats.stats[ExtraPeer], 1) + re.Len(regionStats.stats[LearnerPeer], 1) + re.Len(regionStats.stats[EmptyRegion], 1) + re.Len(regionStats.stats[UndersizedRegion], 1) + re.Len(regionStats.offlineStats[ExtraPeer], 1) + re.Len(regionStats.offlineStats[LearnerPeer], 1) region1 = region1.Clone( core.WithDownPeers(downPeers), @@ -95,58 +80,63 @@ func (t *testRegionStatisticsSuite) TestRegionStatistics(c *C) { core.SetApproximateSize(144), ) regionStats.Observe(region1, stores) - c.Assert(regionStats.stats[ExtraPeer], HasLen, 1) - c.Assert(regionStats.stats[MissPeer], HasLen, 0) - c.Assert(regionStats.stats[DownPeer], HasLen, 1) - c.Assert(regionStats.stats[PendingPeer], HasLen, 1) - c.Assert(regionStats.stats[LearnerPeer], HasLen, 1) - c.Assert(regionStats.stats[EmptyRegion], HasLen, 0) - c.Assert(regionStats.stats[OversizedRegion], HasLen, 1) - c.Assert(regionStats.stats[UndersizedRegion], HasLen, 0) - c.Assert(regionStats.offlineStats[ExtraPeer], HasLen, 1) - c.Assert(regionStats.offlineStats[MissPeer], HasLen, 0) - c.Assert(regionStats.offlineStats[DownPeer], HasLen, 1) - c.Assert(regionStats.offlineStats[PendingPeer], HasLen, 1) - c.Assert(regionStats.offlineStats[LearnerPeer], HasLen, 1) - c.Assert(regionStats.offlineStats[OfflinePeer], HasLen, 1) + re.Len(regionStats.stats[ExtraPeer], 1) + re.Len(regionStats.stats[MissPeer], 0) + re.Len(regionStats.stats[DownPeer], 1) + re.Len(regionStats.stats[PendingPeer], 1) + re.Len(regionStats.stats[LearnerPeer], 1) + re.Len(regionStats.stats[EmptyRegion], 0) + re.Len(regionStats.stats[OversizedRegion], 1) + re.Len(regionStats.stats[UndersizedRegion], 0) + re.Len(regionStats.offlineStats[ExtraPeer], 1) + re.Len(regionStats.offlineStats[MissPeer], 0) + re.Len(regionStats.offlineStats[DownPeer], 1) + re.Len(regionStats.offlineStats[PendingPeer], 1) + re.Len(regionStats.offlineStats[LearnerPeer], 1) + re.Len(regionStats.offlineStats[OfflinePeer], 1) region2 = region2.Clone(core.WithDownPeers(downPeers[0:1])) regionStats.Observe(region2, stores[0:2]) - c.Assert(regionStats.stats[ExtraPeer], HasLen, 1) - c.Assert(regionStats.stats[MissPeer], HasLen, 1) - c.Assert(regionStats.stats[DownPeer], HasLen, 2) - c.Assert(regionStats.stats[PendingPeer], HasLen, 1) - c.Assert(regionStats.stats[LearnerPeer], HasLen, 1) - c.Assert(regionStats.stats[OversizedRegion], HasLen, 1) - c.Assert(regionStats.stats[UndersizedRegion], HasLen, 1) - c.Assert(regionStats.offlineStats[ExtraPeer], HasLen, 1) - c.Assert(regionStats.offlineStats[MissPeer], HasLen, 0) - c.Assert(regionStats.offlineStats[DownPeer], HasLen, 1) - c.Assert(regionStats.offlineStats[PendingPeer], HasLen, 1) - c.Assert(regionStats.offlineStats[LearnerPeer], HasLen, 1) - c.Assert(regionStats.offlineStats[OfflinePeer], HasLen, 1) + re.Len(regionStats.stats[ExtraPeer], 1) + re.Len(regionStats.stats[MissPeer], 1) + re.Len(regionStats.stats[DownPeer], 2) + re.Len(regionStats.stats[PendingPeer], 1) + re.Len(regionStats.stats[LearnerPeer], 1) + re.Len(regionStats.stats[OversizedRegion], 1) + re.Len(regionStats.stats[UndersizedRegion], 1) + re.Len(regionStats.offlineStats[ExtraPeer], 1) + re.Len(regionStats.offlineStats[MissPeer], 0) + re.Len(regionStats.offlineStats[DownPeer], 1) + re.Len(regionStats.offlineStats[PendingPeer], 1) + re.Len(regionStats.offlineStats[LearnerPeer], 1) + re.Len(regionStats.offlineStats[OfflinePeer], 1) region1 = region1.Clone(core.WithRemoveStorePeer(7)) regionStats.Observe(region1, stores[0:3]) - c.Assert(regionStats.stats[ExtraPeer], HasLen, 0) - c.Assert(regionStats.stats[MissPeer], HasLen, 1) - c.Assert(regionStats.stats[DownPeer], HasLen, 2) - c.Assert(regionStats.stats[PendingPeer], HasLen, 1) - c.Assert(regionStats.stats[LearnerPeer], HasLen, 0) - c.Assert(regionStats.offlineStats[ExtraPeer], HasLen, 0) - c.Assert(regionStats.offlineStats[MissPeer], HasLen, 0) - c.Assert(regionStats.offlineStats[DownPeer], HasLen, 0) - c.Assert(regionStats.offlineStats[PendingPeer], HasLen, 0) - c.Assert(regionStats.offlineStats[LearnerPeer], HasLen, 0) - c.Assert(regionStats.offlineStats[OfflinePeer], HasLen, 0) + re.Len(regionStats.stats[ExtraPeer], 0) + re.Len(regionStats.stats[MissPeer], 1) + re.Len(regionStats.stats[DownPeer], 2) + re.Len(regionStats.stats[PendingPeer], 1) + re.Len(regionStats.stats[LearnerPeer], 0) + re.Len(regionStats.offlineStats[ExtraPeer], 0) + re.Len(regionStats.offlineStats[MissPeer], 0) + re.Len(regionStats.offlineStats[DownPeer], 0) + re.Len(regionStats.offlineStats[PendingPeer], 0) + re.Len(regionStats.offlineStats[LearnerPeer], 0) + re.Len(regionStats.offlineStats[OfflinePeer], 0) store3 = stores[3].Clone(core.UpStore()) stores[3] = store3 regionStats.Observe(region1, stores) - c.Assert(regionStats.stats[OfflinePeer], HasLen, 0) + re.Len(regionStats.stats[OfflinePeer], 0) } -func (t *testRegionStatisticsSuite) TestRegionStatisticsWithPlacementRule(c *C) { +func TestRegionStatisticsWithPlacementRule(t *testing.T) { + re := require.New(t) + store := storage.NewStorageWithMemoryBackend() + manager := placement.NewRuleManager(store, nil, nil) + err := manager.Initialize(3, []string{"zone", "rack", "host"}) + re.NoError(err) opt := config.NewTestOptions() opt.SetPlacementRuleEnabled(true) peers := []*metapb.Peer{ @@ -173,20 +163,21 @@ func (t *testRegionStatisticsSuite) TestRegionStatisticsWithPlacementRule(c *C) region2 := core.NewRegionInfo(r2, peers[0]) region3 := core.NewRegionInfo(r3, peers[0]) region4 := core.NewRegionInfo(r4, peers[0]) - regionStats := NewRegionStatistics(opt, t.manager, nil) + regionStats := NewRegionStatistics(opt, manager, nil) // r2 didn't match the rules regionStats.Observe(region2, stores) - c.Assert(regionStats.stats[MissPeer], HasLen, 1) + re.Len(regionStats.stats[MissPeer], 1) regionStats.Observe(region3, stores) // r3 didn't match the rules - c.Assert(regionStats.stats[ExtraPeer], HasLen, 1) + re.Len(regionStats.stats[ExtraPeer], 1) regionStats.Observe(region4, stores) // r4 match the rules - c.Assert(regionStats.stats[MissPeer], HasLen, 1) - c.Assert(regionStats.stats[ExtraPeer], HasLen, 1) + re.Len(regionStats.stats[MissPeer], 1) + re.Len(regionStats.stats[ExtraPeer], 1) } -func (t *testRegionStatisticsSuite) TestRegionLabelIsolationLevel(c *C) { +func TestRegionLabelIsolationLevel(t *testing.T) { + re := require.New(t) locationLabels := []string{"zone", "rack", "host"} labelLevelStats := NewLabelStatistics() labelsSet := [][]map[string]string{ @@ -256,7 +247,7 @@ func (t *testRegionStatisticsSuite) TestRegionLabelIsolationLevel(c *C) { region := core.NewRegionInfo(&metapb.Region{Id: uint64(regionID)}, nil) label := GetRegionLabelIsolation(stores, locationLabels) labelLevelStats.Observe(region, stores, locationLabels) - c.Assert(label, Equals, res) + re.Equal(res, label) regionID++ } @@ -264,16 +255,16 @@ func (t *testRegionStatisticsSuite) TestRegionLabelIsolationLevel(c *C) { f(labels, res[i], locationLabels) } for i, res := range counter { - c.Assert(labelLevelStats.labelCounter[i], Equals, res) + re.Equal(res, labelLevelStats.labelCounter[i]) } label := GetRegionLabelIsolation(nil, locationLabels) - c.Assert(label, Equals, nonIsolation) + re.Equal(nonIsolation, label) label = GetRegionLabelIsolation(nil, nil) - c.Assert(label, Equals, nonIsolation) + re.Equal(nonIsolation, label) store := core.NewStoreInfo(&metapb.Store{Id: 1, Address: "mock://tikv-1"}, core.SetStoreLabels([]*metapb.StoreLabel{{Key: "foo", Value: "bar"}})) label = GetRegionLabelIsolation([]*core.StoreInfo{store}, locationLabels) - c.Assert(label, Equals, "zone") + re.Equal("zone", label) regionID = 1 res = []string{"rack", "none", "zone", "rack", "none", "rack", "none"} @@ -284,6 +275,6 @@ func (t *testRegionStatisticsSuite) TestRegionLabelIsolationLevel(c *C) { f(labels, res[i], locationLabels) } for i, res := range counter { - c.Assert(labelLevelStats.labelCounter[i], Equals, res) + re.Equal(res, labelLevelStats.labelCounter[i]) } } diff --git a/server/statistics/store_collection_test.go b/server/statistics/store_collection_test.go index f93a4b54bb5..388fc13b27e 100644 --- a/server/statistics/store_collection_test.go +++ b/server/statistics/store_collection_test.go @@ -15,19 +15,17 @@ package statistics import ( + "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/core" ) -var _ = Suite(&testStoreStatisticsSuite{}) - -type testStoreStatisticsSuite struct{} - -func (t *testStoreStatisticsSuite) TestStoreStatistics(c *C) { +func TestStoreStatistics(t *testing.T) { + re := require.New(t) opt := config.NewTestOptions() rep := opt.GetReplicationConfig().Clone() rep.LocationLabels = []string{"zone", "host"} @@ -62,22 +60,22 @@ func (t *testStoreStatisticsSuite) TestStoreStatistics(c *C) { } stats := storeStats.stats - c.Assert(stats.Up, Equals, 6) - c.Assert(stats.Preparing, Equals, 7) - c.Assert(stats.Serving, Equals, 0) - c.Assert(stats.Removing, Equals, 1) - c.Assert(stats.Removed, Equals, 1) - c.Assert(stats.Down, Equals, 1) - c.Assert(stats.Offline, Equals, 1) - c.Assert(stats.RegionCount, Equals, 0) - c.Assert(stats.Unhealthy, Equals, 0) - c.Assert(stats.Disconnect, Equals, 0) - c.Assert(stats.Tombstone, Equals, 1) - c.Assert(stats.LowSpace, Equals, 8) - c.Assert(stats.LabelCounter["zone:z1"], Equals, 2) - c.Assert(stats.LabelCounter["zone:z2"], Equals, 2) - c.Assert(stats.LabelCounter["zone:z3"], Equals, 2) - c.Assert(stats.LabelCounter["host:h1"], Equals, 4) - c.Assert(stats.LabelCounter["host:h2"], Equals, 4) - c.Assert(stats.LabelCounter["zone:unknown"], Equals, 2) + re.Equal(6, stats.Up) + re.Equal(7, stats.Preparing) + re.Equal(0, stats.Serving) + re.Equal(1, stats.Removing) + re.Equal(1, stats.Removed) + re.Equal(1, stats.Down) + re.Equal(1, stats.Offline) + re.Equal(0, stats.RegionCount) + re.Equal(0, stats.Unhealthy) + re.Equal(0, stats.Disconnect) + re.Equal(1, stats.Tombstone) + re.Equal(8, stats.LowSpace) + re.Equal(2, stats.LabelCounter["zone:z1"]) + re.Equal(2, stats.LabelCounter["zone:z2"]) + re.Equal(2, stats.LabelCounter["zone:z3"]) + re.Equal(4, stats.LabelCounter["host:h1"]) + re.Equal(4, stats.LabelCounter["host:h2"]) + re.Equal(2, stats.LabelCounter["zone:unknown"]) } diff --git a/server/statistics/store_test.go b/server/statistics/store_test.go index e3247ea1c46..89508be41b7 100644 --- a/server/statistics/store_test.go +++ b/server/statistics/store_test.go @@ -15,26 +15,24 @@ package statistics import ( + "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/server/core" ) -var _ = Suite(&testStoreSuite{}) - -type testStoreSuite struct{} - -func (s *testStoreSuite) TestFilterUnhealtyStore(c *C) { +func TestFilterUnhealtyStore(t *testing.T) { + re := require.New(t) stats := NewStoresStats() cluster := core.NewBasicCluster() for i := uint64(1); i <= 5; i++ { cluster.PutStore(core.NewStoreInfo(&metapb.Store{Id: i}, core.SetLastHeartbeatTS(time.Now()))) stats.Observe(i, &pdpb.StoreStats{}) } - c.Assert(stats.GetStoresLoads(), HasLen, 5) + re.Len(stats.GetStoresLoads(), 5) cluster.PutStore(cluster.GetStore(1).Clone(core.SetLastHeartbeatTS(time.Now().Add(-24 * time.Hour)))) cluster.PutStore(cluster.GetStore(2).Clone(core.TombstoneStore())) @@ -42,7 +40,7 @@ func (s *testStoreSuite) TestFilterUnhealtyStore(c *C) { stats.FilterUnhealthyStore(cluster) loads := stats.GetStoresLoads() - c.Assert(loads, HasLen, 2) - c.Assert(loads[4], NotNil) - c.Assert(loads[5], NotNil) + re.Len(loads, 2) + re.NotNil(loads[4]) + re.NotNil(loads[5]) } diff --git a/server/statistics/topn_test.go b/server/statistics/topn_test.go index 0bf1c4e4d21..fa9e4ebd5f1 100644 --- a/server/statistics/topn_test.go +++ b/server/statistics/topn_test.go @@ -17,15 +17,12 @@ package statistics import ( "math/rand" "sort" + "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testTopNSuite{}) - -type testTopNSuite struct{} - type item struct { id uint64 values []float64 @@ -39,21 +36,22 @@ func (it *item) Less(k int, than TopNItem) bool { return it.values[k] < than.(*item).values[k] } -func (s *testTopNSuite) TestPut(c *C) { +func TestPut(t *testing.T) { + re := require.New(t) const Total, N = 10000, 50 tn := NewTopN(DimLen, N, 1*time.Hour) - putPerm(c, tn, Total, func(x int) float64 { + putPerm(re, tn, Total, func(x int) float64 { return float64(-x) + 1 }, false /*insert*/) - putPerm(c, tn, Total, func(x int) float64 { + putPerm(re, tn, Total, func(x int) float64 { return float64(-x) }, true /*update*/) // check GetTopNMin for k := 0; k < DimLen; k++ { - c.Assert(tn.GetTopNMin(k).(*item).values[k], Equals, float64(1-N)) + re.Equal(float64(1-N), tn.GetTopNMin(k).(*item).values[k]) } { @@ -65,7 +63,7 @@ func (s *testTopNSuite) TestPut(c *C) { } // check update worked for i, v := range topns { - c.Assert(v, Equals, float64(-i)) + re.Equal(float64(-i), v) } } @@ -78,7 +76,7 @@ func (s *testTopNSuite) TestPut(c *C) { } // check update worked for i, v := range all { - c.Assert(v, Equals, float64(-i)) + re.Equal(float64(-i), v) } } @@ -96,19 +94,19 @@ func (s *testTopNSuite) TestPut(c *C) { } sort.Sort(sort.Reverse(sort.Float64Slice(all))) - c.Assert(topn, DeepEquals, all[:N]) + re.Equal(all[:N], topn) } } // check Get for i := uint64(0); i < Total; i++ { it := tn.Get(i).(*item) - c.Assert(it.id, Equals, i) - c.Assert(it.values[0], Equals, -float64(i)) + re.Equal(i, it.id) + re.Equal(-float64(i), it.values[0]) } } -func putPerm(c *C, tn *TopN, total int, f func(x int) float64, isUpdate bool) { +func putPerm(re *require.Assertions, tn *TopN, total int, f func(x int) float64, isUpdate bool) { { // insert dims := make([][]int, DimLen) for k := 0; k < DimLen; k++ { @@ -122,16 +120,17 @@ func putPerm(c *C, tn *TopN, total int, f func(x int) float64, isUpdate bool) { for k := 0; k < DimLen; k++ { item.values[k] = f(dims[k][i]) } - c.Assert(tn.Put(item), Equals, isUpdate) + re.Equal(isUpdate, tn.Put(item)) } } } -func (s *testTopNSuite) TestRemove(c *C) { +func TestRemove(t *testing.T) { + re := require.New(t) const Total, N = 10000, 50 tn := NewTopN(DimLen, N, 1*time.Hour) - putPerm(c, tn, Total, func(x int) float64 { + putPerm(re, tn, Total, func(x int) float64 { return float64(-x) }, false /*insert*/) @@ -139,28 +138,28 @@ func (s *testTopNSuite) TestRemove(c *C) { for i := 0; i < Total; i++ { if i%3 != 0 { it := tn.Remove(uint64(i)).(*item) - c.Assert(it.id, Equals, uint64(i)) + re.Equal(uint64(i), it.id) } } // check Remove worked for i := 0; i < Total; i++ { if i%3 != 0 { - c.Assert(tn.Remove(uint64(i)), IsNil) + re.Nil(tn.Remove(uint64(i))) } } - c.Assert(tn.GetTopNMin(0).(*item).id, Equals, uint64(3*(N-1))) + re.Equal(uint64(3*(N-1)), tn.GetTopNMin(0).(*item).id) { topns := make([]float64, N) for _, it := range tn.GetAllTopN(0) { it := it.(*item) topns[it.id/3] = it.values[0] - c.Assert(it.id%3, Equals, uint64(0)) + re.Equal(uint64(0), it.id%3) } for i, v := range topns { - c.Assert(v, Equals, float64(-i*3)) + re.Equal(float64(-i*3), v) } } @@ -169,10 +168,10 @@ func (s *testTopNSuite) TestRemove(c *C) { for _, it := range tn.GetAll() { it := it.(*item) all[it.id/3] = it.values[0] - c.Assert(it.id%3, Equals, uint64(0)) + re.Equal(uint64(0), it.id%3) } for i, v := range all { - c.Assert(v, Equals, float64(-i*3)) + re.Equal(float64(-i*3), v) } } @@ -190,22 +189,23 @@ func (s *testTopNSuite) TestRemove(c *C) { } sort.Sort(sort.Reverse(sort.Float64Slice(all))) - c.Assert(topn, DeepEquals, all[:N]) + re.Equal(all[:N], topn) } } for i := uint64(0); i < Total; i += 3 { it := tn.Get(i).(*item) - c.Assert(it.id, Equals, i) - c.Assert(it.values[0], Equals, -float64(i)) + re.Equal(i, it.id) + re.Equal(-float64(i), it.values[0]) } } -func (s *testTopNSuite) TestTTL(c *C) { +func TestTTL(t *testing.T) { + re := require.New(t) const Total, N = 1000, 50 tn := NewTopN(DimLen, 50, 900*time.Millisecond) - putPerm(c, tn, Total, func(x int) float64 { + putPerm(re, tn, Total, func(x int) float64 { return float64(-x) }, false /*insert*/) @@ -215,27 +215,27 @@ func (s *testTopNSuite) TestTTL(c *C) { for k := 1; k < DimLen; k++ { item.values = append(item.values, rand.NormFloat64()) } - c.Assert(tn.Put(item), IsTrue) + re.True(tn.Put(item)) } for i := 3; i < Total; i += 3 { item := &item{id: uint64(i), values: []float64{float64(-i) + 100}} for k := 1; k < DimLen; k++ { item.values = append(item.values, rand.NormFloat64()) } - c.Assert(tn.Put(item), IsFalse) + re.False(tn.Put(item)) } tn.RemoveExpired() - c.Assert(tn.Len(), Equals, Total/3+1) + re.Equal(Total/3+1, tn.Len()) items := tn.GetAllTopN(0) v := make([]float64, N) for _, it := range items { it := it.(*item) - c.Assert(it.id%3, Equals, uint64(0)) + re.Equal(uint64(0), it.id%3) v[it.id/3] = it.values[0] } for i, x := range v { - c.Assert(x, Equals, float64(-i*3)+100) + re.Equal(float64(-i*3)+100, x) } { // check all dimensions @@ -252,7 +252,7 @@ func (s *testTopNSuite) TestTTL(c *C) { } sort.Sort(sort.Reverse(sort.Float64Slice(all))) - c.Assert(topn, DeepEquals, all[:N]) + re.Equal(all[:N], topn) } } } From 5a64486d5ff9fbe860d65c3377f30a6f319b04a1 Mon Sep 17 00:00:00 2001 From: Yongbo Jiang Date: Mon, 13 Jun 2022 15:34:32 +0800 Subject: [PATCH 42/82] server: add limiter config and reload mechanism (#4842) ref tikv/pd#4666, ref tikv/pd#4839, ref tikv/pd#4869 update limiter config when reload presist config Signed-off-by: Cabinfever_B Co-authored-by: Ryan Leung Co-authored-by: Ti Chi Robot --- server/api/service_middleware.go | 26 +++- server/api/service_middleware_test.go | 111 +++++++++++++++++- server/config/service_middleware_config.go | 29 ++++- .../service_middleware_persist_options.go | 23 +++- server/server.go | 41 ++++++- tests/server/config/config_test.go | 105 +++++++++++++++++ 6 files changed, 322 insertions(+), 13 deletions(-) create mode 100644 tests/server/config/config_test.go diff --git a/server/api/service_middleware.go b/server/api/service_middleware.go index c136f8fbf4e..0f41f8ae725 100644 --- a/server/api/service_middleware.go +++ b/server/api/service_middleware.go @@ -103,8 +103,11 @@ func (h *serviceMiddlewareHandler) SetServiceMiddlewareConfig(w http.ResponseWri func (h *serviceMiddlewareHandler) updateServiceMiddlewareConfig(cfg *config.ServiceMiddlewareConfig, key string, value interface{}) error { kp := strings.Split(key, ".") - if kp[0] == "audit" { + switch kp[0] { + case "audit": return h.updateAudit(cfg, kp[len(kp)-1], value) + case "rate-limit": + return h.updateRateLimit(cfg, kp[len(kp)-1], value) } return errors.Errorf("config prefix %s not found", kp[0]) } @@ -129,3 +132,24 @@ func (h *serviceMiddlewareHandler) updateAudit(config *config.ServiceMiddlewareC } return err } + +func (h *serviceMiddlewareHandler) updateRateLimit(config *config.ServiceMiddlewareConfig, key string, value interface{}) error { + data, err := json.Marshal(map[string]interface{}{key: value}) + if err != nil { + return err + } + + updated, found, err := mergeConfig(&config.RateLimitConfig, data) + if err != nil { + return err + } + + if !found { + return errors.Errorf("config item %s not found", key) + } + + if updated { + err = h.svr.SetRateLimitConfig(config.RateLimitConfig) + } + return err +} diff --git a/server/api/service_middleware_test.go b/server/api/service_middleware_test.go index 3d29b23a693..a1d4804650c 100644 --- a/server/api/service_middleware_test.go +++ b/server/api/service_middleware_test.go @@ -21,20 +21,21 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/failpoint" + "github.com/tikv/pd/pkg/ratelimit" tu "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server" "github.com/tikv/pd/server/config" ) -var _ = Suite(&testServiceMiddlewareSuite{}) +var _ = Suite(&testAuditMiddlewareSuite{}) -type testServiceMiddlewareSuite struct { +type testAuditMiddlewareSuite struct { svr *server.Server cleanup cleanUpFunc urlPrefix string } -func (s *testServiceMiddlewareSuite) SetUpSuite(c *C) { +func (s *testAuditMiddlewareSuite) SetUpSuite(c *C) { s.svr, s.cleanup = mustNewServer(c, func(cfg *config.Config) { cfg.Replication.EnablePlacementRules = false }) @@ -44,19 +45,24 @@ func (s *testServiceMiddlewareSuite) SetUpSuite(c *C) { s.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) } -func (s *testServiceMiddlewareSuite) TearDownSuite(c *C) { +func (s *testAuditMiddlewareSuite) TearDownSuite(c *C) { s.cleanup() } -func (s *testServiceMiddlewareSuite) TestConfigAudit(c *C) { +func (s *testAuditMiddlewareSuite) TestConfigAuditSwitch(c *C) { addr := fmt.Sprintf("%s/service-middleware/config", s.urlPrefix) + + sc := &config.ServiceMiddlewareConfig{} + c.Assert(tu.ReadGetJSON(c, testDialClient, addr, sc), IsNil) + c.Assert(sc.EnableAudit, Equals, false) + ms := map[string]interface{}{ "enable-audit": "true", } postData, err := json.Marshal(ms) c.Assert(err, IsNil) c.Assert(tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(c)), IsNil) - sc := &config.ServiceMiddlewareConfig{} + sc = &config.ServiceMiddlewareConfig{} c.Assert(tu.ReadGetJSON(c, testDialClient, addr, sc), IsNil) c.Assert(sc.EnableAudit, Equals, true) ms = map[string]interface{}{ @@ -98,3 +104,96 @@ func (s *testServiceMiddlewareSuite) TestConfigAudit(c *C) { c.Assert(err, IsNil) c.Assert(tu.CheckPostJSON(testDialClient, addr, postData, tu.Status(c, http.StatusBadRequest), tu.StringEqual(c, "config item audit not found")), IsNil) } + +var _ = Suite(&testRateLimitConfigSuite{}) + +type testRateLimitConfigSuite struct { + svr *server.Server + cleanup cleanUpFunc + urlPrefix string +} + +func (s *testRateLimitConfigSuite) SetUpSuite(c *C) { + s.svr, s.cleanup = mustNewServer(c) + mustWaitLeader(c, []*server.Server{s.svr}) + mustBootstrapCluster(c, s.svr) + s.urlPrefix = fmt.Sprintf("%s%s/api/v1", s.svr.GetAddr(), apiPrefix) +} + +func (s *testRateLimitConfigSuite) TearDownSuite(c *C) { + s.cleanup() +} + +func (s *testRateLimitConfigSuite) TestConfigRateLimitSwitch(c *C) { + addr := fmt.Sprintf("%s/service-middleware/config", s.urlPrefix) + + sc := &config.ServiceMiddlewareConfig{} + c.Assert(tu.ReadGetJSON(c, testDialClient, addr, sc), IsNil) + c.Assert(sc.EnableRateLimit, Equals, false) + + ms := map[string]interface{}{ + "enable-rate-limit": "true", + } + postData, err := json.Marshal(ms) + c.Assert(err, IsNil) + c.Assert(tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(c)), IsNil) + sc = &config.ServiceMiddlewareConfig{} + c.Assert(tu.ReadGetJSON(c, testDialClient, addr, sc), IsNil) + c.Assert(sc.EnableRateLimit, Equals, true) + ms = map[string]interface{}{ + "enable-rate-limit": "false", + } + postData, err = json.Marshal(ms) + c.Assert(err, IsNil) + c.Assert(tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(c)), IsNil) + sc = &config.ServiceMiddlewareConfig{} + c.Assert(tu.ReadGetJSON(c, testDialClient, addr, sc), IsNil) + c.Assert(sc.EnableRateLimit, Equals, false) + + // test empty + ms = map[string]interface{}{} + postData, err = json.Marshal(ms) + c.Assert(err, IsNil) + c.Assert(tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(c), tu.StringContain(c, "The input is empty.")), IsNil) + + ms = map[string]interface{}{ + "rate-limit": "false", + } + postData, err = json.Marshal(ms) + c.Assert(err, IsNil) + c.Assert(tu.CheckPostJSON(testDialClient, addr, postData, tu.Status(c, http.StatusBadRequest), tu.StringEqual(c, "config item rate-limit not found")), IsNil) + + c.Assert(failpoint.Enable("github.com/tikv/pd/server/config/persistServiceMiddlewareFail", "return(true)"), IsNil) + ms = map[string]interface{}{ + "rate-limit.enable-rate-limit": "true", + } + postData, err = json.Marshal(ms) + c.Assert(err, IsNil) + c.Assert(tu.CheckPostJSON(testDialClient, addr, postData, tu.Status(c, http.StatusBadRequest)), IsNil) + c.Assert(failpoint.Disable("github.com/tikv/pd/server/config/persistServiceMiddlewareFail"), IsNil) + + ms = map[string]interface{}{ + "rate-limit.rate-limit": "false", + } + postData, err = json.Marshal(ms) + c.Assert(err, IsNil) + c.Assert(tu.CheckPostJSON(testDialClient, addr, postData, tu.Status(c, http.StatusBadRequest), tu.StringEqual(c, "config item rate-limit not found")), IsNil) +} + +func (s *testRateLimitConfigSuite) TestConfigLimiterConifgByOriginAPI(c *C) { + // this test case is used to test updating `limiter-config` by origin API simply + addr := fmt.Sprintf("%s/service-middleware/config", s.urlPrefix) + dimensionConfig := ratelimit.DimensionConfig{QPS: 1} + limiterConfig := map[string]interface{}{ + "CreateOperator": dimensionConfig, + } + ms := map[string]interface{}{ + "limiter-config": limiterConfig, + } + postData, err := json.Marshal(ms) + c.Assert(err, IsNil) + c.Assert(tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(c)), IsNil) + sc := &config.ServiceMiddlewareConfig{} + c.Assert(tu.ReadGetJSON(c, testDialClient, addr, sc), IsNil) + c.Assert(sc.RateLimitConfig.LimiterConfig["CreateOperator"].QPS, Equals, 1.) +} diff --git a/server/config/service_middleware_config.go b/server/config/service_middleware_config.go index d1b600ccaf2..38f51fce3fd 100644 --- a/server/config/service_middleware_config.go +++ b/server/config/service_middleware_config.go @@ -14,13 +14,17 @@ package config +import "github.com/tikv/pd/pkg/ratelimit" + const ( - defaultEnableAuditMiddleware = false + defaultEnableAuditMiddleware = false + defaultEnableRateLimitMiddleware = false ) // ServiceMiddlewareConfig is is the configuration for PD Service middleware. type ServiceMiddlewareConfig struct { - AuditConfig `json:"audit"` + AuditConfig `json:"audit"` + RateLimitConfig `json:"rate-limit"` } // NewServiceMiddlewareConfig returns a new service middleware config @@ -28,8 +32,13 @@ func NewServiceMiddlewareConfig() *ServiceMiddlewareConfig { audit := AuditConfig{ EnableAudit: defaultEnableAuditMiddleware, } + ratelimit := RateLimitConfig{ + EnableRateLimit: defaultEnableRateLimitMiddleware, + LimiterConfig: make(map[string]ratelimit.DimensionConfig), + } cfg := &ServiceMiddlewareConfig{ - AuditConfig: audit, + AuditConfig: audit, + RateLimitConfig: ratelimit, } return cfg } @@ -51,3 +60,17 @@ func (c *AuditConfig) Clone() *AuditConfig { cfg := *c return &cfg } + +// RateLimitConfig is the configuration for rate limit +type RateLimitConfig struct { + // EnableRateLimit controls the switch of the rate limit middleware + EnableRateLimit bool `json:"enable-rate-limit,string"` + // RateLimitConfig is the config of rate limit middleware + LimiterConfig map[string]ratelimit.DimensionConfig `json:"limiter-config"` +} + +// Clone returns a cloned rate limit config. +func (c *RateLimitConfig) Clone() *RateLimitConfig { + cfg := *c + return &cfg +} diff --git a/server/config/service_middleware_persist_options.go b/server/config/service_middleware_persist_options.go index 7fde025b8c1..20f8c110a5f 100644 --- a/server/config/service_middleware_persist_options.go +++ b/server/config/service_middleware_persist_options.go @@ -25,13 +25,15 @@ import ( // ServiceMiddlewarePersistOptions wraps all service middleware configurations that need to persist to storage and // allows to access them safely. type ServiceMiddlewarePersistOptions struct { - audit atomic.Value + audit atomic.Value + rateLimit atomic.Value } // NewServiceMiddlewarePersistOptions creates a new ServiceMiddlewarePersistOptions instance. func NewServiceMiddlewarePersistOptions(cfg *ServiceMiddlewareConfig) *ServiceMiddlewarePersistOptions { o := &ServiceMiddlewarePersistOptions{} o.audit.Store(&cfg.AuditConfig) + o.rateLimit.Store(&cfg.RateLimitConfig) return o } @@ -50,10 +52,26 @@ func (o *ServiceMiddlewarePersistOptions) IsAuditEnabled() bool { return o.GetAuditConfig().EnableAudit } +// GetRateLimitConfig returns pd service middleware configurations. +func (o *ServiceMiddlewarePersistOptions) GetRateLimitConfig() *RateLimitConfig { + return o.rateLimit.Load().(*RateLimitConfig) +} + +// SetRateLimitConfig sets the PD service middleware configuration. +func (o *ServiceMiddlewarePersistOptions) SetRateLimitConfig(cfg *RateLimitConfig) { + o.rateLimit.Store(cfg) +} + +// IsRateLimitEnabled returns whether rate limit middleware is enabled +func (o *ServiceMiddlewarePersistOptions) IsRateLimitEnabled() bool { + return o.GetRateLimitConfig().EnableRateLimit +} + // Persist saves the configuration to the storage. func (o *ServiceMiddlewarePersistOptions) Persist(storage endpoint.ServiceMiddlewareStorage) error { cfg := &ServiceMiddlewareConfig{ - AuditConfig: *o.GetAuditConfig(), + AuditConfig: *o.GetAuditConfig(), + RateLimitConfig: *o.GetRateLimitConfig(), } err := storage.SaveServiceMiddlewareConfig(cfg) failpoint.Inject("persistServiceMiddlewareFail", func() { @@ -72,6 +90,7 @@ func (o *ServiceMiddlewarePersistOptions) Reload(storage endpoint.ServiceMiddlew } if isExist { o.audit.Store(&cfg.AuditConfig) + o.rateLimit.Store(&cfg.RateLimitConfig) } return nil } diff --git a/server/server.go b/server/server.go index a63c8c52525..b618f97aeed 100644 --- a/server/server.go +++ b/server/server.go @@ -45,6 +45,7 @@ import ( "github.com/tikv/pd/pkg/etcdutil" "github.com/tikv/pd/pkg/grpcutil" "github.com/tikv/pd/pkg/logutil" + "github.com/tikv/pd/pkg/ratelimit" "github.com/tikv/pd/pkg/syncutil" "github.com/tikv/pd/pkg/systimemon" "github.com/tikv/pd/pkg/typeutil" @@ -156,6 +157,7 @@ type Server struct { // the corresponding forwarding TSO channel. tsoDispatcher sync.Map /* Store as map[string]chan *tsoRequest */ + serviceRateLimiter *ratelimit.Limiter serviceLabels map[string][]apiutil.AccessPath apiServiceLabelMap map[apiutil.AccessPath]string @@ -258,6 +260,7 @@ func CreateServer(ctx context.Context, cfg *config.Config, serviceBuilders ...Ha audit.NewPrometheusHistogramBackend(serviceAuditHistogram, false), } s.serviceAuditBackendLabels = make(map[string]*audit.BackendLabels) + s.serviceRateLimiter = ratelimit.NewLimiter() s.serviceLabels = make(map[string][]apiutil.AccessPath) s.apiServiceLabelMap = make(map[apiutil.AccessPath]string) @@ -806,7 +809,8 @@ func (s *Server) GetMembers() ([]*pdpb.Member, error) { // GetServiceMiddlewareConfig gets the service middleware config information. func (s *Server) GetServiceMiddlewareConfig() *config.ServiceMiddlewareConfig { cfg := s.serviceMiddlewareCfg.Clone() - cfg.AuditConfig = *s.serviceMiddlewarePersistOptions.GetAuditConfig() + cfg.AuditConfig = *s.serviceMiddlewarePersistOptions.GetAuditConfig().Clone() + cfg.RateLimitConfig = *s.serviceMiddlewarePersistOptions.GetRateLimitConfig().Clone() return cfg } @@ -978,6 +982,27 @@ func (s *Server) SetAuditConfig(cfg config.AuditConfig) error { return nil } +// GetRateLimitConfig gets the rate limit config information. +func (s *Server) GetRateLimitConfig() *config.RateLimitConfig { + return s.serviceMiddlewarePersistOptions.GetRateLimitConfig().Clone() +} + +// SetRateLimitConfig sets the rate limit config. +func (s *Server) SetRateLimitConfig(cfg config.RateLimitConfig) error { + old := s.serviceMiddlewarePersistOptions.GetRateLimitConfig() + s.serviceMiddlewarePersistOptions.SetRateLimitConfig(&cfg) + if err := s.serviceMiddlewarePersistOptions.Persist(s.storage); err != nil { + s.serviceMiddlewarePersistOptions.SetRateLimitConfig(old) + log.Error("failed to update Rate Limit config", + zap.Reflect("new", cfg), + zap.Reflect("old", old), + errs.ZapError(err)) + return err + } + log.Info("Rate Limit config is updated", zap.Reflect("new", cfg), zap.Reflect("old", old)) + return nil +} + // GetPDServerConfig gets the balance config information. func (s *Server) GetPDServerConfig() *config.PDServerConfig { return s.persistOptions.GetPDServerConfig().Clone() @@ -1195,6 +1220,11 @@ func (s *Server) SetServiceAuditBackendLabels(serviceLabel string, labels []stri s.serviceAuditBackendLabels[serviceLabel] = &audit.BackendLabels{Labels: labels} } +// GetServiceRateLimiter is used to get rate limiter +func (s *Server) GetServiceRateLimiter() *ratelimit.Limiter { + return s.serviceRateLimiter +} + // GetClusterStatus gets cluster status. func (s *Server) GetClusterStatus() (*cluster.Status, error) { s.cluster.Lock() @@ -1450,6 +1480,7 @@ func (s *Server) reloadConfigFromKV() error { if err != nil { return err } + s.loadRateLimitConfig() switchableStorage, ok := s.storage.(interface { SwitchToRegionStorage() SwitchToDefaultStorage() @@ -1467,6 +1498,14 @@ func (s *Server) reloadConfigFromKV() error { return nil } +func (s *Server) loadRateLimitConfig() { + cfg := s.serviceMiddlewarePersistOptions.GetRateLimitConfig().LimiterConfig + for key := range cfg { + value := cfg[key] + s.serviceRateLimiter.Update(key, ratelimit.UpdateDimensionConfig(&value)) + } +} + // ReplicateFileToMember is used to synchronize state to a member. // Each member will write `data` to a local file named `name`. // For security reason, data should be in JSON format. diff --git a/tests/server/config/config_test.go b/tests/server/config/config_test.go new file mode 100644 index 00000000000..f375397a9f3 --- /dev/null +++ b/tests/server/config/config_test.go @@ -0,0 +1,105 @@ +// Copyright 2022 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "testing" + + . "github.com/pingcap/check" + "github.com/tikv/pd/pkg/ratelimit" + "github.com/tikv/pd/pkg/testutil" + "github.com/tikv/pd/server" + "github.com/tikv/pd/tests" +) + +// dialClient used to dial http request. +var dialClient = &http.Client{ + Transport: &http.Transport{ + DisableKeepAlives: true, + }, +} + +func Test(t *testing.T) { + TestingT(t) +} + +var _ = Suite(&testConfigPresistSuite{}) + +type testConfigPresistSuite struct { + cleanup func() + cluster *tests.TestCluster +} + +func (s *testConfigPresistSuite) SetUpSuite(c *C) { + ctx, cancel := context.WithCancel(context.Background()) + s.cleanup = cancel + cluster, err := tests.NewTestCluster(ctx, 3) + c.Assert(err, IsNil) + c.Assert(cluster.RunInitialServers(), IsNil) + c.Assert(cluster.WaitLeader(), Not(HasLen), 0) + s.cluster = cluster +} + +func (s *testConfigPresistSuite) TearDownSuite(c *C) { + s.cleanup() + s.cluster.Destroy() +} + +func (s *testConfigPresistSuite) TestRateLimitConfigReload(c *C) { + leader := s.cluster.GetServer(s.cluster.GetLeader()) + + c.Assert(leader.GetServer().GetServiceMiddlewareConfig().RateLimitConfig.LimiterConfig, HasLen, 0) + limitCfg := make(map[string]ratelimit.DimensionConfig) + limitCfg["GetRegions"] = ratelimit.DimensionConfig{QPS: 1} + + input := map[string]interface{}{ + "enable-rate-limit": "true", + "limiter-config": limitCfg, + } + data, err := json.Marshal(input) + c.Assert(err, IsNil) + req, _ := http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) + resp, err := dialClient.Do(req) + c.Assert(err, IsNil) + resp.Body.Close() + c.Assert(leader.GetServer().GetServiceMiddlewarePersistOptions().IsRateLimitEnabled(), Equals, true) + c.Assert(leader.GetServer().GetServiceMiddlewarePersistOptions().GetRateLimitConfig().LimiterConfig, HasLen, 1) + + oldLeaderName := leader.GetServer().Name() + leader.GetServer().GetMember().ResignEtcdLeader(leader.GetServer().Context(), oldLeaderName, "") + mustWaitLeader(c, s.cluster.GetServers()) + leader = s.cluster.GetServer(s.cluster.GetLeader()) + + c.Assert(leader.GetServer().GetServiceMiddlewarePersistOptions().IsRateLimitEnabled(), Equals, true) + c.Assert(leader.GetServer().GetServiceMiddlewarePersistOptions().GetRateLimitConfig().LimiterConfig, HasLen, 1) +} + +func mustWaitLeader(c *C, svrs map[string]*tests.TestServer) *server.Server { + var leader *server.Server + testutil.WaitUntil(c, func() bool { + for _, s := range svrs { + if !s.GetServer().IsClosed() && s.GetServer().GetMember().IsLeader() { + leader = s.GetServer() + return true + } + } + return false + }) + return leader +} From 6a266ed7da492d77fab50c1aa3b2ae5ee8fd91c1 Mon Sep 17 00:00:00 2001 From: Eng Zer Jun Date: Tue, 14 Jun 2022 11:34:33 +0800 Subject: [PATCH 43/82] tests: use `T.TempDir` to create temporary test directory (#5153) close tikv/pd#5152 Signed-off-by: Eng Zer Jun --- pkg/encryption/master_key_test.go | 18 ++-- pkg/etcdutil/etcdutil.go | 12 +-- pkg/etcdutil/etcdutil_test.go | 12 +-- server/election/leadership_test.go | 3 +- server/election/lease_test.go | 3 +- server/encryptionkm/key_manager_test.go | 123 +++++++++--------------- server/region_syncer/client_test.go | 9 +- server/storage/kv/kv_test.go | 17 +--- tests/client/client_tls_test.go | 18 ++-- 9 files changed, 71 insertions(+), 144 deletions(-) diff --git a/pkg/encryption/master_key_test.go b/pkg/encryption/master_key_test.go index b8d5657c1fc..79a6bb390d9 100644 --- a/pkg/encryption/master_key_test.go +++ b/pkg/encryption/master_key_test.go @@ -99,8 +99,7 @@ func TestNewFileMasterKeyMissingPath(t *testing.T) { func TestNewFileMasterKeyMissingFile(t *testing.T) { t.Parallel() re := require.New(t) - dir, err := os.MkdirTemp("", "test_key_files") - re.NoError(err) + dir := t.TempDir() path := dir + "/key" config := &encryptionpb.MasterKey{ Backend: &encryptionpb.MasterKey_File{ @@ -109,15 +108,14 @@ func TestNewFileMasterKeyMissingFile(t *testing.T) { }, }, } - _, err = NewMasterKey(config, nil) + _, err := NewMasterKey(config, nil) re.Error(err) } func TestNewFileMasterKeyNotHexString(t *testing.T) { t.Parallel() re := require.New(t) - dir, err := os.MkdirTemp("", "test_key_files") - re.NoError(err) + dir := t.TempDir() path := dir + "/key" os.WriteFile(path, []byte("not-a-hex-string"), 0600) config := &encryptionpb.MasterKey{ @@ -127,15 +125,14 @@ func TestNewFileMasterKeyNotHexString(t *testing.T) { }, }, } - _, err = NewMasterKey(config, nil) + _, err := NewMasterKey(config, nil) re.Error(err) } func TestNewFileMasterKeyLengthMismatch(t *testing.T) { t.Parallel() re := require.New(t) - dir, err := os.MkdirTemp("", "test_key_files") - re.NoError(err) + dir := t.TempDir() path := dir + "/key" os.WriteFile(path, []byte("2f07ec61e5a50284f47f2b402a962ec6"), 0600) config := &encryptionpb.MasterKey{ @@ -145,7 +142,7 @@ func TestNewFileMasterKeyLengthMismatch(t *testing.T) { }, }, } - _, err = NewMasterKey(config, nil) + _, err := NewMasterKey(config, nil) re.Error(err) } @@ -153,8 +150,7 @@ func TestNewFileMasterKey(t *testing.T) { t.Parallel() re := require.New(t) key := "2f07ec61e5a50284f47f2b402a962ec672e500b26cb3aa568bb1531300c74806" - dir, err := os.MkdirTemp("", "test_key_files") - re.NoError(err) + dir := t.TempDir() path := dir + "/key" os.WriteFile(path, []byte(key), 0600) config := &encryptionpb.MasterKey{ diff --git a/pkg/etcdutil/etcdutil.go b/pkg/etcdutil/etcdutil.go index ff5ffce9226..9f81b11aaca 100644 --- a/pkg/etcdutil/etcdutil.go +++ b/pkg/etcdutil/etcdutil.go @@ -20,7 +20,7 @@ import ( "fmt" "net/http" "net/url" - "os" + "testing" "time" "github.com/gogo/protobuf/proto" @@ -182,10 +182,10 @@ func EtcdKVPutWithTTL(ctx context.Context, c *clientv3.Client, key string, value } // NewTestSingleConfig is used to create a etcd config for the unit test purpose. -func NewTestSingleConfig() *embed.Config { +func NewTestSingleConfig(t *testing.T) *embed.Config { cfg := embed.NewConfig() cfg.Name = "test_etcd" - cfg.Dir, _ = os.MkdirTemp("/tmp", "test_etcd") + cfg.Dir = t.TempDir() cfg.WalDir = "" cfg.Logger = "zap" cfg.LogOutputs = []string{"stdout"} @@ -202,9 +202,3 @@ func NewTestSingleConfig() *embed.Config { cfg.ClusterState = embed.ClusterStateFlagNew return cfg } - -// CleanConfig is used to clean the etcd data for the unit test purpose. -func CleanConfig(cfg *embed.Config) { - // Clean data directory - os.RemoveAll(cfg.Dir) -} diff --git a/pkg/etcdutil/etcdutil_test.go b/pkg/etcdutil/etcdutil_test.go index bbb8e595c32..942e66d3239 100644 --- a/pkg/etcdutil/etcdutil_test.go +++ b/pkg/etcdutil/etcdutil_test.go @@ -30,11 +30,10 @@ import ( func TestMemberHelpers(t *testing.T) { t.Parallel() re := require.New(t) - cfg1 := NewTestSingleConfig() + cfg1 := NewTestSingleConfig(t) etcd1, err := embed.StartEtcd(cfg1) defer func() { etcd1.Close() - CleanConfig(cfg1) }() re.NoError(err) @@ -55,7 +54,7 @@ func TestMemberHelpers(t *testing.T) { // Test AddEtcdMember // Make a new etcd config. - cfg2 := NewTestSingleConfig() + cfg2 := NewTestSingleConfig(t) cfg2.Name = "etcd2" cfg2.InitialCluster = cfg1.InitialCluster + fmt.Sprintf(",%s=%s", cfg2.Name, &cfg2.LPUrls[0]) cfg2.ClusterState = embed.ClusterStateFlagExisting @@ -68,7 +67,6 @@ func TestMemberHelpers(t *testing.T) { etcd2, err := embed.StartEtcd(cfg2) defer func() { etcd2.Close() - CleanConfig(cfg2) }() re.NoError(err) re.Equal(uint64(etcd2.Server.ID()), addResp.Member.ID) @@ -113,11 +111,10 @@ func TestMemberHelpers(t *testing.T) { func TestEtcdKVGet(t *testing.T) { t.Parallel() re := require.New(t) - cfg := NewTestSingleConfig() + cfg := NewTestSingleConfig(t) etcd, err := embed.StartEtcd(cfg) defer func() { etcd.Close() - CleanConfig(cfg) }() re.NoError(err) @@ -165,11 +162,10 @@ func TestEtcdKVGet(t *testing.T) { func TestEtcdKVPutWithTTL(t *testing.T) { t.Parallel() re := require.New(t) - cfg := NewTestSingleConfig() + cfg := NewTestSingleConfig(t) etcd, err := embed.StartEtcd(cfg) defer func() { etcd.Close() - CleanConfig(cfg) }() re.NoError(err) diff --git a/server/election/leadership_test.go b/server/election/leadership_test.go index 9a4b52f782e..4b3663a2ff6 100644 --- a/server/election/leadership_test.go +++ b/server/election/leadership_test.go @@ -29,11 +29,10 @@ const defaultLeaseTimeout = 1 func TestLeadership(t *testing.T) { re := require.New(t) - cfg := etcdutil.NewTestSingleConfig() + cfg := etcdutil.NewTestSingleConfig(t) etcd, err := embed.StartEtcd(cfg) defer func() { etcd.Close() - etcdutil.CleanConfig(cfg) }() re.NoError(err) diff --git a/server/election/lease_test.go b/server/election/lease_test.go index ef8c12be2e9..6298c22f0f2 100644 --- a/server/election/lease_test.go +++ b/server/election/lease_test.go @@ -27,11 +27,10 @@ import ( func TestLease(t *testing.T) { re := require.New(t) - cfg := etcdutil.NewTestSingleConfig() + cfg := etcdutil.NewTestSingleConfig(t) etcd, err := embed.StartEtcd(cfg) defer func() { etcd.Close() - etcdutil.CleanConfig(cfg) }() re.NoError(err) diff --git a/server/encryptionkm/key_manager_test.go b/server/encryptionkm/key_manager_test.go index 5e0d864942c..3ca8bb320d4 100644 --- a/server/encryptionkm/key_manager_test.go +++ b/server/encryptionkm/key_manager_test.go @@ -20,6 +20,7 @@ import ( "fmt" "net/url" "os" + "path/filepath" "sync/atomic" "testing" "time" @@ -48,10 +49,10 @@ func getTestDataKey() []byte { return key } -func newTestEtcd(re *require.Assertions) (client *clientv3.Client, cleanup func()) { +func newTestEtcd(t *testing.T, re *require.Assertions) (client *clientv3.Client) { cfg := embed.NewConfig() cfg.Name = "test_etcd" - cfg.Dir, _ = os.MkdirTemp("/tmp", "test_etcd") + cfg.Dir = t.TempDir() cfg.Logger = "zap" pu, err := url.Parse(tempurl.Alloc()) re.NoError(err) @@ -72,31 +73,25 @@ func newTestEtcd(re *require.Assertions) (client *clientv3.Client, cleanup func( }) re.NoError(err) - cleanup = func() { + t.Cleanup(func() { client.Close() server.Close() - os.RemoveAll(cfg.Dir) - } + }) - return client, cleanup + return client } -func newTestKeyFile(re *require.Assertions, key ...string) (keyFilePath string, cleanup func()) { +func newTestKeyFile(t *testing.T, re *require.Assertions, key ...string) (keyFilePath string) { testKey := testMasterKey for _, k := range key { testKey = k } - tempDir, err := os.MkdirTemp("/tmp", "test_key_file") - re.NoError(err) - keyFilePath = tempDir + "/key" - err = os.WriteFile(keyFilePath, []byte(testKey), 0600) - re.NoError(err) - cleanup = func() { - os.RemoveAll(tempDir) - } + keyFilePath = filepath.Join(t.TempDir(), "key") + err := os.WriteFile(keyFilePath, []byte(testKey), 0600) + re.NoError(err) - return keyFilePath, cleanup + return keyFilePath } func newTestLeader(re *require.Assertions, client *clientv3.Client) *election.Leadership { @@ -118,8 +113,7 @@ func checkMasterKeyMeta(re *require.Assertions, value []byte, meta *encryptionpb func TestNewKeyManagerBasic(t *testing.T) { re := require.New(t) // Initialize. - client, cleanupEtcd := newTestEtcd(re) - defer cleanupEtcd() + client := newTestEtcd(t, re) // Use default config. config := &encryption.Config{} err := config.Adjust() @@ -141,10 +135,8 @@ func TestNewKeyManagerBasic(t *testing.T) { func TestNewKeyManagerWithCustomConfig(t *testing.T) { re := require.New(t) // Initialize. - client, cleanupEtcd := newTestEtcd(re) - defer cleanupEtcd() - keyFile, cleanupKeyFile := newTestKeyFile(re) - defer cleanupKeyFile() + client := newTestEtcd(t, re) + keyFile := newTestKeyFile(t, re) // Custom config rotatePeriod, err := time.ParseDuration("100h") re.NoError(err) @@ -181,10 +173,8 @@ func TestNewKeyManagerWithCustomConfig(t *testing.T) { func TestNewKeyManagerLoadKeys(t *testing.T) { re := require.New(t) // Initialize. - client, cleanupEtcd := newTestEtcd(re) - defer cleanupEtcd() - keyFile, cleanupKeyFile := newTestKeyFile(re) - defer cleanupKeyFile() + client := newTestEtcd(t, re) + keyFile := newTestKeyFile(t, re) leadership := newTestLeader(re, client) // Use default config. config := &encryption.Config{} @@ -224,8 +214,7 @@ func TestNewKeyManagerLoadKeys(t *testing.T) { func TestGetCurrentKey(t *testing.T) { re := require.New(t) // Initialize. - client, cleanupEtcd := newTestEtcd(re) - defer cleanupEtcd() + client := newTestEtcd(t, re) // Use default config. config := &encryption.Config{} err := config.Adjust() @@ -268,10 +257,8 @@ func TestGetCurrentKey(t *testing.T) { func TestGetKey(t *testing.T) { re := require.New(t) // Initialize. - client, cleanupEtcd := newTestEtcd(re) - defer cleanupEtcd() - keyFile, cleanupKeyFile := newTestKeyFile(re) - defer cleanupKeyFile() + client := newTestEtcd(t, re) + keyFile := newTestKeyFile(t, re) leadership := newTestLeader(re, client) // Store initial keys in etcd. masterKeyMeta := newMasterKey(keyFile) @@ -324,10 +311,8 @@ func TestGetKey(t *testing.T) { func TestLoadKeyEmpty(t *testing.T) { re := require.New(t) // Initialize. - client, cleanupEtcd := newTestEtcd(re) - defer cleanupEtcd() - keyFile, cleanupKeyFile := newTestKeyFile(re) - defer cleanupKeyFile() + client := newTestEtcd(t, re) + keyFile := newTestKeyFile(t, re) leadership := newTestLeader(re, client) // Store initial keys in etcd. masterKeyMeta := newMasterKey(keyFile) @@ -362,10 +347,8 @@ func TestWatcher(t *testing.T) { // Initialize. ctx, cancel := context.WithCancel(context.Background()) defer cancel() - client, cleanupEtcd := newTestEtcd(re) - defer cleanupEtcd() - keyFile, cleanupKeyFile := newTestKeyFile(re) - defer cleanupKeyFile() + client := newTestEtcd(t, re) + keyFile := newTestKeyFile(t, re) leadership := newTestLeader(re, client) // Setup helper helper := defaultKeyManagerHelper() @@ -440,8 +423,7 @@ func TestWatcher(t *testing.T) { func TestSetLeadershipWithEncryptionOff(t *testing.T) { re := require.New(t) // Initialize. - client, cleanupEtcd := newTestEtcd(re) - defer cleanupEtcd() + client := newTestEtcd(t, re) // Use default config. config := &encryption.Config{} err := config.Adjust() @@ -466,10 +448,8 @@ func TestSetLeadershipWithEncryptionEnabling(t *testing.T) { // Initialize. ctx, cancel := context.WithCancel(context.Background()) defer cancel() - client, cleanupEtcd := newTestEtcd(re) - defer cleanupEtcd() - keyFile, cleanupKeyFile := newTestKeyFile(re) - defer cleanupKeyFile() + client := newTestEtcd(t, re) + keyFile := newTestKeyFile(t, re) leadership := newTestLeader(re, client) // Setup helper helper := defaultKeyManagerHelper() @@ -521,10 +501,8 @@ func TestSetLeadershipWithEncryptionMethodChanged(t *testing.T) { // Initialize. ctx, cancel := context.WithCancel(context.Background()) defer cancel() - client, cleanupEtcd := newTestEtcd(re) - defer cleanupEtcd() - keyFile, cleanupKeyFile := newTestKeyFile(re) - defer cleanupKeyFile() + client := newTestEtcd(t, re) + keyFile := newTestKeyFile(t, re) leadership := newTestLeader(re, client) // Setup helper helper := defaultKeyManagerHelper() @@ -599,10 +577,8 @@ func TestSetLeadershipWithCurrentKeyExposed(t *testing.T) { // Initialize. ctx, cancel := context.WithCancel(context.Background()) defer cancel() - client, cleanupEtcd := newTestEtcd(re) - defer cleanupEtcd() - keyFile, cleanupKeyFile := newTestKeyFile(re) - defer cleanupKeyFile() + client := newTestEtcd(t, re) + keyFile := newTestKeyFile(t, re) leadership := newTestLeader(re, client) // Setup helper helper := defaultKeyManagerHelper() @@ -672,10 +648,8 @@ func TestSetLeadershipWithCurrentKeyExpired(t *testing.T) { // Initialize. ctx, cancel := context.WithCancel(context.Background()) defer cancel() - client, cleanupEtcd := newTestEtcd(re) - defer cleanupEtcd() - keyFile, cleanupKeyFile := newTestKeyFile(re) - defer cleanupKeyFile() + client := newTestEtcd(t, re) + keyFile := newTestKeyFile(t, re) leadership := newTestLeader(re, client) // Setup helper helper := defaultKeyManagerHelper() @@ -749,12 +723,9 @@ func TestSetLeadershipWithMasterKeyChanged(t *testing.T) { // Initialize. ctx, cancel := context.WithCancel(context.Background()) defer cancel() - client, cleanupEtcd := newTestEtcd(re) - defer cleanupEtcd() - keyFile, cleanupKeyFile := newTestKeyFile(re) - defer cleanupKeyFile() - keyFile2, cleanupKeyFile2 := newTestKeyFile(re, testMasterKey2) - defer cleanupKeyFile2() + client := newTestEtcd(t, re) + keyFile := newTestKeyFile(t, re) + keyFile2 := newTestKeyFile(t, re, testMasterKey2) leadership := newTestLeader(re, client) // Setup helper helper := defaultKeyManagerHelper() @@ -817,10 +788,8 @@ func TestSetLeadershipWithMasterKeyChanged(t *testing.T) { func TestSetLeadershipMasterKeyWithCiphertextKey(t *testing.T) { re := require.New(t) // Initialize. - client, cleanupEtcd := newTestEtcd(re) - defer cleanupEtcd() - keyFile, cleanupKeyFile := newTestKeyFile(re) - defer cleanupKeyFile() + client := newTestEtcd(t, re) + keyFile := newTestKeyFile(t, re) leadership := newTestLeader(re, client) // Setup helper helper := defaultKeyManagerHelper() @@ -897,10 +866,8 @@ func TestSetLeadershipWithEncryptionDisabling(t *testing.T) { // Initialize. ctx, cancel := context.WithCancel(context.Background()) defer cancel() - client, cleanupEtcd := newTestEtcd(re) - defer cleanupEtcd() - keyFile, cleanupKeyFile := newTestKeyFile(re) - defer cleanupKeyFile() + client := newTestEtcd(t, re) + keyFile := newTestKeyFile(t, re) leadership := newTestLeader(re, client) // Setup helper helper := defaultKeyManagerHelper() @@ -955,10 +922,8 @@ func TestKeyRotation(t *testing.T) { // Initialize. ctx, cancel := context.WithCancel(context.Background()) defer cancel() - client, cleanupEtcd := newTestEtcd(re) - defer cleanupEtcd() - keyFile, cleanupKeyFile := newTestKeyFile(re) - defer cleanupKeyFile() + client := newTestEtcd(t, re) + keyFile := newTestKeyFile(t, re) leadership := newTestLeader(re, client) // Setup helper helper := defaultKeyManagerHelper() @@ -1053,10 +1018,8 @@ func TestKeyRotationConflict(t *testing.T) { // Initialize. ctx, cancel := context.WithCancel(context.Background()) defer cancel() - client, cleanupEtcd := newTestEtcd(re) - defer cleanupEtcd() - keyFile, cleanupKeyFile := newTestKeyFile(re) - defer cleanupKeyFile() + client := newTestEtcd(t, re) + keyFile := newTestKeyFile(t, re) leadership := newTestLeader(re, client) // Setup helper helper := defaultKeyManagerHelper() diff --git a/server/region_syncer/client_test.go b/server/region_syncer/client_test.go index b63deaae3e0..80185e86f94 100644 --- a/server/region_syncer/client_test.go +++ b/server/region_syncer/client_test.go @@ -16,7 +16,6 @@ package syncer import ( "context" - "os" "testing" "time" @@ -34,9 +33,7 @@ import ( // For issue https://github.com/tikv/pd/issues/3936 func TestLoadRegion(t *testing.T) { re := require.New(t) - tempDir, err := os.MkdirTemp(os.TempDir(), "region_syncer_load_region") - re.NoError(err) - defer os.RemoveAll(tempDir) + tempDir := t.TempDir() rs, err := storage.NewStorageWithLevelDBBackend(context.Background(), tempDir, nil) re.NoError(err) @@ -64,9 +61,7 @@ func TestLoadRegion(t *testing.T) { func TestErrorCode(t *testing.T) { re := require.New(t) - tempDir, err := os.MkdirTemp(os.TempDir(), "region_syncer_err") - re.NoError(err) - defer os.RemoveAll(tempDir) + tempDir := t.TempDir() rs, err := storage.NewStorageWithLevelDBBackend(context.Background(), tempDir, nil) re.NoError(err) server := &mockServer{ diff --git a/server/storage/kv/kv_test.go b/server/storage/kv/kv_test.go index 88bac9b279f..ac2911036aa 100644 --- a/server/storage/kv/kv_test.go +++ b/server/storage/kv/kv_test.go @@ -17,7 +17,6 @@ package kv import ( "fmt" "net/url" - "os" "path" "sort" "strconv" @@ -31,8 +30,7 @@ import ( func TestEtcd(t *testing.T) { re := require.New(t) - cfg := newTestSingleConfig() - defer cleanConfig(cfg) + cfg := newTestSingleConfig(t) etcd, err := embed.StartEtcd(cfg) re.NoError(err) defer etcd.Close() @@ -51,9 +49,7 @@ func TestEtcd(t *testing.T) { func TestLevelDB(t *testing.T) { re := require.New(t) - dir, err := os.MkdirTemp("/tmp", "leveldb_kv") - re.NoError(err) - defer os.RemoveAll(dir) + dir := t.TempDir() kv, err := NewLevelDBKV(dir) re.NoError(err) @@ -121,10 +117,10 @@ func testRange(re *require.Assertions, kv Base) { } } -func newTestSingleConfig() *embed.Config { +func newTestSingleConfig(t *testing.T) *embed.Config { cfg := embed.NewConfig() cfg.Name = "test_etcd" - cfg.Dir, _ = os.MkdirTemp("/tmp", "test_etcd") + cfg.Dir = t.TempDir() cfg.WalDir = "" cfg.Logger = "zap" cfg.LogOutputs = []string{"stdout"} @@ -141,8 +137,3 @@ func newTestSingleConfig() *embed.Config { cfg.ClusterState = embed.ClusterStateFlagNew return cfg } - -func cleanConfig(cfg *embed.Config) { - // Clean data directory - os.RemoveAll(cfg.Dir) -} diff --git a/tests/client/client_tls_test.go b/tests/client/client_tls_test.go index 48a6fec3d2d..997abbf3f35 100644 --- a/tests/client/client_tls_test.go +++ b/tests/client/client_tls_test.go @@ -62,28 +62,22 @@ func TestTLSReloadAtomicReplace(t *testing.T) { re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - tmpDir, err := os.MkdirTemp(os.TempDir(), "cert-tmp") - re.NoError(err) + tmpDir := t.TempDir() os.RemoveAll(tmpDir) - defer os.RemoveAll(tmpDir) - certsDir, err := os.MkdirTemp(os.TempDir(), "cert-to-load") - re.NoError(err) - defer os.RemoveAll(certsDir) + certsDir := t.TempDir() - certsDirExp, err := os.MkdirTemp(os.TempDir(), "cert-expired") - re.NoError(err) - defer os.RemoveAll(certsDirExp) + certsDirExp := t.TempDir() cloneFunc := func() transport.TLSInfo { tlsInfo, terr := copyTLSFiles(testTLSInfo, certsDir) re.NoError(terr) - _, err = copyTLSFiles(testTLSInfoExpired, certsDirExp) + _, err := copyTLSFiles(testTLSInfoExpired, certsDirExp) re.NoError(err) return tlsInfo } replaceFunc := func() { - err = os.Rename(certsDir, tmpDir) + err := os.Rename(certsDir, tmpDir) re.NoError(err) err = os.Rename(certsDirExp, certsDir) re.NoError(err) @@ -93,7 +87,7 @@ func TestTLSReloadAtomicReplace(t *testing.T) { // 'certsDirExp' does not exist } revertFunc := func() { - err = os.Rename(tmpDir, certsDirExp) + err := os.Rename(tmpDir, certsDirExp) re.NoError(err) err = os.Rename(certsDir, tmpDir) From e74e2771d5cec80c7fffc705e75a6ae54e8a5122 Mon Sep 17 00:00:00 2001 From: buffer <1045931706@qq.com> Date: Tue, 14 Jun 2022 11:44:33 +0800 Subject: [PATCH 44/82] config: the defualt value of `max-merge-region-keys` is related with `max-merge-region-size` (#5084) close tikv/pd#5083 Signed-off-by: bufferflies <1045931706@qq.com> Co-authored-by: Ti Chi Robot --- server/api/config.go | 8 ++++++-- server/api/config_test.go | 13 +++++++++++++ server/config/config.go | 13 +++++++++---- server/config/config_test.go | 7 ++++--- server/config/persist_options.go | 11 ++++++++++- tests/pdctl/config/config_test.go | 25 +++++++++++++++++++++++-- 6 files changed, 65 insertions(+), 12 deletions(-) diff --git a/server/api/config.go b/server/api/config.go index de87947b785..d4d90735289 100644 --- a/server/api/config.go +++ b/server/api/config.go @@ -54,7 +54,9 @@ func newConfHandler(svr *server.Server, rd *render.Render) *confHandler { // @Success 200 {object} config.Config // @Router /config [get] func (h *confHandler) GetConfig(w http.ResponseWriter, r *http.Request) { - h.rd.JSON(w, http.StatusOK, h.svr.GetConfig()) + cfg := h.svr.GetConfig() + cfg.Schedule.MaxMergeRegionKeys = cfg.Schedule.GetMaxMergeRegionKeys() + h.rd.JSON(w, http.StatusOK, cfg) } // @Tags config @@ -309,7 +311,9 @@ func mergeConfig(v interface{}, data []byte) (updated bool, found bool, err erro // @Success 200 {object} config.ScheduleConfig // @Router /config/schedule [get] func (h *confHandler) GetScheduleConfig(w http.ResponseWriter, r *http.Request) { - h.rd.JSON(w, http.StatusOK, h.svr.GetScheduleConfig()) + cfg := h.svr.GetScheduleConfig() + cfg.MaxMergeRegionKeys = cfg.GetMaxMergeRegionKeys() + h.rd.JSON(w, http.StatusOK, cfg) } // @Tags config diff --git a/server/api/config_test.go b/server/api/config_test.go index 271849ce223..7abfafd04a6 100644 --- a/server/api/config_test.go +++ b/server/api/config_test.go @@ -370,6 +370,19 @@ func (s *testConfigSuite) TestConfigTTL(c *C) { err = tu.CheckPostJSON(testDialClient, createTTLUrl(s.urlPrefix, 1), postData, tu.StatusNotOK(c), tu.StringEqual(c, "\"unsupported ttl config schedule.invalid-ttl-config\"\n")) c.Assert(err, IsNil) + + // only set max-merge-region-size + mergeConfig := map[string]interface{}{ + "schedule.max-merge-region-size": 999, + } + postData, err = json.Marshal(mergeConfig) + c.Assert(err, IsNil) + + err = tu.CheckPostJSON(testDialClient, createTTLUrl(s.urlPrefix, 1), postData, tu.StatusOK(c)) + c.Assert(err, IsNil) + c.Assert(s.svr.GetPersistOptions().GetMaxMergeRegionSize(), Equals, uint64(999)) + // max-merge-region-keys should keep consistence with max-merge-region-size. + c.Assert(s.svr.GetPersistOptions().GetMaxMergeRegionKeys(), Equals, uint64(999*10000)) } func (s *testConfigSuite) TestTTLConflict(c *C) { diff --git a/server/config/config.go b/server/config/config.go index df833594f74..98d3ddbe66d 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -785,7 +785,6 @@ const ( defaultMaxSnapshotCount = 64 defaultMaxPendingPeerCount = 64 defaultMaxMergeRegionSize = 20 - defaultMaxMergeRegionKeys = 200000 defaultSplitMergeInterval = 1 * time.Hour defaultPatrolRegionInterval = 10 * time.Millisecond defaultMaxStoreDownTime = 30 * time.Minute @@ -822,9 +821,6 @@ func (c *ScheduleConfig) adjust(meta *configMetaData, reloading bool) error { if !meta.IsDefined("max-merge-region-size") { adjustUint64(&c.MaxMergeRegionSize, defaultMaxMergeRegionSize) } - if !meta.IsDefined("max-merge-region-keys") { - adjustUint64(&c.MaxMergeRegionKeys, defaultMaxMergeRegionKeys) - } adjustDuration(&c.SplitMergeInterval, defaultSplitMergeInterval) adjustDuration(&c.PatrolRegionInterval, defaultPatrolRegionInterval) adjustDuration(&c.MaxStoreDownTime, defaultMaxStoreDownTime) @@ -910,6 +906,15 @@ func (c *ScheduleConfig) migrateConfigurationMap() map[string][2]*bool { } } +// GetMaxMergeRegionKeys returns the max merge keys. +// it should keep consistent with tikv: https://github.com/tikv/tikv/pull/12484 +func (c *ScheduleConfig) GetMaxMergeRegionKeys() uint64 { + if keys := c.MaxMergeRegionKeys; keys != 0 { + return keys + } + return c.MaxMergeRegionSize * 10000 +} + func (c *ScheduleConfig) parseDeprecatedFlag(meta *configMetaData, name string, old, new bool) (bool, error) { oldName, newName := "disable-"+name, "enable-"+name defineOld, defineNew := meta.IsDefined(oldName), meta.IsDefined(newName) diff --git a/server/config/config_test.go b/server/config/config_test.go index 032c0526739..fe44ee619ee 100644 --- a/server/config/config_test.go +++ b/server/config/config_test.go @@ -189,13 +189,14 @@ leader-schedule-limit = 0 re.Equal(defaultLeaderLease, cfg.LeaderLease) re.Equal(uint(20000000), cfg.MaxRequestBytes) // When defined, use values from config file. + re.Equal(0*10000, int(cfg.Schedule.GetMaxMergeRegionKeys())) re.Equal(uint64(0), cfg.Schedule.MaxMergeRegionSize) re.True(cfg.Schedule.EnableOneWayMerge) re.Equal(uint64(0), cfg.Schedule.LeaderScheduleLimit) // When undefined, use default values. re.True(cfg.PreVote) re.Equal("info", cfg.Log.Level) - re.Equal(uint64(defaultMaxMergeRegionKeys), cfg.Schedule.MaxMergeRegionKeys) + re.Equal(uint64(0), cfg.Schedule.MaxMergeRegionKeys) re.Equal("http://127.0.0.1:9090", cfg.PDServerCfg.MetricStorage) re.Equal(DefaultTSOUpdatePhysicalInterval, cfg.TSOUpdatePhysicalInterval.Duration) @@ -208,6 +209,7 @@ lease = 0 [schedule] type = "random-merge" +max-merge-region-keys = 400000 ` cfg = NewConfig() meta, err = toml.Decode(cfgData, &cfg) @@ -215,7 +217,7 @@ type = "random-merge" err = cfg.Adjust(&meta, false) re.NoError(err) re.Contains(cfg.WarningMsgs[0], "Config contains undefined item") - + re.Equal(40*10000, int(cfg.Schedule.GetMaxMergeRegionKeys())) // Check misspelled schedulers name cfgData = ` name = "" @@ -229,7 +231,6 @@ type = "random-merge-schedulers" re.NoError(err) err = cfg.Adjust(&meta, false) re.Error(err) - // Check correct schedulers name cfgData = ` name = "" diff --git a/server/config/persist_options.go b/server/config/persist_options.go index fe7203722c2..643e20a3087 100644 --- a/server/config/persist_options.go +++ b/server/config/persist_options.go @@ -244,8 +244,17 @@ func (o *PersistOptions) GetMaxMergeRegionSize() uint64 { } // GetMaxMergeRegionKeys returns the max number of keys. +// It returns size * 10000 if the key of max-merge-region-Keys doesn't exist. func (o *PersistOptions) GetMaxMergeRegionKeys() uint64 { - return o.getTTLUintOr(maxMergeRegionKeysKey, o.GetScheduleConfig().MaxMergeRegionKeys) + keys, exist, err := o.getTTLUint(maxMergeRegionKeysKey) + if exist && err == nil { + return keys + } + size, exist, err := o.getTTLUint(maxMergeRegionSizeKey) + if exist && err == nil { + return size * 10000 + } + return o.GetScheduleConfig().GetMaxMergeRegionKeys() } // GetSplitMergeInterval returns the interval between finishing split and starting to merge. diff --git a/tests/pdctl/config/config_test.go b/tests/pdctl/config/config_test.go index cb564699b53..311a0e7db99 100644 --- a/tests/pdctl/config/config_test.go +++ b/tests/pdctl/config/config_test.go @@ -97,7 +97,9 @@ func (s *configTestSuite) TestConfig(c *C) { scheduleConfig.EnableRemoveExtraReplica = false scheduleConfig.EnableLocationReplacement = false scheduleConfig.StoreLimitMode = "" - + c.Assert(scheduleConfig.MaxMergeRegionKeys, Equals, uint64(0)) + // The result of config show doesn't be 0. + scheduleConfig.MaxMergeRegionKeys = scheduleConfig.GetMaxMergeRegionKeys() c.Assert(&cfg.Schedule, DeepEquals, scheduleConfig) c.Assert(&cfg.Replication, DeepEquals, svr.GetReplicationConfig()) @@ -122,7 +124,26 @@ func (s *configTestSuite) TestConfig(c *C) { c.Assert(err, IsNil) scheduleCfg := config.ScheduleConfig{} c.Assert(json.Unmarshal(output, &scheduleCfg), IsNil) - c.Assert(&scheduleCfg, DeepEquals, svr.GetScheduleConfig()) + scheduleConfig = svr.GetScheduleConfig() + scheduleConfig.MaxMergeRegionKeys = scheduleConfig.GetMaxMergeRegionKeys() + c.Assert(&scheduleCfg, DeepEquals, scheduleConfig) + + c.Assert(int(svr.GetScheduleConfig().MaxMergeRegionSize), Equals, 20) + c.Assert(int(svr.GetScheduleConfig().MaxMergeRegionKeys), Equals, 0) + c.Assert(int(svr.GetScheduleConfig().GetMaxMergeRegionKeys()), Equals, 20*10000) + + // set max-merge-region-size to 40MB + args = []string{"-u", pdAddr, "config", "set", "max-merge-region-size", "40"} + _, err = pdctl.ExecuteCommand(cmd, args...) + c.Assert(err, IsNil) + c.Assert(int(svr.GetScheduleConfig().MaxMergeRegionSize), Equals, 40) + c.Assert(int(svr.GetScheduleConfig().MaxMergeRegionKeys), Equals, 0) + c.Assert(int(svr.GetScheduleConfig().GetMaxMergeRegionKeys()), Equals, 40*10000) + args = []string{"-u", pdAddr, "config", "set", "max-merge-region-keys", "200000"} + _, err = pdctl.ExecuteCommand(cmd, args...) + c.Assert(err, IsNil) + c.Assert(int(svr.GetScheduleConfig().MaxMergeRegionKeys), Equals, 20*10000) + c.Assert(int(svr.GetScheduleConfig().GetMaxMergeRegionKeys()), Equals, 20*10000) // config show replication args = []string{"-u", pdAddr, "config", "show", "replication"} From 74661fad2e90a312622cff4368bff5e45f24d632 Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Tue, 14 Jun 2022 14:40:33 +0800 Subject: [PATCH 45/82] tools: migrate test framework to testify (#5149) ref tikv/pd#4813 Signed-off-by: Ryan Leung Co-authored-by: Ti Chi Robot --- .gitignore | 1 + tools/pd-analysis/analysis/parse_log_test.go | 43 ++++++++--------- .../analysis/transfer_counter_test.go | 35 +++++++------- tools/pd-ctl/pdctl/ctl_test.go | 19 +++----- .../simulator/simutil/key_test.go | 46 ++++++++----------- 5 files changed, 64 insertions(+), 80 deletions(-) diff --git a/.gitignore b/.gitignore index fbe6a8595a8..93e6189a687 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,4 @@ package.list report.xml coverage.xml coverage +*.txt diff --git a/tools/pd-analysis/analysis/parse_log_test.go b/tools/pd-analysis/analysis/parse_log_test.go index 475e3ae7797..ffdcb2137c0 100644 --- a/tools/pd-analysis/analysis/parse_log_test.go +++ b/tools/pd-analysis/analysis/parse_log_test.go @@ -18,17 +18,9 @@ import ( "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testParseLog{}) - -type testParseLog struct{} - func transferCounterParseLog(operator, content string, expect []uint64) bool { r, _ := GetTransferCounter().CompileRegex(operator) results, _ := GetTransferCounter().parseLine(content, r) @@ -43,73 +35,76 @@ func transferCounterParseLog(operator, content string, expect []uint64) bool { return true } -func (t *testParseLog) TestTransferCounterParseLog(c *C) { +func TestTransferCounterParseLog(t *testing.T) { + re := require.New(t) { operator := "balance-leader" content := "[2019/09/05 04:15:52.404 +00:00] [INFO] [operator_controller.go:119] [\"operator finish\"] [region-id=54252] [operator=\"\"balance-leader {transfer leader: store 4 to 6} (kind:leader,balance, region:54252(8243,398), createAt:2019-09-05 04:15:52.400290023 +0000 UTC m=+91268.739649520, startAt:2019-09-05 04:15:52.400489629 +0000 UTC m=+91268.739849120, currentStep:1, steps:[transfer leader from store 4 to store 6]) finished\"\"]" var expect = []uint64{54252, 4, 6} - c.Assert(transferCounterParseLog(operator, content, expect), IsTrue) + re.True(transferCounterParseLog(operator, content, expect)) } { operator := "balance-region" content := "[2019/09/03 17:42:07.898 +08:00] [INFO] [operator_controller.go:119] [\"operator finish\"] [region-id=24622] [operator=\"\"balance-region {mv peer: store [6] to [1]} (kind:region,balance, region:24622(1,1), createAt:2019-09-03 17:42:06.602589701 +0800 CST m=+737.457773921, startAt:2019-09-03 17:42:06.602849306 +0800 CST m=+737.458033475, currentStep:3, steps:[add learner peer 64064 on store 1, promote learner peer 64064 on store 1 to voter, remove peer on store 6]) finished\"\"]\"" var expect = []uint64{24622, 6, 1} - c.Assert(transferCounterParseLog(operator, content, expect), IsTrue) + re.True(transferCounterParseLog(operator, content, expect)) } { operator := "transfer-hot-write-leader" content := "[2019/09/05 14:05:42.811 +08:00] [INFO] [operator_controller.go:119] [\"operator finish\"] [region-id=94] [operator=\"\"transfer-hot-write-leader {transfer leader: store 2 to 1} (kind:leader,hot-region, region:94(1,1), createAt:2019-09-05 14:05:42.676394689 +0800 CST m=+14.955640307, startAt:2019-09-05 14:05:42.676589507 +0800 CST m=+14.955835051, currentStep:1, steps:[transfer leader from store 2 to store 1]) finished\"\"]" var expect = []uint64{94, 2, 1} - c.Assert(transferCounterParseLog(operator, content, expect), IsTrue) + re.True(transferCounterParseLog(operator, content, expect)) } { operator := "move-hot-write-region" content := "[2019/09/05 14:05:54.311 +08:00] [INFO] [operator_controller.go:119] [\"operator finish\"] [region-id=98] [operator=\"\"move-hot-write-region {mv peer: store [2] to [10]} (kind:region,hot-region, region:98(1,1), createAt:2019-09-05 14:05:49.718201432 +0800 CST m=+21.997446945, startAt:2019-09-05 14:05:49.718336308 +0800 CST m=+21.997581822, currentStep:3, steps:[add learner peer 2048 on store 10, promote learner peer 2048 on store 10 to voter, remove peer on store 2]) finished\"\"]" var expect = []uint64{98, 2, 10} - c.Assert(transferCounterParseLog(operator, content, expect), IsTrue) + re.True(transferCounterParseLog(operator, content, expect)) } { operator := "transfer-hot-read-leader" content := "[2019/09/05 14:16:38.758 +08:00] [INFO] [operator_controller.go:119] [\"operator finish\"] [region-id=85] [operator=\"\"transfer-hot-read-leader {transfer leader: store 1 to 5} (kind:leader,hot-region, region:85(1,1), createAt:2019-09-05 14:16:38.567463945 +0800 CST m=+29.117453011, startAt:2019-09-05 14:16:38.567603515 +0800 CST m=+29.117592496, currentStep:1, steps:[transfer leader from store 1 to store 5]) finished\"\"]" var expect = []uint64{85, 1, 5} - c.Assert(transferCounterParseLog(operator, content, expect), IsTrue) + re.True(transferCounterParseLog(operator, content, expect)) } { operator := "move-hot-read-region" content := "[2019/09/05 14:19:15.066 +08:00] [INFO] [operator_controller.go:119] [\"operator finish\"] [region-id=389] [operator=\"\"move-hot-read-region {mv peer: store [5] to [4]} (kind:leader,region,hot-region, region:389(1,1), createAt:2019-09-05 14:19:13.576359364 +0800 CST m=+25.855737101, startAt:2019-09-05 14:19:13.576556556 +0800 CST m=+25.855934288, currentStep:4, steps:[add learner peer 2014 on store 4, promote learner peer 2014 on store 4 to voter, transfer leader from store 5 to store 3, remove peer on store 5]) finished\"\"]" var expect = []uint64{389, 5, 4} - c.Assert(transferCounterParseLog(operator, content, expect), IsTrue) + re.True(transferCounterParseLog(operator, content, expect)) } } -func (t *testParseLog) TestIsExpectTime(c *C) { +func TestIsExpectTime(t *testing.T) { + re := require.New(t) { testFunction := isExpectTime("2019/09/05 14:19:15", DefaultLayout, true) current, _ := time.Parse(DefaultLayout, "2019/09/05 14:19:14") - c.Assert(testFunction(current), IsTrue) + re.True(testFunction(current)) } { testFunction := isExpectTime("2019/09/05 14:19:15", DefaultLayout, false) current, _ := time.Parse(DefaultLayout, "2019/09/05 14:19:16") - c.Assert(testFunction(current), IsTrue) + re.True(testFunction(current)) } { testFunction := isExpectTime("", DefaultLayout, true) current, _ := time.Parse(DefaultLayout, "2019/09/05 14:19:14") - c.Assert(testFunction(current), IsTrue) + re.True(testFunction(current)) } { testFunction := isExpectTime("", DefaultLayout, false) current, _ := time.Parse(DefaultLayout, "2019/09/05 14:19:16") - c.Assert(testFunction(current), IsTrue) + re.True(testFunction(current)) } } -func (t *testParseLog) TestCurrentTime(c *C) { +func TestCurrentTime(t *testing.T) { + re := require.New(t) getCurrentTime := currentTime(DefaultLayout) content := "[2019/09/05 14:19:15.066 +08:00] [INFO] [operator_controller.go:119] [\"operator finish\"] [region-id=389] [operator=\"\"move-hot-read-region {mv peer: store 5 to 4} (kind:leader,region,hot-region, region:389(1,1), createAt:2019-09-05 14:19:13.576359364 +0800 CST m=+25.855737101, startAt:2019-09-05 14:19:13.576556556 +0800 CST m=+25.855934288, currentStep:4, steps:[add learner peer 2014 on store 4, promote learner peer 2014 on store 4 to voter, transfer leader from store 5 to store 3, remove peer on store 5]) finished\"\"]" current, err := getCurrentTime(content) - c.Assert(err, Equals, nil) + re.NoError(err) expect, _ := time.Parse(DefaultLayout, "2019/09/05 14:19:15") - c.Assert(current, Equals, expect) + re.Equal(expect, current) } diff --git a/tools/pd-analysis/analysis/transfer_counter_test.go b/tools/pd-analysis/analysis/transfer_counter_test.go index 796f0652345..092767cd49d 100644 --- a/tools/pd-analysis/analysis/transfer_counter_test.go +++ b/tools/pd-analysis/analysis/transfer_counter_test.go @@ -15,12 +15,10 @@ package analysis import ( - . "github.com/pingcap/check" -) - -var _ = Suite(&testTransferRegionCounter{}) + "testing" -type testTransferRegionCounter struct{} + "github.com/stretchr/testify/require" +) func addData(test [][]uint64) { for i, row := range test { @@ -33,7 +31,8 @@ func addData(test [][]uint64) { } } -func (t *testTransferRegionCounter) TestCounterRedundant(c *C) { +func TestCounterRedundant(t *testing.T) { + re := require.New(t) { test := [][]uint64{ {0, 0, 0, 0, 0, 0, 0}, @@ -44,12 +43,12 @@ func (t *testTransferRegionCounter) TestCounterRedundant(c *C) { {0, 5, 9, 0, 0, 0, 0}, {0, 0, 8, 0, 0, 0, 0}} GetTransferCounter().Init(6, 3000) - c.Assert(GetTransferCounter().Redundant, Equals, uint64(0)) - c.Assert(GetTransferCounter().Necessary, Equals, uint64(0)) + re.Equal(uint64(0), GetTransferCounter().Redundant) + re.Equal(uint64(0), GetTransferCounter().Necessary) addData(test) GetTransferCounter().Result() - c.Assert(GetTransferCounter().Redundant, Equals, uint64(64)) - c.Assert(GetTransferCounter().Necessary, Equals, uint64(5)) + re.Equal(uint64(64), GetTransferCounter().Redundant) + re.Equal(uint64(5), GetTransferCounter().Necessary) } { test := [][]uint64{ @@ -61,12 +60,12 @@ func (t *testTransferRegionCounter) TestCounterRedundant(c *C) { {0, 1, 0, 0, 0, 0, 0}, {0, 0, 1, 0, 0, 0, 0}} GetTransferCounter().Init(6, 3000) - c.Assert(GetTransferCounter().Redundant, Equals, uint64(0)) - c.Assert(GetTransferCounter().Necessary, Equals, uint64(0)) + re.Equal(uint64(0), GetTransferCounter().Redundant) + re.Equal(uint64(0), GetTransferCounter().Necessary) addData(test) GetTransferCounter().Result() - c.Assert(GetTransferCounter().Redundant, Equals, uint64(0)) - c.Assert(GetTransferCounter().Necessary, Equals, uint64(5)) + re.Equal(uint64(0), GetTransferCounter().Redundant) + re.Equal(uint64(5), GetTransferCounter().Necessary) } { test := [][]uint64{ @@ -80,12 +79,12 @@ func (t *testTransferRegionCounter) TestCounterRedundant(c *C) { {0, 0, 48, 0, 84, 1, 48, 0, 20}, {0, 61, 2, 57, 7, 122, 1, 21, 0}} GetTransferCounter().Init(8, 3000) - c.Assert(GetTransferCounter().Redundant, Equals, uint64(0)) - c.Assert(GetTransferCounter().Necessary, Equals, uint64(0)) + re.Equal(uint64(0), GetTransferCounter().Redundant) + re.Equal(uint64(0), GetTransferCounter().Necessary) addData(test) GetTransferCounter().Result() - c.Assert(GetTransferCounter().Redundant, Equals, uint64(1778)) - c.Assert(GetTransferCounter().Necessary, Equals, uint64(938)) + re.Equal(uint64(1778), GetTransferCounter().Redundant) + re.Equal(uint64(938), GetTransferCounter().Necessary) GetTransferCounter().PrintResult() } } diff --git a/tools/pd-ctl/pdctl/ctl_test.go b/tools/pd-ctl/pdctl/ctl_test.go index 90369bab46c..6dc29058e34 100644 --- a/tools/pd-ctl/pdctl/ctl_test.go +++ b/tools/pd-ctl/pdctl/ctl_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/spf13/cobra" + "github.com/stretchr/testify/require" ) func newCommand(usage, short string) *cobra.Command { @@ -31,6 +32,7 @@ func newCommand(usage, short string) *cobra.Command { } func TestGenCompleter(t *testing.T) { + re := require.New(t) var subCommand = []string{"testa", "testb", "testc", "testdef"} rootCmd := &cobra.Command{ @@ -65,13 +67,12 @@ func TestGenCompleter(t *testing.T) { } } - if inPrefixArray == false { - t.Errorf("%s not in prefix array", cmd) - } + re.True(inPrefixArray) } } func TestReadStdin(t *testing.T) { + re := require.New(t) s := []struct { in io.Reader targets []string @@ -84,16 +85,10 @@ func TestReadStdin(t *testing.T) { }} for _, v := range s { in, err := ReadStdin(v.in) - if err != nil { - t.Errorf("ReadStdin err:%v", err) - } - if len(v.targets) != len(in) { - t.Errorf("ReadStdin = %v, want %s, nil", in, v.targets) - } + re.NoError(err) + re.Equal(len(v.targets), len(in)) for i, target := range v.targets { - if target != in[i] { - t.Errorf("ReadStdin = %v, want %s, nil", in, v.targets) - } + re.Equal(target, in[i]) } } } diff --git a/tools/pd-simulator/simulator/simutil/key_test.go b/tools/pd-simulator/simulator/simutil/key_test.go index d7821c475fc..6174ec35381 100644 --- a/tools/pd-simulator/simulator/simutil/key_test.go +++ b/tools/pd-simulator/simulator/simutil/key_test.go @@ -17,46 +17,40 @@ package simutil import ( "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/codec" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testTableKeySuite{}) - -type testTableKeySuite struct{} - -func (t *testTableKeySuite) TestGenerateTableKeys(c *C) { +func TestGenerateTableKeys(t *testing.T) { + re := require.New(t) tableCount := 3 size := 10 keys := GenerateTableKeys(tableCount, size) - c.Assert(keys, HasLen, size) + re.Len(keys, size) for i := 1; i < len(keys); i++ { - c.Assert(keys[i-1], Less, keys[i]) + re.Less(keys[i-1], keys[i]) s := []byte(keys[i-1]) e := []byte(keys[i]) for j := 0; j < 1000; j++ { split, err := GenerateTiDBEncodedSplitKey(s, e) - c.Assert(err, IsNil) - c.Assert(s, Less, split) - c.Assert(split, Less, e) + re.NoError(err) + re.Less(string(s), string(split)) + re.Less(string(split), string(e)) e = split } } } -func (t *testTableKeySuite) TestGenerateSplitKey(c *C) { +func TestGenerateSplitKey(t *testing.T) { + re := require.New(t) s := []byte(codec.EncodeBytes([]byte("a"))) e := []byte(codec.EncodeBytes([]byte("ab"))) for i := 0; i <= 1000; i++ { cc, err := GenerateTiDBEncodedSplitKey(s, e) - c.Assert(err, IsNil) - c.Assert(s, Less, cc) - c.Assert(cc, Less, e) + re.NoError(err) + re.Less(string(s), string(cc)) + re.Less(string(cc), string(e)) e = cc } @@ -64,19 +58,19 @@ func (t *testTableKeySuite) TestGenerateSplitKey(c *C) { s = []byte("") e = []byte{116, 128, 0, 0, 0, 0, 0, 0, 255, 1, 0, 0, 0, 0, 0, 0, 0, 248} splitKey, err := GenerateTiDBEncodedSplitKey(s, e) - c.Assert(err, IsNil) - c.Assert(s, Less, splitKey) - c.Assert(splitKey, Less, e) + re.NoError(err) + re.Less(string(s), string(splitKey)) + re.Less(string(splitKey), string(e)) // split equal key s = codec.EncodeBytes([]byte{116, 128, 0, 0, 0, 0, 0, 0, 1}) e = codec.EncodeBytes([]byte{116, 128, 0, 0, 0, 0, 0, 0, 1, 1}) for i := 0; i <= 1000; i++ { - c.Assert(s, Less, e) + re.Less(string(s), string(e)) splitKey, err = GenerateTiDBEncodedSplitKey(s, e) - c.Assert(err, IsNil) - c.Assert(s, Less, splitKey) - c.Assert(splitKey, Less, e) + re.NoError(err) + re.Less(string(s), string(splitKey)) + re.Less(string(splitKey), string(e)) e = splitKey } } From d263b8586123387d219bdccf9a1733632ab0d9c2 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Tue, 14 Jun 2022 16:04:33 +0800 Subject: [PATCH 46/82] tests: testify the pd-ctl tests (#5154) ref tikv/pd#4813 Testify the pd-ctl tests. Signed-off-by: JmPotato --- tests/pdctl/cluster/cluster_test.go | 47 +- tests/pdctl/completion/completion_test.go | 17 +- tests/pdctl/config/config_test.go | 529 +++++++++--------- tests/pdctl/global_test.go | 25 +- tests/pdctl/health/health_test.go | 28 +- tests/pdctl/helper.go | 66 ++- tests/pdctl/hot/hot_test.go | 173 +++--- tests/pdctl/label/label_test.go | 39 +- tests/pdctl/log/log_test.go | 78 ++- tests/pdctl/member/member_test.go | 67 +-- tests/pdctl/operator/operator_test.go | 115 ++-- tests/pdctl/region/region_test.go | 75 ++- tests/pdctl/scheduler/scheduler_test.go | 141 ++--- tests/pdctl/store/store_test.go | 314 ++++++----- tests/pdctl/tso/tso_test.go | 26 +- tests/pdctl/unsafe/unsafe_operation_test.go | 29 +- tests/server/api/api_test.go | 22 +- .../server/storage/hot_region_storage_test.go | 22 +- 18 files changed, 877 insertions(+), 936 deletions(-) diff --git a/tests/pdctl/cluster/cluster_test.go b/tests/pdctl/cluster/cluster_test.go index 4b8cceb3bc5..9d69f89dcef 100644 --- a/tests/pdctl/cluster/cluster_test.go +++ b/tests/pdctl/cluster/cluster_test.go @@ -21,32 +21,25 @@ import ( "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/require" clusterpkg "github.com/tikv/pd/server/cluster" "github.com/tikv/pd/tests" "github.com/tikv/pd/tests/pdctl" pdctlCmd "github.com/tikv/pd/tools/pd-ctl/pdctl" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&clusterTestSuite{}) - -type clusterTestSuite struct{} - -func (s *clusterTestSuite) TestClusterAndPing(c *C) { +func TestClusterAndPing(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() cluster, err := tests.NewTestCluster(ctx, 1) - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() err = cluster.GetServer(cluster.GetLeader()).BootstrapCluster() - c.Assert(err, IsNil) + re.NoError(err) pdAddr := cluster.GetConfig().GetClientURL() i := strings.Index(pdAddr, "//") pdAddr = pdAddr[i+2:] @@ -56,42 +49,42 @@ func (s *clusterTestSuite) TestClusterAndPing(c *C) { // cluster args := []string{"-u", pdAddr, "cluster"} output, err := pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) ci := &metapb.Cluster{} - c.Assert(json.Unmarshal(output, ci), IsNil) - c.Assert(ci, DeepEquals, cluster.GetCluster()) + re.NoError(json.Unmarshal(output, ci)) + re.Equal(cluster.GetCluster(), ci) // cluster info args = []string{"-u", pdAddr, "cluster"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) ci = &metapb.Cluster{} - c.Assert(json.Unmarshal(output, ci), IsNil) - c.Assert(ci, DeepEquals, cluster.GetCluster()) + re.NoError(json.Unmarshal(output, ci)) + re.Equal(cluster.GetCluster(), ci) // cluster status args = []string{"-u", pdAddr, "cluster", "status"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) cs := &clusterpkg.Status{} - c.Assert(json.Unmarshal(output, cs), IsNil) + re.NoError(json.Unmarshal(output, cs)) clusterStatus, err := cluster.GetClusterStatus() - c.Assert(err, IsNil) - c.Assert(clusterStatus.RaftBootstrapTime.Equal(cs.RaftBootstrapTime), IsTrue) + re.NoError(err) + re.True(clusterStatus.RaftBootstrapTime.Equal(cs.RaftBootstrapTime)) // ref: https://github.com/onsi/gomega/issues/264 clusterStatus.RaftBootstrapTime = time.Time{} cs.RaftBootstrapTime = time.Time{} - c.Assert(cs, DeepEquals, clusterStatus) + re.Equal(clusterStatus, cs) // ping args = []string{"-u", pdAddr, "ping"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - c.Assert(output, NotNil) + re.NoError(err) + re.NotNil(output) // does not exist args = []string{"-u", pdAddr, "--cacert=ca.pem", "cluster"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, ErrorMatches, ".*no such file or directory.*") + re.Contains(err.Error(), "no such file or directory") } diff --git a/tests/pdctl/completion/completion_test.go b/tests/pdctl/completion/completion_test.go index f7cc30bbe05..c64615df9e1 100644 --- a/tests/pdctl/completion/completion_test.go +++ b/tests/pdctl/completion/completion_test.go @@ -17,29 +17,22 @@ package completion_test import ( "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/tests/pdctl" pdctlCmd "github.com/tikv/pd/tools/pd-ctl/pdctl" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&completionTestSuite{}) - -type completionTestSuite struct{} - -func (s *completionTestSuite) TestCompletion(c *C) { +func TestCompletion(t *testing.T) { + re := require.New(t) cmd := pdctlCmd.GetRootCmd() // completion command args := []string{"completion", "bash"} _, err := pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) // completion command args = []string{"completion", "zsh"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) } diff --git a/tests/pdctl/config/config_test.go b/tests/pdctl/config/config_test.go index 311a0e7db99..f5acd3fd3ff 100644 --- a/tests/pdctl/config/config_test.go +++ b/tests/pdctl/config/config_test.go @@ -15,18 +15,16 @@ package config_test import ( - "bytes" "context" "encoding/json" "os" "reflect" - "strings" "testing" "time" "github.com/coreos/go-semver/semver" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/typeutil" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/schedule/placement" @@ -35,35 +33,28 @@ import ( pdctlCmd "github.com/tikv/pd/tools/pd-ctl/pdctl" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&configTestSuite{}) - -type configTestSuite struct{} - type testItem struct { name string value interface{} read func(scheduleConfig *config.ScheduleConfig) interface{} } -func (t *testItem) judge(c *C, scheduleConfigs ...*config.ScheduleConfig) { +func (t *testItem) judge(re *require.Assertions, scheduleConfigs ...*config.ScheduleConfig) { value := t.value for _, scheduleConfig := range scheduleConfigs { - c.Assert(scheduleConfig, NotNil) - c.Assert(reflect.TypeOf(t.read(scheduleConfig)), Equals, reflect.TypeOf(value)) + re.NotNil(scheduleConfig) + re.IsType(value, t.read(scheduleConfig)) } } -func (s *configTestSuite) TestConfig(c *C) { +func TestConfig(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() cluster, err := tests.NewTestCluster(ctx, 1) - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() pdAddr := cluster.GetConfig().GetClientURL() cmd := pdctlCmd.GetRootCmd() @@ -73,17 +64,17 @@ func (s *configTestSuite) TestConfig(c *C) { State: metapb.StoreState_Up, } leaderServer := cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) + re.NoError(leaderServer.BootstrapCluster()) svr := leaderServer.GetServer() - pdctl.MustPutStore(c, svr, store) + pdctl.MustPutStore(re, svr, store) defer cluster.Destroy() // config show args := []string{"-u", pdAddr, "config", "show"} output, err := pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) cfg := config.Config{} - c.Assert(json.Unmarshal(output, &cfg), IsNil) + re.NoError(json.Unmarshal(output, &cfg)) scheduleConfig := svr.GetScheduleConfig() // hidden config @@ -97,127 +88,127 @@ func (s *configTestSuite) TestConfig(c *C) { scheduleConfig.EnableRemoveExtraReplica = false scheduleConfig.EnableLocationReplacement = false scheduleConfig.StoreLimitMode = "" - c.Assert(scheduleConfig.MaxMergeRegionKeys, Equals, uint64(0)) + re.Equal(uint64(0), scheduleConfig.MaxMergeRegionKeys) // The result of config show doesn't be 0. scheduleConfig.MaxMergeRegionKeys = scheduleConfig.GetMaxMergeRegionKeys() - c.Assert(&cfg.Schedule, DeepEquals, scheduleConfig) - c.Assert(&cfg.Replication, DeepEquals, svr.GetReplicationConfig()) + re.Equal(scheduleConfig, &cfg.Schedule) + re.Equal(svr.GetReplicationConfig(), &cfg.Replication) // config set trace-region-flow args = []string{"-u", pdAddr, "config", "set", "trace-region-flow", "false"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - c.Assert(svr.GetPDServerConfig().TraceRegionFlow, IsFalse) + re.NoError(err) + re.False(svr.GetPDServerConfig().TraceRegionFlow) args = []string{"-u", pdAddr, "config", "set", "flow-round-by-digit", "10"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - c.Assert(svr.GetPDServerConfig().FlowRoundByDigit, Equals, 10) + re.NoError(err) + re.Equal(10, svr.GetPDServerConfig().FlowRoundByDigit) args = []string{"-u", pdAddr, "config", "set", "flow-round-by-digit", "-10"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, NotNil) + re.Error(err) // config show schedule args = []string{"-u", pdAddr, "config", "show", "schedule"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) scheduleCfg := config.ScheduleConfig{} - c.Assert(json.Unmarshal(output, &scheduleCfg), IsNil) + re.NoError(json.Unmarshal(output, &scheduleCfg)) scheduleConfig = svr.GetScheduleConfig() scheduleConfig.MaxMergeRegionKeys = scheduleConfig.GetMaxMergeRegionKeys() - c.Assert(&scheduleCfg, DeepEquals, scheduleConfig) + re.Equal(scheduleConfig, &scheduleCfg) - c.Assert(int(svr.GetScheduleConfig().MaxMergeRegionSize), Equals, 20) - c.Assert(int(svr.GetScheduleConfig().MaxMergeRegionKeys), Equals, 0) - c.Assert(int(svr.GetScheduleConfig().GetMaxMergeRegionKeys()), Equals, 20*10000) + re.Equal(20, int(svr.GetScheduleConfig().MaxMergeRegionSize)) + re.Equal(0, int(svr.GetScheduleConfig().MaxMergeRegionKeys)) + re.Equal(20*10000, int(svr.GetScheduleConfig().GetMaxMergeRegionKeys())) // set max-merge-region-size to 40MB args = []string{"-u", pdAddr, "config", "set", "max-merge-region-size", "40"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - c.Assert(int(svr.GetScheduleConfig().MaxMergeRegionSize), Equals, 40) - c.Assert(int(svr.GetScheduleConfig().MaxMergeRegionKeys), Equals, 0) - c.Assert(int(svr.GetScheduleConfig().GetMaxMergeRegionKeys()), Equals, 40*10000) + re.NoError(err) + re.Equal(40, int(svr.GetScheduleConfig().MaxMergeRegionSize)) + re.Equal(0, int(svr.GetScheduleConfig().MaxMergeRegionKeys)) + re.Equal(40*10000, int(svr.GetScheduleConfig().GetMaxMergeRegionKeys())) args = []string{"-u", pdAddr, "config", "set", "max-merge-region-keys", "200000"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - c.Assert(int(svr.GetScheduleConfig().MaxMergeRegionKeys), Equals, 20*10000) - c.Assert(int(svr.GetScheduleConfig().GetMaxMergeRegionKeys()), Equals, 20*10000) + re.NoError(err) + re.Equal(20*10000, int(svr.GetScheduleConfig().MaxMergeRegionKeys)) + re.Equal(20*10000, int(svr.GetScheduleConfig().GetMaxMergeRegionKeys())) // config show replication args = []string{"-u", pdAddr, "config", "show", "replication"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) replicationCfg := config.ReplicationConfig{} - c.Assert(json.Unmarshal(output, &replicationCfg), IsNil) - c.Assert(&replicationCfg, DeepEquals, svr.GetReplicationConfig()) + re.NoError(json.Unmarshal(output, &replicationCfg)) + re.Equal(svr.GetReplicationConfig(), &replicationCfg) // config show cluster-version args1 := []string{"-u", pdAddr, "config", "show", "cluster-version"} output, err = pdctl.ExecuteCommand(cmd, args1...) - c.Assert(err, IsNil) + re.NoError(err) clusterVersion := semver.Version{} - c.Assert(json.Unmarshal(output, &clusterVersion), IsNil) - c.Assert(clusterVersion, DeepEquals, svr.GetClusterVersion()) + re.NoError(json.Unmarshal(output, &clusterVersion)) + re.Equal(svr.GetClusterVersion(), clusterVersion) // config set cluster-version args2 := []string{"-u", pdAddr, "config", "set", "cluster-version", "2.1.0-rc.5"} _, err = pdctl.ExecuteCommand(cmd, args2...) - c.Assert(err, IsNil) - c.Assert(clusterVersion, Not(DeepEquals), svr.GetClusterVersion()) + re.NoError(err) + re.NotEqual(svr.GetClusterVersion(), clusterVersion) output, err = pdctl.ExecuteCommand(cmd, args1...) - c.Assert(err, IsNil) + re.NoError(err) clusterVersion = semver.Version{} - c.Assert(json.Unmarshal(output, &clusterVersion), IsNil) - c.Assert(clusterVersion, DeepEquals, svr.GetClusterVersion()) + re.NoError(json.Unmarshal(output, &clusterVersion)) + re.Equal(svr.GetClusterVersion(), clusterVersion) // config show label-property args1 = []string{"-u", pdAddr, "config", "show", "label-property"} output, err = pdctl.ExecuteCommand(cmd, args1...) - c.Assert(err, IsNil) + re.NoError(err) labelPropertyCfg := config.LabelPropertyConfig{} - c.Assert(json.Unmarshal(output, &labelPropertyCfg), IsNil) - c.Assert(labelPropertyCfg, DeepEquals, svr.GetLabelProperty()) + re.NoError(json.Unmarshal(output, &labelPropertyCfg)) + re.Equal(svr.GetLabelProperty(), labelPropertyCfg) // config set label-property args2 = []string{"-u", pdAddr, "config", "set", "label-property", "reject-leader", "zone", "cn"} _, err = pdctl.ExecuteCommand(cmd, args2...) - c.Assert(err, IsNil) - c.Assert(labelPropertyCfg, Not(DeepEquals), svr.GetLabelProperty()) + re.NoError(err) + re.NotEqual(svr.GetLabelProperty(), labelPropertyCfg) output, err = pdctl.ExecuteCommand(cmd, args1...) - c.Assert(err, IsNil) + re.NoError(err) labelPropertyCfg = config.LabelPropertyConfig{} - c.Assert(json.Unmarshal(output, &labelPropertyCfg), IsNil) - c.Assert(labelPropertyCfg, DeepEquals, svr.GetLabelProperty()) + re.NoError(json.Unmarshal(output, &labelPropertyCfg)) + re.Equal(svr.GetLabelProperty(), labelPropertyCfg) // config delete label-property args3 := []string{"-u", pdAddr, "config", "delete", "label-property", "reject-leader", "zone", "cn"} _, err = pdctl.ExecuteCommand(cmd, args3...) - c.Assert(err, IsNil) - c.Assert(labelPropertyCfg, Not(DeepEquals), svr.GetLabelProperty()) + re.NoError(err) + re.NotEqual(svr.GetLabelProperty(), labelPropertyCfg) output, err = pdctl.ExecuteCommand(cmd, args1...) - c.Assert(err, IsNil) + re.NoError(err) labelPropertyCfg = config.LabelPropertyConfig{} - c.Assert(json.Unmarshal(output, &labelPropertyCfg), IsNil) - c.Assert(labelPropertyCfg, DeepEquals, svr.GetLabelProperty()) + re.NoError(json.Unmarshal(output, &labelPropertyCfg)) + re.Equal(svr.GetLabelProperty(), labelPropertyCfg) // config set min-resolved-ts-persistence-interval args = []string{"-u", pdAddr, "config", "set", "min-resolved-ts-persistence-interval", "1s"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - c.Assert(svr.GetPDServerConfig().MinResolvedTSPersistenceInterval, Equals, typeutil.NewDuration(time.Second)) + re.NoError(err) + re.Equal(typeutil.NewDuration(time.Second), svr.GetPDServerConfig().MinResolvedTSPersistenceInterval) // config set max-store-preparing-time 10m args = []string{"-u", pdAddr, "config", "set", "max-store-preparing-time", "10m"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - c.Assert(svr.GetScheduleConfig().MaxStorePreparingTime, Equals, typeutil.NewDuration(10*time.Minute)) + re.NoError(err) + re.Equal(typeutil.NewDuration(10*time.Minute), svr.GetScheduleConfig().MaxStorePreparingTime) args = []string{"-u", pdAddr, "config", "set", "max-store-preparing-time", "0s"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - c.Assert(svr.GetScheduleConfig().MaxStorePreparingTime, Equals, typeutil.NewDuration(0)) + re.NoError(err) + re.Equal(typeutil.NewDuration(0), svr.GetScheduleConfig().MaxStorePreparingTime) // test config read and write testItems := []testItem{ @@ -242,43 +233,44 @@ func (s *configTestSuite) TestConfig(c *C) { // write args1 = []string{"-u", pdAddr, "config", "set", item.name, reflect.TypeOf(item.value).String()} _, err = pdctl.ExecuteCommand(cmd, args1...) - c.Assert(err, IsNil) + re.NoError(err) // read args2 = []string{"-u", pdAddr, "config", "show"} output, err = pdctl.ExecuteCommand(cmd, args2...) - c.Assert(err, IsNil) + re.NoError(err) cfg = config.Config{} - c.Assert(json.Unmarshal(output, &cfg), IsNil) + re.NoError(json.Unmarshal(output, &cfg)) // judge - item.judge(c, &cfg.Schedule, svr.GetScheduleConfig()) + item.judge(re, &cfg.Schedule, svr.GetScheduleConfig()) } // test error or deprecated config name args1 = []string{"-u", pdAddr, "config", "set", "foo-bar", "1"} output, err = pdctl.ExecuteCommand(cmd, args1...) - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "not found"), IsTrue) + re.NoError(err) + re.Contains(string(output), "not found") args1 = []string{"-u", pdAddr, "config", "set", "disable-remove-down-replica", "true"} output, err = pdctl.ExecuteCommand(cmd, args1...) - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "already been deprecated"), IsTrue) + re.NoError(err) + re.Contains(string(output), "already been deprecated") // set enable-placement-rules twice, make sure it does not return error. args1 = []string{"-u", pdAddr, "config", "set", "enable-placement-rules", "true"} _, err = pdctl.ExecuteCommand(cmd, args1...) - c.Assert(err, IsNil) + re.NoError(err) args1 = []string{"-u", pdAddr, "config", "set", "enable-placement-rules", "true"} _, err = pdctl.ExecuteCommand(cmd, args1...) - c.Assert(err, IsNil) + re.NoError(err) } -func (s *configTestSuite) TestPlacementRules(c *C) { +func TestPlacementRules(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() cluster, err := tests.NewTestCluster(ctx, 1) - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() pdAddr := cluster.GetConfig().GetClientURL() cmd := pdctlCmd.GetRootCmd() @@ -289,23 +281,22 @@ func (s *configTestSuite) TestPlacementRules(c *C) { LastHeartbeat: time.Now().UnixNano(), } leaderServer := cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) + re.NoError(leaderServer.BootstrapCluster()) svr := leaderServer.GetServer() - pdctl.MustPutStore(c, svr, store) + pdctl.MustPutStore(re, svr, store) defer cluster.Destroy() output, err := pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "enable") - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "Success!"), IsTrue) + re.NoError(err) + re.Contains(string(output), "Success!") // test show var rules []placement.Rule output, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "show") - c.Assert(err, IsNil) - err = json.Unmarshal(output, &rules) - c.Assert(err, IsNil) - c.Assert(rules, HasLen, 1) - c.Assert(rules[0].Key(), Equals, [2]string{"pd", "default"}) + re.NoError(err) + re.NoError(json.Unmarshal(output, &rules)) + re.Len(rules, 1) + re.Equal([2]string{"pd", "default"}, rules[0].Key()) f, _ := os.CreateTemp("/tmp", "pd_tests") fname := f.Name() @@ -313,11 +304,11 @@ func (s *configTestSuite) TestPlacementRules(c *C) { // test load _, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "load", "--out="+fname) - c.Assert(err, IsNil) + re.NoError(err) b, _ := os.ReadFile(fname) - c.Assert(json.Unmarshal(b, &rules), IsNil) - c.Assert(rules, HasLen, 1) - c.Assert(rules[0].Key(), Equals, [2]string{"pd", "default"}) + re.NoError(json.Unmarshal(b, &rules)) + re.Len(rules, 1) + re.Equal([2]string{"pd", "default"}, rules[0].Key()) // test save rules = append(rules, placement.Rule{ @@ -334,39 +325,38 @@ func (s *configTestSuite) TestPlacementRules(c *C) { b, _ = json.Marshal(rules) os.WriteFile(fname, b, 0600) _, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "save", "--in="+fname) - c.Assert(err, IsNil) + re.NoError(err) // test show group var rules2 []placement.Rule output, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "show", "--group=pd") - c.Assert(err, IsNil) - err = json.Unmarshal(output, &rules2) - c.Assert(err, IsNil) - c.Assert(rules2, HasLen, 2) - c.Assert(rules2[0].Key(), Equals, [2]string{"pd", "default"}) - c.Assert(rules2[1].Key(), Equals, [2]string{"pd", "test1"}) + re.NoError(err) + re.NoError(json.Unmarshal(output, &rules2)) + re.Len(rules2, 2) + re.Equal([2]string{"pd", "default"}, rules2[0].Key()) + re.Equal([2]string{"pd", "test1"}, rules2[1].Key()) // test delete rules[0].Count = 0 b, _ = json.Marshal(rules) os.WriteFile(fname, b, 0600) _, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "save", "--in="+fname) - c.Assert(err, IsNil) + re.NoError(err) output, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "show", "--group=pd") - c.Assert(err, IsNil) - err = json.Unmarshal(output, &rules) - c.Assert(err, IsNil) - c.Assert(rules, HasLen, 1) - c.Assert(rules[0].Key(), Equals, [2]string{"pd", "test1"}) + re.NoError(err) + re.NoError(json.Unmarshal(output, &rules)) + re.Len(rules, 1) + re.Equal([2]string{"pd", "test1"}, rules[0].Key()) } -func (s *configTestSuite) TestPlacementRuleGroups(c *C) { +func TestPlacementRuleGroups(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() cluster, err := tests.NewTestCluster(ctx, 1) - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() pdAddr := cluster.GetConfig().GetClientURL() cmd := pdctlCmd.GetRootCmd() @@ -377,60 +367,59 @@ func (s *configTestSuite) TestPlacementRuleGroups(c *C) { LastHeartbeat: time.Now().UnixNano(), } leaderServer := cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) + re.NoError(leaderServer.BootstrapCluster()) svr := leaderServer.GetServer() - pdctl.MustPutStore(c, svr, store) + pdctl.MustPutStore(re, svr, store) defer cluster.Destroy() output, err := pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "enable") - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "Success!"), IsTrue) + re.NoError(err) + re.Contains(string(output), "Success!") // test show var group placement.RuleGroup output, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "rule-group", "show", "pd") - c.Assert(err, IsNil) - err = json.Unmarshal(output, &group) - c.Assert(err, IsNil) - c.Assert(group, DeepEquals, placement.RuleGroup{ID: "pd"}) + re.NoError(err) + re.NoError(json.Unmarshal(output, &group)) + re.Equal(placement.RuleGroup{ID: "pd"}, group) // test set output, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "rule-group", "set", "pd", "42", "true") - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "Success!"), IsTrue) + re.NoError(err) + re.Contains(string(output), "Success!") output, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "rule-group", "set", "group2", "100", "false") - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "Success!"), IsTrue) + re.NoError(err) + re.Contains(string(output), "Success!") // show all var groups []placement.RuleGroup output, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "rule-group", "show") - c.Assert(err, IsNil) - err = json.Unmarshal(output, &groups) - c.Assert(err, IsNil) - c.Assert(groups, DeepEquals, []placement.RuleGroup{ + re.NoError(err) + re.NoError(json.Unmarshal(output, &groups)) + re.Equal([]placement.RuleGroup{ {ID: "pd", Index: 42, Override: true}, {ID: "group2", Index: 100, Override: false}, - }) + }, groups) // delete output, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "rule-group", "delete", "group2") - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "Success!"), IsTrue) + re.NoError(err) + re.Contains(string(output), "Success!") // show again output, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "rule-group", "show", "group2") - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "404"), IsTrue) + re.NoError(err) + re.Contains(string(output), "404") } -func (s *configTestSuite) TestPlacementRuleBundle(c *C) { +func TestPlacementRuleBundle(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() cluster, err := tests.NewTestCluster(ctx, 1) - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() pdAddr := cluster.GetConfig().GetClientURL() cmd := pdctlCmd.GetRootCmd() @@ -441,25 +430,24 @@ func (s *configTestSuite) TestPlacementRuleBundle(c *C) { LastHeartbeat: time.Now().UnixNano(), } leaderServer := cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) + re.NoError(leaderServer.BootstrapCluster()) svr := leaderServer.GetServer() - pdctl.MustPutStore(c, svr, store) + pdctl.MustPutStore(re, svr, store) defer cluster.Destroy() output, err := pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "enable") - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "Success!"), IsTrue) + re.NoError(err) + re.Contains(string(output), "Success!") // test get var bundle placement.GroupBundle output, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "rule-bundle", "get", "pd") - c.Assert(err, IsNil) - err = json.Unmarshal(output, &bundle) - c.Assert(err, IsNil) - c.Assert(bundle, DeepEquals, placement.GroupBundle{ID: "pd", Index: 0, Override: false, Rules: []*placement.Rule{{GroupID: "pd", ID: "default", Role: "voter", Count: 3}}}) + re.NoError(err) + re.NoError(json.Unmarshal(output, &bundle)) + re.Equal(placement.GroupBundle{ID: "pd", Index: 0, Override: false, Rules: []*placement.Rule{{GroupID: "pd", ID: "default", Role: "voter", Count: 3}}}, bundle) f, err := os.CreateTemp("/tmp", "pd_tests") - c.Assert(err, IsNil) + re.NoError(err) fname := f.Name() f.Close() defer func() { @@ -469,106 +457,107 @@ func (s *configTestSuite) TestPlacementRuleBundle(c *C) { // test load var bundles []placement.GroupBundle _, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "rule-bundle", "load", "--out="+fname) - c.Assert(err, IsNil) + re.NoError(err) b, _ := os.ReadFile(fname) - c.Assert(json.Unmarshal(b, &bundles), IsNil) - c.Assert(bundles, HasLen, 1) - c.Assert(bundles[0], DeepEquals, placement.GroupBundle{ID: "pd", Index: 0, Override: false, Rules: []*placement.Rule{{GroupID: "pd", ID: "default", Role: "voter", Count: 3}}}) + re.NoError(json.Unmarshal(b, &bundles)) + re.Len(bundles, 1) + re.Equal(placement.GroupBundle{ID: "pd", Index: 0, Override: false, Rules: []*placement.Rule{{GroupID: "pd", ID: "default", Role: "voter", Count: 3}}}, bundles[0]) // test set bundle.ID = "pe" bundle.Rules[0].GroupID = "pe" b, err = json.Marshal(bundle) - c.Assert(err, IsNil) - c.Assert(os.WriteFile(fname, b, 0600), IsNil) + re.NoError(err) + re.NoError(os.WriteFile(fname, b, 0600)) _, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "rule-bundle", "set", "--in="+fname) - c.Assert(err, IsNil) + re.NoError(err) _, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "rule-bundle", "load", "--out="+fname) - c.Assert(err, IsNil) + re.NoError(err) b, _ = os.ReadFile(fname) - c.Assert(json.Unmarshal(b, &bundles), IsNil) - assertBundles(bundles, []placement.GroupBundle{ + re.NoError(json.Unmarshal(b, &bundles)) + assertBundles(re, bundles, []placement.GroupBundle{ {ID: "pd", Index: 0, Override: false, Rules: []*placement.Rule{{GroupID: "pd", ID: "default", Role: "voter", Count: 3}}}, {ID: "pe", Index: 0, Override: false, Rules: []*placement.Rule{{GroupID: "pe", ID: "default", Role: "voter", Count: 3}}}, - }, c) + }) // test delete _, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "rule-bundle", "delete", "pd") - c.Assert(err, IsNil) + re.NoError(err) _, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "rule-bundle", "load", "--out="+fname) - c.Assert(err, IsNil) + re.NoError(err) b, _ = os.ReadFile(fname) - c.Assert(json.Unmarshal(b, &bundles), IsNil) - assertBundles(bundles, []placement.GroupBundle{ + re.NoError(json.Unmarshal(b, &bundles)) + assertBundles(re, bundles, []placement.GroupBundle{ {ID: "pe", Index: 0, Override: false, Rules: []*placement.Rule{{GroupID: "pe", ID: "default", Role: "voter", Count: 3}}}, - }, c) + }) // test delete regexp bundle.ID = "pf" bundle.Rules = []*placement.Rule{{GroupID: "pf", ID: "default", Role: "voter", Count: 3}} b, err = json.Marshal(bundle) - c.Assert(err, IsNil) - c.Assert(os.WriteFile(fname, b, 0600), IsNil) + re.NoError(err) + re.NoError(os.WriteFile(fname, b, 0600)) _, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "rule-bundle", "set", "--in="+fname) - c.Assert(err, IsNil) + re.NoError(err) _, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "rule-bundle", "delete", "--regexp", ".*f") - c.Assert(err, IsNil) + re.NoError(err) _, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "rule-bundle", "load", "--out="+fname) - c.Assert(err, IsNil) + re.NoError(err) b, _ = os.ReadFile(fname) - c.Assert(json.Unmarshal(b, &bundles), IsNil) - assertBundles(bundles, []placement.GroupBundle{ + re.NoError(json.Unmarshal(b, &bundles)) + assertBundles(re, bundles, []placement.GroupBundle{ {ID: "pe", Index: 0, Override: false, Rules: []*placement.Rule{{GroupID: "pe", ID: "default", Role: "voter", Count: 3}}}, - }, c) + }) // test save bundle.Rules = []*placement.Rule{{GroupID: "pf", ID: "default", Role: "voter", Count: 3}} bundles = append(bundles, bundle) b, err = json.Marshal(bundles) - c.Assert(err, IsNil) - c.Assert(os.WriteFile(fname, b, 0600), IsNil) + re.NoError(err) + re.NoError(os.WriteFile(fname, b, 0600)) _, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "rule-bundle", "save", "--in="+fname) - c.Assert(err, IsNil) + re.NoError(err) _, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "rule-bundle", "load", "--out="+fname) - c.Assert(err, IsNil) + re.NoError(err) b, err = os.ReadFile(fname) - c.Assert(err, IsNil) - c.Assert(json.Unmarshal(b, &bundles), IsNil) - assertBundles(bundles, []placement.GroupBundle{ + re.NoError(err) + re.NoError(json.Unmarshal(b, &bundles)) + assertBundles(re, bundles, []placement.GroupBundle{ {ID: "pe", Index: 0, Override: false, Rules: []*placement.Rule{{GroupID: "pe", ID: "default", Role: "voter", Count: 3}}}, {ID: "pf", Index: 0, Override: false, Rules: []*placement.Rule{{GroupID: "pf", ID: "default", Role: "voter", Count: 3}}}, - }, c) + }) // partial update, so still one group is left, no error bundles = []placement.GroupBundle{{ID: "pe", Rules: []*placement.Rule{}}} b, err = json.Marshal(bundles) - c.Assert(err, IsNil) - c.Assert(os.WriteFile(fname, b, 0600), IsNil) + re.NoError(err) + re.NoError(os.WriteFile(fname, b, 0600)) _, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "rule-bundle", "save", "--in="+fname, "--partial") - c.Assert(err, IsNil) + re.NoError(err) _, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "rule-bundle", "load", "--out="+fname) - c.Assert(err, IsNil) + re.NoError(err) b, err = os.ReadFile(fname) - c.Assert(err, IsNil) - c.Assert(json.Unmarshal(b, &bundles), IsNil) - assertBundles(bundles, []placement.GroupBundle{ + re.NoError(err) + re.NoError(json.Unmarshal(b, &bundles)) + assertBundles(re, bundles, []placement.GroupBundle{ {ID: "pf", Index: 0, Override: false, Rules: []*placement.Rule{{GroupID: "pf", ID: "default", Role: "voter", Count: 3}}}, - }, c) + }) } -func (s *configTestSuite) TestReplicationMode(c *C) { +func TestReplicationMode(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() cluster, err := tests.NewTestCluster(ctx, 1) - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() pdAddr := cluster.GetConfig().GetClientURL() cmd := pdctlCmd.GetRootCmd() @@ -579,9 +568,9 @@ func (s *configTestSuite) TestReplicationMode(c *C) { LastHeartbeat: time.Now().UnixNano(), } leaderServer := cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) + re.NoError(leaderServer.BootstrapCluster()) svr := leaderServer.GetServer() - pdctl.MustPutStore(c, svr, store) + pdctl.MustPutStore(re, svr, store) defer cluster.Destroy() conf := config.ReplicationModeConfig{ @@ -593,42 +582,43 @@ func (s *configTestSuite) TestReplicationMode(c *C) { } check := func() { output, err := pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "show", "replication-mode") - c.Assert(err, IsNil) + re.NoError(err) var conf2 config.ReplicationModeConfig - json.Unmarshal(output, &conf2) - c.Assert(conf2, DeepEquals, conf) + re.NoError(json.Unmarshal(output, &conf2)) + re.Equal(conf, conf2) } check() _, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "set", "replication-mode", "dr-auto-sync") - c.Assert(err, IsNil) + re.NoError(err) conf.ReplicationMode = "dr-auto-sync" check() _, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "set", "replication-mode", "dr-auto-sync", "label-key", "foobar") - c.Assert(err, IsNil) + re.NoError(err) conf.DRAutoSync.LabelKey = "foobar" check() _, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "set", "replication-mode", "dr-auto-sync", "primary-replicas", "5") - c.Assert(err, IsNil) + re.NoError(err) conf.DRAutoSync.PrimaryReplicas = 5 check() _, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "set", "replication-mode", "dr-auto-sync", "wait-store-timeout", "10m") - c.Assert(err, IsNil) + re.NoError(err) conf.DRAutoSync.WaitStoreTimeout = typeutil.NewDuration(time.Minute * 10) check() } -func (s *configTestSuite) TestUpdateDefaultReplicaConfig(c *C) { +func TestUpdateDefaultReplicaConfig(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() cluster, err := tests.NewTestCluster(ctx, 1) - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() pdAddr := cluster.GetConfig().GetClientURL() cmd := pdctlCmd.GetRootCmd() @@ -638,83 +628,77 @@ func (s *configTestSuite) TestUpdateDefaultReplicaConfig(c *C) { State: metapb.StoreState_Up, } leaderServer := cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) + re.NoError(leaderServer.BootstrapCluster()) svr := leaderServer.GetServer() - pdctl.MustPutStore(c, svr, store) + pdctl.MustPutStore(re, svr, store) defer cluster.Destroy() checkMaxReplicas := func(expect uint64) { args := []string{"-u", pdAddr, "config", "show", "replication"} output, err := pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) replicationCfg := config.ReplicationConfig{} - c.Assert(json.Unmarshal(output, &replicationCfg), IsNil) - c.Assert(replicationCfg.MaxReplicas, Equals, expect) + re.NoError(json.Unmarshal(output, &replicationCfg)) + re.Equal(expect, replicationCfg.MaxReplicas) } checkLocaltionLabels := func(expect int) { args := []string{"-u", pdAddr, "config", "show", "replication"} output, err := pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) replicationCfg := config.ReplicationConfig{} - c.Assert(json.Unmarshal(output, &replicationCfg), IsNil) - c.Assert(replicationCfg.LocationLabels, HasLen, expect) + re.NoError(json.Unmarshal(output, &replicationCfg)) + re.Len(replicationCfg.LocationLabels, expect) } checkRuleCount := func(expect int) { args := []string{"-u", pdAddr, "config", "placement-rules", "show", "--group", "pd", "--id", "default"} output, err := pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) rule := placement.Rule{} - c.Assert(json.Unmarshal(output, &rule), IsNil) - c.Assert(rule.Count, Equals, expect) + re.NoError(json.Unmarshal(output, &rule)) + re.Equal(expect, rule.Count) } checkRuleLocationLabels := func(expect int) { args := []string{"-u", pdAddr, "config", "placement-rules", "show", "--group", "pd", "--id", "default"} output, err := pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) rule := placement.Rule{} - c.Assert(json.Unmarshal(output, &rule), IsNil) - c.Assert(rule.LocationLabels, HasLen, expect) + re.NoError(json.Unmarshal(output, &rule)) + re.Len(rule.LocationLabels, expect) } // update successfully when placement rules is not enabled. output, err := pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "set", "max-replicas", "2") - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "Success!"), IsTrue) + re.NoError(err) + re.Contains(string(output), "Success!") checkMaxReplicas(2) output, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "set", "location-labels", "zone,host") - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "Success!"), IsTrue) + re.NoError(err) + re.Contains(string(output), "Success!") checkLocaltionLabels(2) checkRuleLocationLabels(2) // update successfully when only one default rule exists. output, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "enable") - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "Success!"), IsTrue) + re.NoError(err) + re.Contains(string(output), "Success!") output, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "set", "max-replicas", "3") - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "Success!"), IsTrue) + re.NoError(err) + re.Contains(string(output), "Success!") checkMaxReplicas(3) checkRuleCount(3) output, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "set", "location-labels", "host") - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "Success!"), IsTrue) + re.NoError(err) + re.Contains(string(output), "Success!") checkLocaltionLabels(1) checkRuleLocationLabels(1) // update unsuccessfully when many rule exists. - f, _ := os.CreateTemp("/tmp", "pd_tests") - fname := f.Name() - f.Close() - defer func() { - os.RemoveAll(fname) - }() - + fname := t.TempDir() rules := []placement.Rule{ { GroupID: "pd", @@ -724,28 +708,29 @@ func (s *configTestSuite) TestUpdateDefaultReplicaConfig(c *C) { }, } b, err := json.Marshal(rules) - c.Assert(err, IsNil) + re.NoError(err) os.WriteFile(fname, b, 0600) _, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "save", "--in="+fname) - c.Assert(err, IsNil) + re.NoError(err) checkMaxReplicas(3) checkRuleCount(3) _, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "set", "max-replicas", "4") - c.Assert(err, IsNil) + re.NoError(err) checkMaxReplicas(4) checkRuleCount(4) checkLocaltionLabels(1) checkRuleLocationLabels(1) } -func (s *configTestSuite) TestPDServerConfig(c *C) { +func TestPDServerConfig(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() cluster, err := tests.NewTestCluster(ctx, 1) - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() pdAddr := cluster.GetConfig().GetClientURL() cmd := pdctlCmd.GetRootCmd() @@ -756,52 +741,52 @@ func (s *configTestSuite) TestPDServerConfig(c *C) { LastHeartbeat: time.Now().UnixNano(), } leaderServer := cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) + re.NoError(leaderServer.BootstrapCluster()) svr := leaderServer.GetServer() - pdctl.MustPutStore(c, svr, store) + pdctl.MustPutStore(re, svr, store) defer cluster.Destroy() output, err := pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "show", "server") - c.Assert(err, IsNil) + re.NoError(err) var conf config.PDServerConfig - json.Unmarshal(output, &conf) - - c.Assert(conf.UseRegionStorage, Equals, bool(true)) - c.Assert(conf.MaxResetTSGap.Duration, Equals, 24*time.Hour) - c.Assert(conf.KeyType, Equals, "table") - c.Assert(conf.RuntimeServices, DeepEquals, typeutil.StringSlice([]string{})) - c.Assert(conf.MetricStorage, Equals, "") - c.Assert(conf.DashboardAddress, Equals, "auto") - c.Assert(conf.FlowRoundByDigit, Equals, int(3)) + re.NoError(json.Unmarshal(output, &conf)) + + re.True(conf.UseRegionStorage) + re.Equal(24*time.Hour, conf.MaxResetTSGap.Duration) + re.Equal("table", conf.KeyType) + re.Equal(typeutil.StringSlice([]string{}), conf.RuntimeServices) + re.Equal("", conf.MetricStorage) + re.Equal("auto", conf.DashboardAddress) + re.Equal(int(3), conf.FlowRoundByDigit) } -func assertBundles(a, b []placement.GroupBundle, c *C) { - c.Assert(len(a), Equals, len(b)) +func assertBundles(re *require.Assertions, a, b []placement.GroupBundle) { + re.Equal(len(a), len(b)) for i := 0; i < len(a); i++ { - assertBundle(a[i], b[i], c) + assertBundle(re, a[i], b[i]) } } -func assertBundle(a, b placement.GroupBundle, c *C) { - c.Assert(a.ID, Equals, b.ID) - c.Assert(a.Index, Equals, b.Index) - c.Assert(a.Override, Equals, b.Override) - c.Assert(len(a.Rules), Equals, len(b.Rules)) +func assertBundle(re *require.Assertions, a, b placement.GroupBundle) { + re.Equal(a.ID, b.ID) + re.Equal(a.Index, b.Index) + re.Equal(a.Override, b.Override) + re.Equal(len(a.Rules), len(b.Rules)) for i := 0; i < len(a.Rules); i++ { - assertRule(a.Rules[i], b.Rules[i], c) + assertRule(re, a.Rules[i], b.Rules[i]) } } -func assertRule(a, b *placement.Rule, c *C) { - c.Assert(a.GroupID, Equals, b.GroupID) - c.Assert(a.ID, Equals, b.ID) - c.Assert(a.Index, Equals, b.Index) - c.Assert(a.Override, Equals, b.Override) - c.Assert(bytes.Equal(a.StartKey, b.StartKey), IsTrue) - c.Assert(bytes.Equal(a.EndKey, b.EndKey), IsTrue) - c.Assert(a.Role, Equals, b.Role) - c.Assert(a.Count, Equals, b.Count) - c.Assert(a.LabelConstraints, DeepEquals, b.LabelConstraints) - c.Assert(a.LocationLabels, DeepEquals, b.LocationLabels) - c.Assert(a.IsolationLevel, Equals, b.IsolationLevel) +func assertRule(re *require.Assertions, a, b *placement.Rule) { + re.Equal(a.GroupID, b.GroupID) + re.Equal(a.ID, b.ID) + re.Equal(a.Index, b.Index) + re.Equal(a.Override, b.Override) + re.Equal(a.StartKey, b.StartKey) + re.Equal(a.EndKey, b.EndKey) + re.Equal(a.Role, b.Role) + re.Equal(a.Count, b.Count) + re.Equal(a.LabelConstraints, b.LabelConstraints) + re.Equal(a.LocationLabels, b.LocationLabels) + re.Equal(a.IsolationLevel, b.IsolationLevel) } diff --git a/tests/pdctl/global_test.go b/tests/pdctl/global_test.go index de165eea600..c182c739403 100644 --- a/tests/pdctl/global_test.go +++ b/tests/pdctl/global_test.go @@ -20,8 +20,8 @@ import ( "net/http" "testing" - . "github.com/pingcap/check" "github.com/pingcap/log" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/apiutil" "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server" @@ -29,15 +29,8 @@ import ( "go.uber.org/zap" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&globalTestSuite{}) - -type globalTestSuite struct{} - -func (s *globalTestSuite) TestSendAndGetComponent(c *C) { +func TestSendAndGetComponent(t *testing.T) { + re := require.New(t) handler := func(ctx context.Context, s *server.Server) (http.Handler, server.ServiceGroup, error) { mux := http.NewServeMux() mux.HandleFunc("/pd/api/v1/health", func(w http.ResponseWriter, r *http.Request) { @@ -46,7 +39,7 @@ func (s *globalTestSuite) TestSendAndGetComponent(c *C) { log.Info("header", zap.String("key", k)) } log.Info("component", zap.String("component", component)) - c.Assert(component, Equals, "pdctl") + re.Equal("pdctl", component) fmt.Fprint(w, component) }) info := server.ServiceGroup{ @@ -54,12 +47,12 @@ func (s *globalTestSuite) TestSendAndGetComponent(c *C) { } return mux, info, nil } - cfg := server.NewTestSingleConfig(checkerWithNilAssert(c)) + cfg := server.NewTestSingleConfig(checkerWithNilAssert(re)) ctx, cancel := context.WithCancel(context.Background()) svr, err := server.CreateServer(ctx, cfg, handler) - c.Assert(err, IsNil) + re.NoError(err) err = svr.Run() - c.Assert(err, IsNil) + re.NoError(err) pdAddr := svr.GetAddr() defer func() { cancel() @@ -70,6 +63,6 @@ func (s *globalTestSuite) TestSendAndGetComponent(c *C) { cmd := cmd.GetRootCmd() args := []string{"-u", pdAddr, "health"} output, err := ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - c.Assert(string(output), Equals, "pdctl\n") + re.NoError(err) + re.Equal("pdctl\n", string(output)) } diff --git a/tests/pdctl/health/health_test.go b/tests/pdctl/health/health_test.go index 06e287dcb36..bc808a36750 100644 --- a/tests/pdctl/health/health_test.go +++ b/tests/pdctl/health/health_test.go @@ -19,7 +19,7 @@ import ( "encoding/json" "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/server/api" "github.com/tikv/pd/server/cluster" "github.com/tikv/pd/tests" @@ -27,31 +27,24 @@ import ( pdctlCmd "github.com/tikv/pd/tools/pd-ctl/pdctl" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&healthTestSuite{}) - -type healthTestSuite struct{} - -func (s *healthTestSuite) TestHealth(c *C) { +func TestHealth(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() tc, err := tests.NewTestCluster(ctx, 3) - c.Assert(err, IsNil) + re.NoError(err) err = tc.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) + re.NoError(leaderServer.BootstrapCluster()) pdAddr := tc.GetConfig().GetClientURL() cmd := pdctlCmd.GetRootCmd() defer tc.Destroy() client := tc.GetEtcdClient() members, err := cluster.GetMembers(client) - c.Assert(err, IsNil) + re.NoError(err) healthMembers := cluster.CheckHealth(tc.GetHTTPClient(), members) healths := []api.Health{} for _, member := range members { @@ -70,9 +63,8 @@ func (s *healthTestSuite) TestHealth(c *C) { // health command args := []string{"-u", pdAddr, "health"} output, err := pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) h := make([]api.Health, len(healths)) - c.Assert(json.Unmarshal(output, &h), IsNil) - c.Assert(err, IsNil) - c.Assert(h, DeepEquals, healths) + re.NoError(json.Unmarshal(output, &h)) + re.Equal(healths, h) } diff --git a/tests/pdctl/helper.go b/tests/pdctl/helper.go index b5160a83f30..c5aaf948aa2 100644 --- a/tests/pdctl/helper.go +++ b/tests/pdctl/helper.go @@ -25,6 +25,7 @@ import ( "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" "github.com/spf13/cobra" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/assertutil" "github.com/tikv/pd/server" "github.com/tikv/pd/server/api" @@ -43,9 +44,9 @@ func ExecuteCommand(root *cobra.Command, args ...string) (output []byte, err err } // CheckStoresInfo is used to check the test results. -// CheckStoresInfo will not check Store.State because this feild has been omitted pdctl output -func CheckStoresInfo(c *check.C, stores []*api.StoreInfo, want []*api.StoreInfo) { - c.Assert(len(stores), check.Equals, len(want)) +// CheckStoresInfo will not check Store.State because this field has been omitted pdctl output +func CheckStoresInfo(re *require.Assertions, stores []*api.StoreInfo, want []*api.StoreInfo) { + re.Equal(len(want), len(stores)) mapWant := make(map[uint64]*api.StoreInfo) for _, s := range want { if _, ok := mapWant[s.Store.Id]; !ok { @@ -60,24 +61,24 @@ func CheckStoresInfo(c *check.C, stores []*api.StoreInfo, want []*api.StoreInfo) obtained.NodeState, expected.NodeState = 0, 0 // Ignore lastHeartbeat obtained.LastHeartbeat, expected.LastHeartbeat = 0, 0 - c.Assert(obtained, check.DeepEquals, expected) + re.Equal(expected, obtained) obtainedStateName := s.Store.StateName expectedStateName := mapWant[obtained.Id].Store.StateName - c.Assert(obtainedStateName, check.Equals, expectedStateName) + re.Equal(expectedStateName, obtainedStateName) } } // CheckRegionInfo is used to check the test results. -func CheckRegionInfo(c *check.C, output *api.RegionInfo, expected *core.RegionInfo) { +func CheckRegionInfo(re *require.Assertions, output *api.RegionInfo, expected *core.RegionInfo) { region := api.NewRegionInfo(expected) output.Adjust() - c.Assert(output, check.DeepEquals, region) + re.Equal(region, output) } // CheckRegionsInfo is used to check the test results. -func CheckRegionsInfo(c *check.C, output *api.RegionsInfo, expected []*core.RegionInfo) { - c.Assert(output.Count, check.Equals, len(expected)) +func CheckRegionsInfo(re *require.Assertions, output *api.RegionsInfo, expected []*core.RegionInfo) { + re.Len(expected, output.Count) got := output.Regions sort.Slice(got, func(i, j int) bool { return got[i].ID < got[j].ID @@ -86,12 +87,26 @@ func CheckRegionsInfo(c *check.C, output *api.RegionsInfo, expected []*core.Regi return expected[i].GetID() < expected[j].GetID() }) for i, region := range expected { - CheckRegionInfo(c, &got[i], region) + CheckRegionInfo(re, &got[i], region) } } // MustPutStore is used for test purpose. -func MustPutStore(c *check.C, svr *server.Server, store *metapb.Store) { +func MustPutStore(re *require.Assertions, svr *server.Server, store *metapb.Store) { + store.Address = fmt.Sprintf("tikv%d", store.GetId()) + if len(store.Version) == 0 { + store.Version = versioninfo.MinSupportedVersion(versioninfo.Version2_0).String() + } + grpcServer := &server.GrpcServer{Server: svr} + _, err := grpcServer.PutStore(context.Background(), &pdpb.PutStoreRequest{ + Header: &pdpb.RequestHeader{ClusterId: svr.ClusterID()}, + Store: store, + }) + re.NoError(err) +} + +// MustPutStoreWithCheck is a temporary function for test purpose. +func MustPutStoreWithCheck(c *check.C, svr *server.Server, store *metapb.Store) { store.Address = fmt.Sprintf("tikv%d", store.GetId()) if len(store.Version) == 0 { store.Version = versioninfo.MinSupportedVersion(versioninfo.Version2_0).String() @@ -105,7 +120,26 @@ func MustPutStore(c *check.C, svr *server.Server, store *metapb.Store) { } // MustPutRegion is used for test purpose. -func MustPutRegion(c *check.C, cluster *tests.TestCluster, regionID, storeID uint64, start, end []byte, opts ...core.RegionCreateOption) *core.RegionInfo { +func MustPutRegion(re *require.Assertions, cluster *tests.TestCluster, regionID, storeID uint64, start, end []byte, opts ...core.RegionCreateOption) *core.RegionInfo { + leader := &metapb.Peer{ + Id: regionID, + StoreId: storeID, + } + metaRegion := &metapb.Region{ + Id: regionID, + StartKey: start, + EndKey: end, + Peers: []*metapb.Peer{leader}, + RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 1}, + } + r := core.NewRegionInfo(metaRegion, leader, opts...) + err := cluster.HandleRegionHeartbeat(r) + re.NoError(err) + return r +} + +// MustPutRegionWithCheck is a temporary function for test purpose. +func MustPutRegionWithCheck(c *check.C, cluster *tests.TestCluster, regionID, storeID uint64, start, end []byte, opts ...core.RegionCreateOption) *core.RegionInfo { leader := &metapb.Peer{ Id: regionID, StoreId: storeID, @@ -123,10 +157,12 @@ func MustPutRegion(c *check.C, cluster *tests.TestCluster, regionID, storeID uin return r } -func checkerWithNilAssert(c *check.C) *assertutil.Checker { - checker := assertutil.NewChecker(c.FailNow) +func checkerWithNilAssert(re *require.Assertions) *assertutil.Checker { + checker := assertutil.NewChecker(func() { + re.FailNow("should be nil") + }) checker.IsNil = func(obtained interface{}) { - c.Assert(obtained, check.IsNil) + re.Nil(obtained) } return checker } diff --git a/tests/pdctl/hot/hot_test.go b/tests/pdctl/hot/hot_test.go index 06a657df7d7..74148b40955 100644 --- a/tests/pdctl/hot/hot_test.go +++ b/tests/pdctl/hot/hot_test.go @@ -22,9 +22,9 @@ import ( "time" "github.com/gogo/protobuf/proto" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/server/api" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/core" @@ -35,22 +35,15 @@ import ( pdctlCmd "github.com/tikv/pd/tools/pd-ctl/pdctl" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&hotTestSuite{}) - -type hotTestSuite struct{} - -func (s *hotTestSuite) TestHot(c *C) { +func TestHot(t *testing.T) { + re := require.New(t) statistics.Denoising = false ctx, cancel := context.WithCancel(context.Background()) defer cancel() cluster, err := tests.NewTestCluster(ctx, 1) - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() pdAddr := cluster.GetConfig().GetClientURL() cmd := pdctlCmd.GetRootCmd() @@ -68,9 +61,9 @@ func (s *hotTestSuite) TestHot(c *C) { } leaderServer := cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) - pdctl.MustPutStore(c, leaderServer.GetServer(), store1) - pdctl.MustPutStore(c, leaderServer.GetServer(), store2) + re.NoError(leaderServer.BootstrapCluster()) + pdctl.MustPutStore(re, leaderServer.GetServer(), store1) + pdctl.MustPutStore(re, leaderServer.GetServer(), store2) defer cluster.Destroy() // test hot store @@ -99,33 +92,33 @@ func (s *hotTestSuite) TestHot(c *C) { args := []string{"-u", pdAddr, "hot", "store"} output, err := pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) hotStores := api.HotStoreStats{} - c.Assert(json.Unmarshal(output, &hotStores), IsNil) - c.Assert(hotStores.BytesWriteStats[1], Equals, float64(bytesWritten)/statistics.StoreHeartBeatReportInterval) - c.Assert(hotStores.BytesReadStats[1], Equals, float64(bytesRead)/statistics.StoreHeartBeatReportInterval) - c.Assert(hotStores.KeysWriteStats[1], Equals, float64(keysWritten)/statistics.StoreHeartBeatReportInterval) - c.Assert(hotStores.KeysReadStats[1], Equals, float64(keysRead)/statistics.StoreHeartBeatReportInterval) - c.Assert(hotStores.BytesWriteStats[2], Equals, float64(bytesWritten)) - c.Assert(hotStores.KeysWriteStats[2], Equals, float64(keysWritten)) + re.NoError(json.Unmarshal(output, &hotStores)) + re.Equal(float64(bytesWritten)/statistics.StoreHeartBeatReportInterval, hotStores.BytesWriteStats[1]) + re.Equal(float64(bytesRead)/statistics.StoreHeartBeatReportInterval, hotStores.BytesReadStats[1]) + re.Equal(float64(keysWritten)/statistics.StoreHeartBeatReportInterval, hotStores.KeysWriteStats[1]) + re.Equal(float64(keysRead)/statistics.StoreHeartBeatReportInterval, hotStores.KeysReadStats[1]) + re.Equal(float64(bytesWritten), hotStores.BytesWriteStats[2]) + re.Equal(float64(keysWritten), hotStores.KeysWriteStats[2]) // test hot region args = []string{"-u", pdAddr, "config", "set", "hot-region-cache-hits-threshold", "0"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) hotStoreID := store1.Id count := 0 testHot := func(hotRegionID, hotStoreID uint64, hotType string) { args = []string{"-u", pdAddr, "hot", hotType} - output, e := pdctl.ExecuteCommand(cmd, args...) + output, err := pdctl.ExecuteCommand(cmd, args...) + re.NoError(err) hotRegion := statistics.StoreHotPeersInfos{} - c.Assert(e, IsNil) - c.Assert(json.Unmarshal(output, &hotRegion), IsNil) - c.Assert(hotRegion.AsLeader, HasKey, hotStoreID) - c.Assert(hotRegion.AsLeader[hotStoreID].Count, Equals, count) + re.NoError(json.Unmarshal(output, &hotRegion)) + re.Contains(hotRegion.AsLeader, hotStoreID) + re.Equal(count, hotRegion.AsLeader[hotStoreID].Count) if count > 0 { - c.Assert(hotRegion.AsLeader[hotStoreID].Stats[count-1].RegionID, Equals, hotRegionID) + re.Equal(hotRegionID, hotRegion.AsLeader[hotStoreID].Stats[count-1].RegionID) } } @@ -159,7 +152,11 @@ func (s *hotTestSuite) TestHot(c *C) { } testHot(hotRegionID, hotStoreID, "read") case "write": - pdctl.MustPutRegion(c, cluster, hotRegionID, hotStoreID, []byte("c"), []byte("d"), core.SetWrittenBytes(1000000000*reportInterval), core.SetReportInterval(reportInterval)) + pdctl.MustPutRegion( + re, cluster, + hotRegionID, hotStoreID, + []byte("c"), []byte("d"), + core.SetWrittenBytes(1000000000*reportInterval), core.SetReportInterval(reportInterval)) time.Sleep(5000 * time.Millisecond) if reportInterval >= statistics.WriteReportInterval { count++ @@ -189,14 +186,15 @@ func (s *hotTestSuite) TestHot(c *C) { testCommand(reportIntervals, "read") } -func (s *hotTestSuite) TestHotWithStoreID(c *C) { +func TestHotWithStoreID(t *testing.T) { + re := require.New(t) statistics.Denoising = false ctx, cancel := context.WithCancel(context.Background()) defer cancel() cluster, err := tests.NewTestCluster(ctx, 1, func(cfg *config.Config, serverName string) { cfg.Schedule.HotRegionCacheHitsThreshold = 0 }) - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() pdAddr := cluster.GetConfig().GetClientURL() cmd := pdctlCmd.GetRootCmd() @@ -215,39 +213,40 @@ func (s *hotTestSuite) TestHotWithStoreID(c *C) { } leaderServer := cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) + re.NoError(leaderServer.BootstrapCluster()) for _, store := range stores { - pdctl.MustPutStore(c, leaderServer.GetServer(), store) + pdctl.MustPutStore(re, leaderServer.GetServer(), store) } defer cluster.Destroy() - pdctl.MustPutRegion(c, cluster, 1, 1, []byte("a"), []byte("b"), core.SetWrittenBytes(3000000000), core.SetReportInterval(statistics.WriteReportInterval)) - pdctl.MustPutRegion(c, cluster, 2, 2, []byte("c"), []byte("d"), core.SetWrittenBytes(6000000000), core.SetReportInterval(statistics.WriteReportInterval)) - pdctl.MustPutRegion(c, cluster, 3, 1, []byte("e"), []byte("f"), core.SetWrittenBytes(9000000000), core.SetReportInterval(statistics.WriteReportInterval)) + pdctl.MustPutRegion(re, cluster, 1, 1, []byte("a"), []byte("b"), core.SetWrittenBytes(3000000000), core.SetReportInterval(statistics.WriteReportInterval)) + pdctl.MustPutRegion(re, cluster, 2, 2, []byte("c"), []byte("d"), core.SetWrittenBytes(6000000000), core.SetReportInterval(statistics.WriteReportInterval)) + pdctl.MustPutRegion(re, cluster, 3, 1, []byte("e"), []byte("f"), core.SetWrittenBytes(9000000000), core.SetReportInterval(statistics.WriteReportInterval)) // wait hot scheduler starts time.Sleep(5000 * time.Millisecond) args := []string{"-u", pdAddr, "hot", "write", "1"} - output, e := pdctl.ExecuteCommand(cmd, args...) + output, err := pdctl.ExecuteCommand(cmd, args...) hotRegion := statistics.StoreHotPeersInfos{} - c.Assert(e, IsNil) - c.Assert(json.Unmarshal(output, &hotRegion), IsNil) - c.Assert(hotRegion.AsLeader, HasLen, 1) - c.Assert(hotRegion.AsLeader[1].Count, Equals, 2) - c.Assert(hotRegion.AsLeader[1].TotalBytesRate, Equals, float64(200000000)) + re.NoError(err) + re.NoError(json.Unmarshal(output, &hotRegion)) + re.Len(hotRegion.AsLeader, 1) + re.Equal(2, hotRegion.AsLeader[1].Count) + re.Equal(float64(200000000), hotRegion.AsLeader[1].TotalBytesRate) args = []string{"-u", pdAddr, "hot", "write", "1", "2"} - output, e = pdctl.ExecuteCommand(cmd, args...) + output, err = pdctl.ExecuteCommand(cmd, args...) + re.NoError(err) hotRegion = statistics.StoreHotPeersInfos{} - c.Assert(e, IsNil) - c.Assert(json.Unmarshal(output, &hotRegion), IsNil) - c.Assert(hotRegion.AsLeader, HasLen, 2) - c.Assert(hotRegion.AsLeader[1].Count, Equals, 2) - c.Assert(hotRegion.AsLeader[2].Count, Equals, 1) - c.Assert(hotRegion.AsLeader[1].TotalBytesRate, Equals, float64(200000000)) - c.Assert(hotRegion.AsLeader[2].TotalBytesRate, Equals, float64(100000000)) + re.NoError(json.Unmarshal(output, &hotRegion)) + re.Len(hotRegion.AsLeader, 2) + re.Equal(2, hotRegion.AsLeader[1].Count) + re.Equal(1, hotRegion.AsLeader[2].Count) + re.Equal(float64(200000000), hotRegion.AsLeader[1].TotalBytesRate) + re.Equal(float64(100000000), hotRegion.AsLeader[2].TotalBytesRate) } -func (s *hotTestSuite) TestHistoryHotRegions(c *C) { +func TestHistoryHotRegions(t *testing.T) { + re := require.New(t) statistics.Denoising = false ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -258,9 +257,9 @@ func (s *hotTestSuite) TestHistoryHotRegions(c *C) { cfg.Schedule.HotRegionsReservedDays = 1 }, ) - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() pdAddr := cluster.GetConfig().GetClientURL() cmd := pdctlCmd.GetRootCmd() @@ -284,16 +283,16 @@ func (s *hotTestSuite) TestHistoryHotRegions(c *C) { } leaderServer := cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) + re.NoError(leaderServer.BootstrapCluster()) for _, store := range stores { - pdctl.MustPutStore(c, leaderServer.GetServer(), store) + pdctl.MustPutStore(re, leaderServer.GetServer(), store) } defer cluster.Destroy() startTime := time.Now().UnixNano() / int64(time.Millisecond) - pdctl.MustPutRegion(c, cluster, 1, 1, []byte("a"), []byte("b"), core.SetWrittenBytes(3000000000), core.SetReportInterval(statistics.WriteReportInterval)) - pdctl.MustPutRegion(c, cluster, 2, 2, []byte("c"), []byte("d"), core.SetWrittenBytes(6000000000), core.SetReportInterval(statistics.WriteReportInterval)) - pdctl.MustPutRegion(c, cluster, 3, 1, []byte("e"), []byte("f"), core.SetWrittenBytes(9000000000), core.SetReportInterval(statistics.WriteReportInterval)) - pdctl.MustPutRegion(c, cluster, 4, 3, []byte("g"), []byte("h"), core.SetWrittenBytes(9000000000), core.SetReportInterval(statistics.WriteReportInterval)) + pdctl.MustPutRegion(re, cluster, 1, 1, []byte("a"), []byte("b"), core.SetWrittenBytes(3000000000), core.SetReportInterval(statistics.WriteReportInterval)) + pdctl.MustPutRegion(re, cluster, 2, 2, []byte("c"), []byte("d"), core.SetWrittenBytes(6000000000), core.SetReportInterval(statistics.WriteReportInterval)) + pdctl.MustPutRegion(re, cluster, 3, 1, []byte("e"), []byte("f"), core.SetWrittenBytes(9000000000), core.SetReportInterval(statistics.WriteReportInterval)) + pdctl.MustPutRegion(re, cluster, 4, 3, []byte("g"), []byte("h"), core.SetWrittenBytes(9000000000), core.SetReportInterval(statistics.WriteReportInterval)) // wait hot scheduler starts time.Sleep(5000 * time.Millisecond) endTime := time.Now().UnixNano() / int64(time.Millisecond) @@ -306,54 +305,54 @@ func (s *hotTestSuite) TestHistoryHotRegions(c *C) { "store_id", "1,4", "is_learner", "false", } - output, e := pdctl.ExecuteCommand(cmd, args...) + output, err := pdctl.ExecuteCommand(cmd, args...) hotRegions := storage.HistoryHotRegions{} - c.Assert(e, IsNil) - c.Assert(json.Unmarshal(output, &hotRegions), IsNil) + re.NoError(err) + re.NoError(json.Unmarshal(output, &hotRegions)) regions := hotRegions.HistoryHotRegion - c.Assert(len(regions), Equals, 1) - c.Assert(regions[0].RegionID, Equals, uint64(1)) - c.Assert(regions[0].StoreID, Equals, uint64(1)) - c.Assert(regions[0].HotRegionType, Equals, "write") + re.Len(regions, 1) + re.Equal(uint64(1), regions[0].RegionID) + re.Equal(uint64(1), regions[0].StoreID) + re.Equal("write", regions[0].HotRegionType) args = []string{"-u", pdAddr, "hot", "history", start, end, "hot_region_type", "write", "region_id", "1,2", "store_id", "1,2", } - output, e = pdctl.ExecuteCommand(cmd, args...) - c.Assert(e, IsNil) - c.Assert(json.Unmarshal(output, &hotRegions), IsNil) + output, err = pdctl.ExecuteCommand(cmd, args...) + re.NoError(err) + re.NoError(json.Unmarshal(output, &hotRegions)) regions = hotRegions.HistoryHotRegion - c.Assert(len(regions), Equals, 2) + re.Len(regions, 2) isSort := regions[0].UpdateTime > regions[1].UpdateTime || regions[0].RegionID < regions[1].RegionID - c.Assert(isSort, Equals, true) + re.True(isSort) args = []string{"-u", pdAddr, "hot", "history", start, end, "hot_region_type", "read", "is_leader", "false", "peer_id", "12", } - output, e = pdctl.ExecuteCommand(cmd, args...) - c.Assert(e, IsNil) - c.Assert(json.Unmarshal(output, &hotRegions), IsNil) - c.Assert(len(hotRegions.HistoryHotRegion), Equals, 0) + output, err = pdctl.ExecuteCommand(cmd, args...) + re.NoError(err) + re.NoError(json.Unmarshal(output, &hotRegions)) + re.Len(hotRegions.HistoryHotRegion, 0) args = []string{"-u", pdAddr, "hot", "history"} - output, e = pdctl.ExecuteCommand(cmd, args...) - c.Assert(e, IsNil) - c.Assert(json.Unmarshal(output, &hotRegions), NotNil) + output, err = pdctl.ExecuteCommand(cmd, args...) + re.NoError(err) + re.Error(json.Unmarshal(output, &hotRegions)) args = []string{"-u", pdAddr, "hot", "history", start, end, "region_id", "dada", } - output, e = pdctl.ExecuteCommand(cmd, args...) - c.Assert(e, IsNil) - c.Assert(json.Unmarshal(output, &hotRegions), NotNil) + output, err = pdctl.ExecuteCommand(cmd, args...) + re.NoError(err) + re.Error(json.Unmarshal(output, &hotRegions)) args = []string{"-u", pdAddr, "hot", "history", start, end, "region_ids", "12323", } - output, e = pdctl.ExecuteCommand(cmd, args...) - c.Assert(e, IsNil) - c.Assert(json.Unmarshal(output, &hotRegions), NotNil) + output, err = pdctl.ExecuteCommand(cmd, args...) + re.NoError(err) + re.Error(json.Unmarshal(output, &hotRegions)) } diff --git a/tests/pdctl/label/label_test.go b/tests/pdctl/label/label_test.go index 50a52413e82..ba31b1fb1d1 100644 --- a/tests/pdctl/label/label_test.go +++ b/tests/pdctl/label/label_test.go @@ -21,8 +21,8 @@ import ( "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/server/api" "github.com/tikv/pd/server/config" "github.com/tikv/pd/tests" @@ -30,21 +30,14 @@ import ( pdctlCmd "github.com/tikv/pd/tools/pd-ctl/pdctl" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&labelTestSuite{}) - -type labelTestSuite struct{} - -func (s *labelTestSuite) TestLabel(c *C) { +func TestLabel(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() cluster, err := tests.NewTestCluster(ctx, 1, func(cfg *config.Config, serverName string) { cfg.Replication.StrictlyMatchLabel = false }) - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() pdAddr := cluster.GetConfig().GetClientURL() cmd := pdctlCmd.GetRootCmd() @@ -100,19 +93,19 @@ func (s *labelTestSuite) TestLabel(c *C) { }, } leaderServer := cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) + re.NoError(leaderServer.BootstrapCluster()) for _, store := range stores { - pdctl.MustPutStore(c, leaderServer.GetServer(), store.Store.Store) + pdctl.MustPutStore(re, leaderServer.GetServer(), store.Store.Store) } defer cluster.Destroy() // label command args := []string{"-u", pdAddr, "label"} output, err := pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) labels := make([]*metapb.StoreLabel, 0, len(stores)) - c.Assert(json.Unmarshal(output, &labels), IsNil) + re.NoError(json.Unmarshal(output, &labels)) got := make(map[string]struct{}) for _, l := range labels { if _, ok := got[strings.ToLower(l.Key+l.Value)]; !ok { @@ -129,21 +122,21 @@ func (s *labelTestSuite) TestLabel(c *C) { } } } - c.Assert(got, DeepEquals, expected) + re.Equal(expected, got) // label store command args = []string{"-u", pdAddr, "label", "store", "zone", "us-west"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) storesInfo := new(api.StoresInfo) - c.Assert(json.Unmarshal(output, &storesInfo), IsNil) + re.NoError(json.Unmarshal(output, &storesInfo)) sss := []*api.StoreInfo{stores[0], stores[2]} - pdctl.CheckStoresInfo(c, storesInfo.Stores, sss) + pdctl.CheckStoresInfo(re, storesInfo.Stores, sss) // label isolation [label] args = []string{"-u", pdAddr, "label", "isolation"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "none"), IsTrue) - c.Assert(strings.Contains(string(output), "2"), IsTrue) + re.NoError(err) + re.Contains(string(output), "none") + re.Contains(string(output), "2") } diff --git a/tests/pdctl/log/log_test.go b/tests/pdctl/log/log_test.go index 6499b2694c7..7f2e4f20584 100644 --- a/tests/pdctl/log/log_test.go +++ b/tests/pdctl/log/log_test.go @@ -19,21 +19,16 @@ import ( "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/suite" "github.com/tikv/pd/server" "github.com/tikv/pd/tests" "github.com/tikv/pd/tests/pdctl" pdctlCmd "github.com/tikv/pd/tools/pd-ctl/pdctl" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&logTestSuite{}) - type logTestSuite struct { + suite.Suite ctx context.Context cancel context.CancelFunc cluster *tests.TestCluster @@ -41,33 +36,36 @@ type logTestSuite struct { pdAddrs []string } -func (s *logTestSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) +func TestLogTestSuite(t *testing.T) { + suite.Run(t, new(logTestSuite)) +} + +func (suite *logTestSuite) SetupSuite() { + suite.ctx, suite.cancel = context.WithCancel(context.Background()) var err error - s.cluster, err = tests.NewTestCluster(s.ctx, 3) - c.Assert(err, IsNil) - err = s.cluster.RunInitialServers() - c.Assert(err, IsNil) - s.cluster.WaitLeader() - s.pdAddrs = s.cluster.GetConfig().GetClientURLs() + suite.cluster, err = tests.NewTestCluster(suite.ctx, 3) + suite.NoError(err) + suite.NoError(suite.cluster.RunInitialServers()) + suite.cluster.WaitLeader() + suite.pdAddrs = suite.cluster.GetConfig().GetClientURLs() store := &metapb.Store{ Id: 1, State: metapb.StoreState_Up, LastHeartbeat: time.Now().UnixNano(), } - leaderServer := s.cluster.GetServer(s.cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) - s.svr = leaderServer.GetServer() - pdctl.MustPutStore(c, s.svr, store) + leaderServer := suite.cluster.GetServer(suite.cluster.GetLeader()) + suite.NoError(leaderServer.BootstrapCluster()) + suite.svr = leaderServer.GetServer() + pdctl.MustPutStore(suite.Require(), suite.svr, store) } -func (s *logTestSuite) TearDownSuite(c *C) { - s.cluster.Destroy() - s.cancel() +func (suite *logTestSuite) TearDownSuite() { + suite.cancel() + suite.cluster.Destroy() } -func (s *logTestSuite) TestLog(c *C) { +func (suite *logTestSuite) TestLog() { cmd := pdctlCmd.GetRootCmd() var testCases = []struct { cmd []string @@ -75,35 +73,35 @@ func (s *logTestSuite) TestLog(c *C) { }{ // log [fatal|error|warn|info|debug] { - cmd: []string{"-u", s.pdAddrs[0], "log", "fatal"}, + cmd: []string{"-u", suite.pdAddrs[0], "log", "fatal"}, expect: "fatal", }, { - cmd: []string{"-u", s.pdAddrs[0], "log", "error"}, + cmd: []string{"-u", suite.pdAddrs[0], "log", "error"}, expect: "error", }, { - cmd: []string{"-u", s.pdAddrs[0], "log", "warn"}, + cmd: []string{"-u", suite.pdAddrs[0], "log", "warn"}, expect: "warn", }, { - cmd: []string{"-u", s.pdAddrs[0], "log", "info"}, + cmd: []string{"-u", suite.pdAddrs[0], "log", "info"}, expect: "info", }, { - cmd: []string{"-u", s.pdAddrs[0], "log", "debug"}, + cmd: []string{"-u", suite.pdAddrs[0], "log", "debug"}, expect: "debug", }, } for _, testCase := range testCases { _, err := pdctl.ExecuteCommand(cmd, testCase.cmd...) - c.Assert(err, IsNil) - c.Assert(s.svr.GetConfig().Log.Level, Equals, testCase.expect) + suite.NoError(err) + suite.Equal(testCase.expect, suite.svr.GetConfig().Log.Level) } } -func (s *logTestSuite) TestInstanceLog(c *C) { +func (suite *logTestSuite) TestInstanceLog() { cmd := pdctlCmd.GetRootCmd() var testCases = []struct { cmd []string @@ -112,29 +110,29 @@ func (s *logTestSuite) TestInstanceLog(c *C) { }{ // log [fatal|error|warn|info|debug] [address] { - cmd: []string{"-u", s.pdAddrs[0], "log", "debug", s.pdAddrs[0]}, - instance: s.pdAddrs[0], + cmd: []string{"-u", suite.pdAddrs[0], "log", "debug", suite.pdAddrs[0]}, + instance: suite.pdAddrs[0], expect: "debug", }, { - cmd: []string{"-u", s.pdAddrs[0], "log", "error", s.pdAddrs[1]}, - instance: s.pdAddrs[1], + cmd: []string{"-u", suite.pdAddrs[0], "log", "error", suite.pdAddrs[1]}, + instance: suite.pdAddrs[1], expect: "error", }, { - cmd: []string{"-u", s.pdAddrs[0], "log", "warn", s.pdAddrs[2]}, - instance: s.pdAddrs[2], + cmd: []string{"-u", suite.pdAddrs[0], "log", "warn", suite.pdAddrs[2]}, + instance: suite.pdAddrs[2], expect: "warn", }, } for _, testCase := range testCases { _, err := pdctl.ExecuteCommand(cmd, testCase.cmd...) - c.Assert(err, IsNil) - svrs := s.cluster.GetServers() + suite.NoError(err) + svrs := suite.cluster.GetServers() for _, svr := range svrs { if svr.GetAddr() == testCase.instance { - c.Assert(svr.GetConfig().Log.Level, Equals, testCase.expect) + suite.Equal(testCase.expect, svr.GetConfig().Log.Level) } } } diff --git a/tests/pdctl/member/member_test.go b/tests/pdctl/member/member_test.go index f85f2d946df..2c93a9c6c53 100644 --- a/tests/pdctl/member/member_test.go +++ b/tests/pdctl/member/member_test.go @@ -18,11 +18,10 @@ import ( "context" "encoding/json" "fmt" - "strings" "testing" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/etcdutil" "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/tests" @@ -30,26 +29,19 @@ import ( pdctlCmd "github.com/tikv/pd/tools/pd-ctl/pdctl" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&memberTestSuite{}) - -type memberTestSuite struct{} - -func (s *memberTestSuite) TestMember(c *C) { +func TestMember(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() cluster, err := tests.NewTestCluster(ctx, 3) - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() leaderServer := cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) + re.NoError(leaderServer.BootstrapCluster()) pdAddr := cluster.GetConfig().GetClientURL() - c.Assert(err, IsNil) + re.NoError(err) cmd := pdctlCmd.GetRootCmd() svr := cluster.GetServer("pd2") id := svr.GetServerID() @@ -60,57 +52,56 @@ func (s *memberTestSuite) TestMember(c *C) { // member leader show args := []string{"-u", pdAddr, "member", "leader", "show"} output, err := pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) leader := pdpb.Member{} - c.Assert(json.Unmarshal(output, &leader), IsNil) - c.Assert(&leader, DeepEquals, svr.GetLeader()) + re.NoError(json.Unmarshal(output, &leader)) + re.Equal(svr.GetLeader(), &leader) // member leader transfer args = []string{"-u", pdAddr, "member", "leader", "transfer", "pd2"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - testutil.WaitUntil(c, func() bool { - return c.Check("pd2", Equals, svr.GetLeader().GetName()) + re.NoError(err) + testutil.Eventually(re, func() bool { + return svr.GetLeader().GetName() == "pd2" }) // member leader resign cluster.WaitLeader() args = []string{"-u", pdAddr, "member", "leader", "resign"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(strings.Contains(string(output), "Success"), IsTrue) - c.Assert(err, IsNil) - testutil.WaitUntil(c, func() bool { - return c.Check("pd2", Not(Equals), svr.GetLeader().GetName()) + re.Contains(string(output), "Success") + re.NoError(err) + testutil.Eventually(re, func() bool { + return svr.GetLeader().GetName() != "pd2" }) // member leader_priority cluster.WaitLeader() args = []string{"-u", pdAddr, "member", "leader_priority", name, "100"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) priority, err := svr.GetServer().GetMember().GetMemberLeaderPriority(id) - c.Assert(err, IsNil) - c.Assert(priority, Equals, 100) + re.NoError(err) + re.Equal(100, priority) // member delete name err = svr.Destroy() - c.Assert(err, IsNil) + re.NoError(err) members, err := etcdutil.ListEtcdMembers(client) - c.Assert(err, IsNil) - c.Assert(members.Members, HasLen, 3) + re.NoError(err) + re.Len(members.Members, 3) args = []string{"-u", pdAddr, "member", "delete", "name", name} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) members, err = etcdutil.ListEtcdMembers(client) - c.Assert(err, IsNil) - c.Assert(members.Members, HasLen, 2) + re.NoError(err) + re.Len(members.Members, 2) // member delete id args = []string{"-u", pdAddr, "member", "delete", "id", fmt.Sprint(id)} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) members, err = etcdutil.ListEtcdMembers(client) - c.Assert(err, IsNil) - c.Assert(members.Members, HasLen, 2) - c.Succeed() + re.NoError(err) + re.Len(members.Members, 2) } diff --git a/tests/pdctl/operator/operator_test.go b/tests/pdctl/operator/operator_test.go index 73ae2687c80..b8433520381 100644 --- a/tests/pdctl/operator/operator_test.go +++ b/tests/pdctl/operator/operator_test.go @@ -21,8 +21,8 @@ import ( "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/core" "github.com/tikv/pd/tests" @@ -30,31 +30,26 @@ import ( pdctlCmd "github.com/tikv/pd/tools/pd-ctl/pdctl" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&operatorTestSuite{}) - -type operatorTestSuite struct{} - -func (s *operatorTestSuite) TestOperator(c *C) { +func TestOperator(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() var err error - var t time.Time - t = t.Add(time.Hour) + var start time.Time + start = start.Add(time.Hour) cluster, err := tests.NewTestCluster(ctx, 1, // TODO: enable placementrules func(conf *config.Config, serverName string) { conf.Replication.MaxReplicas = 2 conf.Replication.EnablePlacementRules = false }, - func(conf *config.Config, serverName string) { conf.Schedule.MaxStoreDownTime.Duration = time.Since(t) }, + func(conf *config.Config, serverName string) { + conf.Schedule.MaxStoreDownTime.Duration = time.Since(start) + }, ) - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() pdAddr := cluster.GetConfig().GetClientURL() cmd := pdctlCmd.GetRootCmd() @@ -83,16 +78,16 @@ func (s *operatorTestSuite) TestOperator(c *C) { } leaderServer := cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) + re.NoError(leaderServer.BootstrapCluster()) for _, store := range stores { - pdctl.MustPutStore(c, leaderServer.GetServer(), store) + pdctl.MustPutStore(re, leaderServer.GetServer(), store) } - pdctl.MustPutRegion(c, cluster, 1, 1, []byte("a"), []byte("b"), core.SetPeers([]*metapb.Peer{ + pdctl.MustPutRegion(re, cluster, 1, 1, []byte("a"), []byte("b"), core.SetPeers([]*metapb.Peer{ {Id: 1, StoreId: 1}, {Id: 2, StoreId: 2}, })) - pdctl.MustPutRegion(c, cluster, 3, 2, []byte("b"), []byte("c"), core.SetPeers([]*metapb.Peer{ + pdctl.MustPutRegion(re, cluster, 3, 2, []byte("b"), []byte("c"), core.SetPeers([]*metapb.Peer{ {Id: 3, StoreId: 1}, {Id: 4, StoreId: 2}, })) @@ -170,78 +165,80 @@ func (s *operatorTestSuite) TestOperator(c *C) { } for _, testCase := range testCases { - _, e := pdctl.ExecuteCommand(cmd, testCase.cmd...) - c.Assert(e, IsNil) - output, e := pdctl.ExecuteCommand(cmd, testCase.show...) - c.Assert(e, IsNil) - c.Assert(strings.Contains(string(output), testCase.expect), IsTrue) - t := time.Now() - _, e = pdctl.ExecuteCommand(cmd, testCase.reset...) - c.Assert(e, IsNil) - historyCmd := []string{"-u", pdAddr, "operator", "history", strconv.FormatInt(t.Unix(), 10)} - records, e := pdctl.ExecuteCommand(cmd, historyCmd...) - c.Assert(e, IsNil) - c.Assert(strings.Contains(string(records), "admin"), IsTrue) + _, err := pdctl.ExecuteCommand(cmd, testCase.cmd...) + re.NoError(err) + output, err := pdctl.ExecuteCommand(cmd, testCase.show...) + re.NoError(err) + re.Contains(string(output), testCase.expect) + start := time.Now() + _, err = pdctl.ExecuteCommand(cmd, testCase.reset...) + re.NoError(err) + historyCmd := []string{"-u", pdAddr, "operator", "history", strconv.FormatInt(start.Unix(), 10)} + records, err := pdctl.ExecuteCommand(cmd, historyCmd...) + re.NoError(err) + re.Contains(string(records), "admin") } // operator add merge-region args := []string{"-u", pdAddr, "operator", "add", "merge-region", "1", "3"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) args = []string{"-u", pdAddr, "operator", "show"} output, err := pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "merge region 1 into region 3"), IsTrue) + re.NoError(err) + re.Contains(string(output), "merge region 1 into region 3") args = []string{"-u", pdAddr, "operator", "remove", "1"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) args = []string{"-u", pdAddr, "operator", "remove", "3"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) _, err = pdctl.ExecuteCommand(cmd, "config", "set", "enable-placement-rules", "true") - c.Assert(err, IsNil) + re.NoError(err) output, err = pdctl.ExecuteCommand(cmd, "operator", "add", "transfer-region", "1", "2", "3") - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "not supported"), IsTrue) + re.NoError(err) + re.Contains(string(output), "not supported") output, err = pdctl.ExecuteCommand(cmd, "operator", "add", "transfer-region", "1", "2", "follower", "3") - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "not match"), IsTrue) + re.NoError(err) + re.Contains(string(output), "not match") output, err = pdctl.ExecuteCommand(cmd, "operator", "add", "transfer-peer", "1", "2", "4") - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "is unhealthy"), IsTrue) + re.NoError(err) + re.Contains(string(output), "is unhealthy") output, err = pdctl.ExecuteCommand(cmd, "operator", "add", "transfer-region", "1", "2", "leader", "4", "follower") - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "is unhealthy"), IsTrue) + re.NoError(err) + re.Contains(string(output), "is unhealthy") output, err = pdctl.ExecuteCommand(cmd, "operator", "add", "transfer-region", "1", "2", "follower", "leader", "3", "follower") - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "invalid"), IsTrue) + re.NoError(err) + re.Contains(string(output), "invalid") output, err = pdctl.ExecuteCommand(cmd, "operator", "add", "transfer-region", "1", "leader", "2", "follower", "3") - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "invalid"), IsTrue) + re.NoError(err) + re.Contains(string(output), "invalid") output, err = pdctl.ExecuteCommand(cmd, "operator", "add", "transfer-region", "1", "2", "leader", "3", "follower") - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "Success!"), IsTrue) + re.NoError(err) + re.Contains(string(output), "Success!") output, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "operator", "remove", "1") - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "Success!"), IsTrue) + re.NoError(err) + re.Contains(string(output), "Success!") _, err = pdctl.ExecuteCommand(cmd, "config", "set", "enable-placement-rules", "false") - c.Assert(err, IsNil) + re.NoError(err) // operator add scatter-region args = []string{"-u", pdAddr, "operator", "add", "scatter-region", "3"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) args = []string{"-u", pdAddr, "operator", "add", "scatter-region", "1"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) args = []string{"-u", pdAddr, "operator", "show", "region"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "scatter-region"), IsTrue) + re.NoError(err) + re.Contains(string(output), "scatter-region") // test echo, as the scatter region result is random, both region 1 and region 3 can be the region to be scattered output1, _ := pdctl.ExecuteCommand(cmd, "-u", pdAddr, "operator", "remove", "1") output2, _ := pdctl.ExecuteCommand(cmd, "-u", pdAddr, "operator", "remove", "3") - c.Assert(strings.Contains(string(output1), "Success!") || strings.Contains(string(output2), "Success!"), IsTrue) + re.Condition(func() bool { + return strings.Contains(string(output1), "Success!") || strings.Contains(string(output2), "Success!") + }) } diff --git a/tests/pdctl/region/region_test.go b/tests/pdctl/region/region_test.go index dd83accea55..951433bd432 100644 --- a/tests/pdctl/region/region_test.go +++ b/tests/pdctl/region/region_test.go @@ -17,13 +17,12 @@ package region_test import ( "context" "encoding/json" - "strings" "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/server/api" "github.com/tikv/pd/server/core" "github.com/tikv/pd/tests" @@ -31,21 +30,14 @@ import ( pdctlCmd "github.com/tikv/pd/tools/pd-ctl/pdctl" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(®ionTestSuite{}) - -type regionTestSuite struct{} - -func (s *regionTestSuite) TestRegionKeyFormat(c *C) { +func TestRegionKeyFormat(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() cluster, err := tests.NewTestCluster(ctx, 1) - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() url := cluster.GetConfig().GetClientURL() store := &metapb.Store{ @@ -54,22 +46,23 @@ func (s *regionTestSuite) TestRegionKeyFormat(c *C) { LastHeartbeat: time.Now().UnixNano(), } leaderServer := cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) - pdctl.MustPutStore(c, leaderServer.GetServer(), store) + re.NoError(leaderServer.BootstrapCluster()) + pdctl.MustPutStore(re, leaderServer.GetServer(), store) cmd := pdctlCmd.GetRootCmd() - output, e := pdctl.ExecuteCommand(cmd, "-u", url, "region", "key", "--format=raw", " ") - c.Assert(e, IsNil) - c.Assert(strings.Contains(string(output), "unknown flag"), IsFalse) + output, err := pdctl.ExecuteCommand(cmd, "-u", url, "region", "key", "--format=raw", " ") + re.NoError(err) + re.NotContains(string(output), "unknown flag") } -func (s *regionTestSuite) TestRegion(c *C) { +func TestRegion(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() cluster, err := tests.NewTestCluster(ctx, 1) - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() pdAddr := cluster.GetConfig().GetClientURL() cmd := pdctlCmd.GetRootCmd() @@ -80,11 +73,11 @@ func (s *regionTestSuite) TestRegion(c *C) { LastHeartbeat: time.Now().UnixNano(), } leaderServer := cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) - pdctl.MustPutStore(c, leaderServer.GetServer(), store) + re.NoError(leaderServer.BootstrapCluster()) + pdctl.MustPutStore(re, leaderServer.GetServer(), store) downPeer := &metapb.Peer{Id: 8, StoreId: 3} - r1 := pdctl.MustPutRegion(c, cluster, 1, 1, []byte("a"), []byte("b"), + r1 := pdctl.MustPutRegion(re, cluster, 1, 1, []byte("a"), []byte("b"), core.SetWrittenBytes(1000), core.SetReadBytes(1000), core.SetRegionConfVer(1), core.SetRegionVersion(1), core.SetApproximateSize(1), core.SetApproximateKeys(100), core.SetPeers([]*metapb.Peer{ @@ -93,16 +86,16 @@ func (s *regionTestSuite) TestRegion(c *C) { {Id: 6, StoreId: 3}, {Id: 7, StoreId: 4}, })) - r2 := pdctl.MustPutRegion(c, cluster, 2, 1, []byte("b"), []byte("c"), + r2 := pdctl.MustPutRegion(re, cluster, 2, 1, []byte("b"), []byte("c"), core.SetWrittenBytes(2000), core.SetReadBytes(0), core.SetRegionConfVer(2), core.SetRegionVersion(3), core.SetApproximateSize(144), core.SetApproximateKeys(14400), ) - r3 := pdctl.MustPutRegion(c, cluster, 3, 1, []byte("c"), []byte("d"), + r3 := pdctl.MustPutRegion(re, cluster, 3, 1, []byte("c"), []byte("d"), core.SetWrittenBytes(500), core.SetReadBytes(800), core.SetRegionConfVer(3), core.SetRegionVersion(2), core.SetApproximateSize(30), core.SetApproximateKeys(3000), core.WithDownPeers([]*pdpb.PeerStats{{Peer: downPeer, DownSeconds: 3600}}), core.WithPendingPeers([]*metapb.Peer{downPeer}), core.WithLearners([]*metapb.Peer{{Id: 3, StoreId: 1}})) - r4 := pdctl.MustPutRegion(c, cluster, 4, 1, []byte("d"), []byte("e"), + r4 := pdctl.MustPutRegion(re, cluster, 4, 1, []byte("d"), []byte("e"), core.SetWrittenBytes(100), core.SetReadBytes(100), core.SetRegionConfVer(1), core.SetRegionVersion(1), core.SetApproximateSize(10), core.SetApproximateKeys(1000), ) @@ -173,11 +166,11 @@ func (s *regionTestSuite) TestRegion(c *C) { for _, testCase := range testRegionsCases { args := append([]string{"-u", pdAddr}, testCase.args...) - output, e := pdctl.ExecuteCommand(cmd, args...) - c.Assert(e, IsNil) + output, err := pdctl.ExecuteCommand(cmd, args...) + re.NoError(err) regions := &api.RegionsInfo{} - c.Assert(json.Unmarshal(output, regions), IsNil) - pdctl.CheckRegionsInfo(c, regions, testCase.expect) + re.NoError(json.Unmarshal(output, regions)) + pdctl.CheckRegionsInfo(re, regions, testCase.expect) } var testRegionCases = []struct { @@ -196,22 +189,22 @@ func (s *regionTestSuite) TestRegion(c *C) { for _, testCase := range testRegionCases { args := append([]string{"-u", pdAddr}, testCase.args...) - output, e := pdctl.ExecuteCommand(cmd, args...) - c.Assert(e, IsNil) + output, err := pdctl.ExecuteCommand(cmd, args...) + re.NoError(err) region := &api.RegionInfo{} - c.Assert(json.Unmarshal(output, region), IsNil) - pdctl.CheckRegionInfo(c, region, testCase.expect) + re.NoError(json.Unmarshal(output, region)) + pdctl.CheckRegionInfo(re, region, testCase.expect) } // Test region range-holes. - r5 := pdctl.MustPutRegion(c, cluster, 5, 1, []byte("x"), []byte("z")) - output, e := pdctl.ExecuteCommand(cmd, []string{"-u", pdAddr, "region", "range-holes"}...) - c.Assert(e, IsNil) + r5 := pdctl.MustPutRegion(re, cluster, 5, 1, []byte("x"), []byte("z")) + output, err := pdctl.ExecuteCommand(cmd, []string{"-u", pdAddr, "region", "range-holes"}...) + re.NoError(err) rangeHoles := new([][]string) - c.Assert(json.Unmarshal(output, rangeHoles), IsNil) - c.Assert(*rangeHoles, DeepEquals, [][]string{ + re.NoError(json.Unmarshal(output, rangeHoles)) + re.Equal([][]string{ {"", core.HexRegionKeyStr(r1.GetStartKey())}, {core.HexRegionKeyStr(r4.GetEndKey()), core.HexRegionKeyStr(r5.GetStartKey())}, {core.HexRegionKeyStr(r5.GetEndKey()), ""}, - }) + }, *rangeHoles) } diff --git a/tests/pdctl/scheduler/scheduler_test.go b/tests/pdctl/scheduler/scheduler_test.go index 53ed808f410..3a3846603dd 100644 --- a/tests/pdctl/scheduler/scheduler_test.go +++ b/tests/pdctl/scheduler/scheduler_test.go @@ -17,12 +17,11 @@ package scheduler_test import ( "context" "encoding/json" - "strings" "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/versioninfo" "github.com/tikv/pd/tests" @@ -30,30 +29,14 @@ import ( pdctlCmd "github.com/tikv/pd/tools/pd-ctl/pdctl" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&schedulerTestSuite{}) - -type schedulerTestSuite struct { - context context.Context - cancel context.CancelFunc -} - -func (s *schedulerTestSuite) SetUpSuite(c *C) { - s.context, s.cancel = context.WithCancel(context.Background()) -} - -func (s *schedulerTestSuite) TearDownSuite(c *C) { - s.cancel() -} - -func (s *schedulerTestSuite) TestScheduler(c *C) { - cluster, err := tests.NewTestCluster(s.context, 1) - c.Assert(err, IsNil) +func TestScheduler(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 1) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() pdAddr := cluster.GetConfig().GetClientURL() cmd := pdctlCmd.GetRootCmd() @@ -83,18 +66,18 @@ func (s *schedulerTestSuite) TestScheduler(c *C) { mustExec := func(args []string, v interface{}) string { output, err := pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) if v == nil { return string(output) } - c.Assert(json.Unmarshal(output, v), IsNil) + re.NoError(json.Unmarshal(output, v)) return "" } mustUsage := func(args []string) { output, err := pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "Usage"), IsTrue) + re.NoError(err) + re.Contains(string(output), "Usage") } checkSchedulerCommand := func(args []string, expected map[string]bool) { @@ -104,7 +87,7 @@ func (s *schedulerTestSuite) TestScheduler(c *C) { var schedulers []string mustExec([]string{"-u", pdAddr, "scheduler", "show"}, &schedulers) for _, scheduler := range schedulers { - c.Assert(expected[scheduler], IsTrue) + re.True(expected[scheduler]) } } @@ -114,16 +97,16 @@ func (s *schedulerTestSuite) TestScheduler(c *C) { } configInfo := make(map[string]interface{}) mustExec([]string{"-u", pdAddr, "scheduler", "config", schedulerName}, &configInfo) - c.Assert(expectedConfig, DeepEquals, configInfo) + re.Equal(expectedConfig, configInfo) } leaderServer := cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) + re.NoError(leaderServer.BootstrapCluster()) for _, store := range stores { - pdctl.MustPutStore(c, leaderServer.GetServer(), store) + pdctl.MustPutStore(re, leaderServer.GetServer(), store) } - pdctl.MustPutRegion(c, cluster, 1, 1, []byte("a"), []byte("b")) + pdctl.MustPutRegion(re, cluster, 1, 1, []byte("a"), []byte("b")) defer cluster.Destroy() time.Sleep(3 * time.Second) @@ -245,12 +228,12 @@ func (s *schedulerTestSuite) TestScheduler(c *C) { }) var roles []string mustExec([]string{"-u", pdAddr, "scheduler", "config", "shuffle-region-scheduler", "show-roles"}, &roles) - c.Assert(roles, DeepEquals, []string{"leader", "follower", "learner"}) + re.Equal([]string{"leader", "follower", "learner"}, roles) mustExec([]string{"-u", pdAddr, "scheduler", "config", "shuffle-region-scheduler", "set-roles", "learner"}, nil) mustExec([]string{"-u", pdAddr, "scheduler", "config", "shuffle-region-scheduler", "show-roles"}, &roles) - c.Assert(roles, DeepEquals, []string{"learner"}) + re.Equal([]string{"learner"}, roles) mustExec([]string{"-u", pdAddr, "scheduler", "config", "shuffle-region-scheduler"}, &roles) - c.Assert(roles, DeepEquals, []string{"learner"}) + re.Equal([]string{"learner"}, roles) // test grant hot region scheduler config checkSchedulerCommand([]string{"-u", pdAddr, "scheduler", "add", "grant-hot-region-scheduler", "1", "1,2,3"}, map[string]bool{ @@ -266,30 +249,30 @@ func (s *schedulerTestSuite) TestScheduler(c *C) { "store-leader-id": float64(1), } mustExec([]string{"-u", pdAddr, "scheduler", "config", "grant-hot-region-scheduler"}, &conf3) - c.Assert(expected3, DeepEquals, conf3) + re.Equal(expected3, conf3) mustExec([]string{"-u", pdAddr, "scheduler", "config", "grant-hot-region-scheduler", "set", "2", "1,2,3"}, nil) expected3["store-leader-id"] = float64(2) mustExec([]string{"-u", pdAddr, "scheduler", "config", "grant-hot-region-scheduler"}, &conf3) - c.Assert(expected3, DeepEquals, conf3) + re.Equal(expected3, conf3) // test balance region config echo := mustExec([]string{"-u", pdAddr, "scheduler", "add", "balance-region-scheduler"}, nil) - c.Assert(strings.Contains(echo, "Success!"), IsTrue) + re.Contains(echo, "Success!") echo = mustExec([]string{"-u", pdAddr, "scheduler", "remove", "balance-region-scheduler"}, nil) - c.Assert(strings.Contains(echo, "Success!"), IsTrue) + re.Contains(echo, "Success!") echo = mustExec([]string{"-u", pdAddr, "scheduler", "remove", "balance-region-scheduler"}, nil) - c.Assert(strings.Contains(echo, "Success!"), IsFalse) + re.NotContains(echo, "Success!") echo = mustExec([]string{"-u", pdAddr, "scheduler", "add", "evict-leader-scheduler", "1"}, nil) - c.Assert(strings.Contains(echo, "Success!"), IsTrue) + re.Contains(echo, "Success!") echo = mustExec([]string{"-u", pdAddr, "scheduler", "remove", "evict-leader-scheduler-1"}, nil) - c.Assert(strings.Contains(echo, "Success!"), IsTrue) + re.Contains(echo, "Success!") echo = mustExec([]string{"-u", pdAddr, "scheduler", "remove", "evict-leader-scheduler-1"}, nil) - c.Assert(strings.Contains(echo, "404"), IsTrue) + re.Contains(echo, "404") // test hot region config echo = mustExec([]string{"-u", pdAddr, "scheduler", "config", "evict-leader-scheduler"}, nil) - c.Assert(strings.Contains(echo, "[404] scheduler not found"), IsTrue) + re.Contains(echo, "[404] scheduler not found") expected1 := map[string]interface{}{ "min-hot-byte-rate": float64(100), "min-hot-key-rate": float64(10), @@ -312,55 +295,55 @@ func (s *schedulerTestSuite) TestScheduler(c *C) { } var conf map[string]interface{} mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-hot-region-scheduler", "list"}, &conf) - c.Assert(conf, DeepEquals, expected1) + re.Equal(expected1, conf) mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-hot-region-scheduler", "show"}, &conf) - c.Assert(conf, DeepEquals, expected1) + re.Equal(expected1, conf) mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-hot-region-scheduler", "set", "src-tolerance-ratio", "1.02"}, nil) expected1["src-tolerance-ratio"] = 1.02 var conf1 map[string]interface{} mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-hot-region-scheduler"}, &conf1) - c.Assert(conf1, DeepEquals, expected1) + re.Equal(expected1, conf1) mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-hot-region-scheduler", "set", "read-priorities", "byte,key"}, nil) expected1["read-priorities"] = []interface{}{"byte", "key"} mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-hot-region-scheduler"}, &conf1) - c.Assert(conf1, DeepEquals, expected1) + re.Equal(expected1, conf1) mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-hot-region-scheduler", "set", "read-priorities", "key"}, nil) mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-hot-region-scheduler"}, &conf1) - c.Assert(conf1, DeepEquals, expected1) + re.Equal(expected1, conf1) mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-hot-region-scheduler", "set", "read-priorities", "key,byte"}, nil) expected1["read-priorities"] = []interface{}{"key", "byte"} mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-hot-region-scheduler"}, &conf1) - c.Assert(conf1, DeepEquals, expected1) + re.Equal(expected1, conf1) mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-hot-region-scheduler", "set", "read-priorities", "foo,bar"}, nil) mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-hot-region-scheduler"}, &conf1) - c.Assert(conf1, DeepEquals, expected1) + re.Equal(expected1, conf1) mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-hot-region-scheduler", "set", "read-priorities", ""}, nil) mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-hot-region-scheduler"}, &conf1) - c.Assert(conf1, DeepEquals, expected1) + re.Equal(expected1, conf1) mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-hot-region-scheduler", "set", "read-priorities", "key,key"}, nil) mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-hot-region-scheduler"}, &conf1) - c.Assert(conf1, DeepEquals, expected1) + re.Equal(expected1, conf1) mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-hot-region-scheduler", "set", "read-priorities", "byte,byte"}, nil) mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-hot-region-scheduler"}, &conf1) - c.Assert(conf1, DeepEquals, expected1) + re.Equal(expected1, conf1) mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-hot-region-scheduler", "set", "read-priorities", "key,key,byte"}, nil) mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-hot-region-scheduler"}, &conf1) - c.Assert(conf1, DeepEquals, expected1) + re.Equal(expected1, conf1) // write-priorities is divided into write-leader-priorities and write-peer-priorities mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-hot-region-scheduler", "set", "write-priorities", "key,byte"}, nil) mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-hot-region-scheduler"}, &conf1) - c.Assert(conf1, DeepEquals, expected1) + re.Equal(expected1, conf1) mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-hot-region-scheduler", "set", "forbid-rw-type", "read"}, nil) mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-hot-region-scheduler"}, &conf1) - c.Assert(conf1, DeepEquals, expected1) + re.Equal(expected1, conf1) // test compatibility for _, store := range stores { version := versioninfo.HotScheduleWithQuery store.Version = versioninfo.MinSupportedVersion(version).String() - pdctl.MustPutStore(c, leaderServer.GetServer(), store) + pdctl.MustPutStore(re, leaderServer.GetServer(), store) mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-hot-region-scheduler"}, &conf1) } conf["read-priorities"] = []interface{}{"query", "byte"} @@ -368,32 +351,32 @@ func (s *schedulerTestSuite) TestScheduler(c *C) { // cannot set qps as write-peer-priorities mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-hot-region-scheduler", "set", "write-peer-priorities", "query,byte"}, nil) mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-hot-region-scheduler"}, &conf1) - c.Assert(conf1, DeepEquals, expected1) + re.Equal(expected1, conf1) // test remove and add mustExec([]string{"-u", pdAddr, "scheduler", "remove", "balance-hot-region-scheduler"}, nil) mustExec([]string{"-u", pdAddr, "scheduler", "add", "balance-hot-region-scheduler"}, nil) - c.Assert(conf1, DeepEquals, expected1) + re.Equal(expected1, conf1) // test balance leader config conf = make(map[string]interface{}) conf1 = make(map[string]interface{}) mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-leader-scheduler", "show"}, &conf) - c.Assert(conf["batch"], Equals, 4.) + re.Equal(4., conf["batch"]) mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-leader-scheduler", "set", "batch", "3"}, nil) mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-leader-scheduler"}, &conf1) - c.Assert(conf1["batch"], Equals, 3.) + re.Equal(3., conf1["batch"]) echo = mustExec([]string{"-u", pdAddr, "scheduler", "add", "balance-leader-scheduler"}, nil) - c.Assert(strings.Contains(echo, "Success!"), IsFalse) + re.NotContains(echo, "Success!") echo = mustExec([]string{"-u", pdAddr, "scheduler", "remove", "balance-leader-scheduler"}, nil) - c.Assert(strings.Contains(echo, "Success!"), IsTrue) + re.Contains(echo, "Success!") echo = mustExec([]string{"-u", pdAddr, "scheduler", "remove", "balance-leader-scheduler"}, nil) - c.Assert(strings.Contains(echo, "404"), IsTrue) - c.Assert(strings.Contains(echo, "PD:scheduler:ErrSchedulerNotFound]scheduler not found"), IsTrue) + re.Contains(echo, "404") + re.Contains(echo, "PD:scheduler:ErrSchedulerNotFound]scheduler not found") echo = mustExec([]string{"-u", pdAddr, "scheduler", "config", "balance-leader-scheduler"}, nil) - c.Assert(strings.Contains(echo, "404"), IsTrue) - c.Assert(strings.Contains(echo, "scheduler not found"), IsTrue) + re.Contains(echo, "404") + re.Contains(echo, "scheduler not found") echo = mustExec([]string{"-u", pdAddr, "scheduler", "add", "balance-leader-scheduler"}, nil) - c.Assert(strings.Contains(echo, "Success!"), IsTrue) + re.Contains(echo, "Success!") // test show scheduler with paused and disabled status. checkSchedulerWithStatusCommand := func(args []string, status string, expected []string) { @@ -402,7 +385,7 @@ func (s *schedulerTestSuite) TestScheduler(c *C) { } var schedulers []string mustExec([]string{"-u", pdAddr, "scheduler", "show", "--status", status}, &schedulers) - c.Assert(schedulers, DeepEquals, expected) + re.Equal(expected, schedulers) } mustUsage([]string{"-u", pdAddr, "scheduler", "pause", "balance-leader-scheduler"}) @@ -417,26 +400,26 @@ func (s *schedulerTestSuite) TestScheduler(c *C) { // set label scheduler to disabled manually. echo = mustExec([]string{"-u", pdAddr, "scheduler", "add", "label-scheduler"}, nil) - c.Assert(strings.Contains(echo, "Success!"), IsTrue) + re.Contains(echo, "Success!") cfg := leaderServer.GetServer().GetScheduleConfig() origin := cfg.Schedulers cfg.Schedulers = config.SchedulerConfigs{{Type: "label", Disable: true}} err = leaderServer.GetServer().SetScheduleConfig(*cfg) - c.Assert(err, IsNil) + re.NoError(err) checkSchedulerWithStatusCommand(nil, "disabled", []string{"label-scheduler"}) // reset Schedulers in ScheduleConfig cfg.Schedulers = origin err = leaderServer.GetServer().SetScheduleConfig(*cfg) - c.Assert(err, IsNil) + re.NoError(err) checkSchedulerWithStatusCommand(nil, "disabled", nil) // test split bucket scheduler echo = mustExec([]string{"-u", pdAddr, "scheduler", "config", "split-bucket-scheduler"}, nil) - c.Assert(strings.Contains(echo, "\"degree\": 3"), IsTrue) + re.Contains(echo, "\"degree\": 3") echo = mustExec([]string{"-u", pdAddr, "scheduler", "config", "split-bucket-scheduler", "set", "degree", "10"}, nil) - c.Assert(strings.Contains(echo, "Success"), IsTrue) + re.Contains(echo, "Success") echo = mustExec([]string{"-u", pdAddr, "scheduler", "config", "split-bucket-scheduler"}, nil) - c.Assert(strings.Contains(echo, "\"degree\": 10"), IsTrue) + re.Contains(echo, "\"degree\": 10") echo = mustExec([]string{"-u", pdAddr, "scheduler", "remove", "split-bucket-scheduler"}, nil) - c.Assert(strings.Contains(echo, "Success!"), IsTrue) + re.Contains(echo, "Success!") } diff --git a/tests/pdctl/store/store_test.go b/tests/pdctl/store/store_test.go index a43d70722e8..c2c9420d01a 100644 --- a/tests/pdctl/store/store_test.go +++ b/tests/pdctl/store/store_test.go @@ -21,8 +21,8 @@ import ( "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/server/api" "github.com/tikv/pd/server/core" "github.com/tikv/pd/server/core/storelimit" @@ -32,21 +32,14 @@ import ( cmd "github.com/tikv/pd/tools/pd-ctl/pdctl" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&storeTestSuite{}) - -type storeTestSuite struct{} - -func (s *storeTestSuite) TestStore(c *C) { +func TestStore(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() cluster, err := tests.NewTestCluster(ctx, 1) - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() pdAddr := cluster.GetConfig().GetClientURL() cmd := cmd.GetRootCmd() @@ -88,180 +81,189 @@ func (s *storeTestSuite) TestStore(c *C) { } leaderServer := cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) + re.NoError(leaderServer.BootstrapCluster()) for _, store := range stores { - pdctl.MustPutStore(c, leaderServer.GetServer(), store.Store.Store) + pdctl.MustPutStore(re, leaderServer.GetServer(), store.Store.Store) } defer cluster.Destroy() // store command args := []string{"-u", pdAddr, "store"} output, err := pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) storesInfo := new(api.StoresInfo) - c.Assert(json.Unmarshal(output, &storesInfo), IsNil) - pdctl.CheckStoresInfo(c, storesInfo.Stores, stores[:2]) + re.NoError(json.Unmarshal(output, &storesInfo)) + + pdctl.CheckStoresInfo(re, storesInfo.Stores, stores[:2]) // store --state= command args = []string{"-u", pdAddr, "store", "--state", "Up,Tombstone"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "\"state\":"), Equals, false) + re.NoError(err) + re.Equal(false, strings.Contains(string(output), "\"state\":")) storesInfo = new(api.StoresInfo) - c.Assert(json.Unmarshal(output, &storesInfo), IsNil) - pdctl.CheckStoresInfo(c, storesInfo.Stores, stores) + re.NoError(json.Unmarshal(output, &storesInfo)) + + pdctl.CheckStoresInfo(re, storesInfo.Stores, stores) // store command args = []string{"-u", pdAddr, "store", "1"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) storeInfo := new(api.StoreInfo) - c.Assert(json.Unmarshal(output, &storeInfo), IsNil) - pdctl.CheckStoresInfo(c, []*api.StoreInfo{storeInfo}, stores[:1]) + re.NoError(json.Unmarshal(output, &storeInfo)) + + pdctl.CheckStoresInfo(re, []*api.StoreInfo{storeInfo}, stores[:1]) // store label [ ]... [flags] command - c.Assert(storeInfo.Store.Labels, IsNil) + re.Nil(storeInfo.Store.Labels) + args = []string{"-u", pdAddr, "store", "label", "1", "zone", "cn"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) args = []string{"-u", pdAddr, "store", "1"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - c.Assert(json.Unmarshal(output, &storeInfo), IsNil) + re.NoError(err) + re.NoError(json.Unmarshal(output, &storeInfo)) + label := storeInfo.Store.Labels[0] - c.Assert(label.Key, Equals, "zone") - c.Assert(label.Value, Equals, "cn") + re.Equal("zone", label.Key) + re.Equal("cn", label.Value) // store label ... command args = []string{"-u", pdAddr, "store", "label", "1", "zone", "us", "language", "English"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) args = []string{"-u", pdAddr, "store", "1"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - c.Assert(json.Unmarshal(output, &storeInfo), IsNil) + re.NoError(err) + re.NoError(json.Unmarshal(output, &storeInfo)) + label0 := storeInfo.Store.Labels[0] - c.Assert(label0.Key, Equals, "zone") - c.Assert(label0.Value, Equals, "us") + re.Equal("zone", label0.Key) + re.Equal("us", label0.Value) label1 := storeInfo.Store.Labels[1] - c.Assert(label1.Key, Equals, "language") - c.Assert(label1.Value, Equals, "English") + re.Equal("language", label1.Key) + re.Equal("English", label1.Value) // store label ... -f command args = []string{"-u", pdAddr, "store", "label", "1", "zone", "uk", "-f"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) args = []string{"-u", pdAddr, "store", "1"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - c.Assert(json.Unmarshal(output, &storeInfo), IsNil) + re.NoError(err) + re.NoError(json.Unmarshal(output, &storeInfo)) + label0 = storeInfo.Store.Labels[0] - c.Assert(label0.Key, Equals, "zone") - c.Assert(label0.Value, Equals, "uk") - c.Assert(storeInfo.Store.Labels, HasLen, 1) + re.Equal("zone", label0.Key) + re.Equal("uk", label0.Value) + re.Len(storeInfo.Store.Labels, 1) // store weight command - c.Assert(storeInfo.Status.LeaderWeight, Equals, float64(1)) - c.Assert(storeInfo.Status.RegionWeight, Equals, float64(1)) + re.Equal(float64(1), storeInfo.Status.LeaderWeight) + re.Equal(float64(1), storeInfo.Status.RegionWeight) args = []string{"-u", pdAddr, "store", "weight", "1", "5", "10"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) args = []string{"-u", pdAddr, "store", "1"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - c.Assert(json.Unmarshal(output, &storeInfo), IsNil) - c.Assert(storeInfo.Status.LeaderWeight, Equals, float64(5)) - c.Assert(storeInfo.Status.RegionWeight, Equals, float64(10)) + re.NoError(err) + re.NoError(json.Unmarshal(output, &storeInfo)) + + re.Equal(float64(5), storeInfo.Status.LeaderWeight) + re.Equal(float64(10), storeInfo.Status.RegionWeight) // store limit args = []string{"-u", pdAddr, "store", "limit", "1", "10"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) limit := leaderServer.GetRaftCluster().GetStoreLimitByType(1, storelimit.AddPeer) - c.Assert(limit, Equals, float64(10)) + re.Equal(float64(10), limit) limit = leaderServer.GetRaftCluster().GetStoreLimitByType(1, storelimit.RemovePeer) - c.Assert(limit, Equals, float64(10)) + re.Equal(float64(10), limit) // store limit args = []string{"-u", pdAddr, "store", "limit", "1", "5", "remove-peer"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) limit = leaderServer.GetRaftCluster().GetStoreLimitByType(1, storelimit.RemovePeer) - c.Assert(limit, Equals, float64(5)) + re.Equal(float64(5), limit) limit = leaderServer.GetRaftCluster().GetStoreLimitByType(1, storelimit.AddPeer) - c.Assert(limit, Equals, float64(10)) + re.Equal(float64(10), limit) // store limit all args = []string{"-u", pdAddr, "store", "limit", "all", "20"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) limit1 := leaderServer.GetRaftCluster().GetStoreLimitByType(1, storelimit.AddPeer) limit2 := leaderServer.GetRaftCluster().GetStoreLimitByType(2, storelimit.AddPeer) limit3 := leaderServer.GetRaftCluster().GetStoreLimitByType(3, storelimit.AddPeer) - c.Assert(limit1, Equals, float64(20)) - c.Assert(limit2, Equals, float64(20)) - c.Assert(limit3, Equals, float64(20)) + re.Equal(float64(20), limit1) + re.Equal(float64(20), limit2) + re.Equal(float64(20), limit3) limit1 = leaderServer.GetRaftCluster().GetStoreLimitByType(1, storelimit.RemovePeer) limit2 = leaderServer.GetRaftCluster().GetStoreLimitByType(2, storelimit.RemovePeer) limit3 = leaderServer.GetRaftCluster().GetStoreLimitByType(3, storelimit.RemovePeer) - c.Assert(limit1, Equals, float64(20)) - c.Assert(limit2, Equals, float64(20)) - c.Assert(limit3, Equals, float64(20)) + re.Equal(float64(20), limit1) + re.Equal(float64(20), limit2) + re.Equal(float64(20), limit3) + + re.NoError(leaderServer.Stop()) + re.NoError(leaderServer.Run()) - c.Assert(leaderServer.Stop(), IsNil) - c.Assert(leaderServer.Run(), IsNil) cluster.WaitLeader() storesLimit := leaderServer.GetPersistOptions().GetAllStoresLimit() - c.Assert(storesLimit[1].AddPeer, Equals, float64(20)) - c.Assert(storesLimit[1].RemovePeer, Equals, float64(20)) + re.Equal(float64(20), storesLimit[1].AddPeer) + re.Equal(float64(20), storesLimit[1].RemovePeer) // store limit all args = []string{"-u", pdAddr, "store", "limit", "all", "25", "remove-peer"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) limit1 = leaderServer.GetRaftCluster().GetStoreLimitByType(1, storelimit.RemovePeer) limit3 = leaderServer.GetRaftCluster().GetStoreLimitByType(3, storelimit.RemovePeer) - c.Assert(limit1, Equals, float64(25)) - c.Assert(limit3, Equals, float64(25)) + re.Equal(float64(25), limit1) + re.Equal(float64(25), limit3) limit2 = leaderServer.GetRaftCluster().GetStoreLimitByType(2, storelimit.RemovePeer) - c.Assert(limit2, Equals, float64(25)) + re.Equal(float64(25), limit2) // store limit all args = []string{"-u", pdAddr, "store", "limit", "all", "zone", "uk", "20", "remove-peer"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) limit1 = leaderServer.GetRaftCluster().GetStoreLimitByType(1, storelimit.RemovePeer) - c.Assert(limit1, Equals, float64(20)) + re.Equal(float64(20), limit1) // store limit all 0 is invalid args = []string{"-u", pdAddr, "store", "limit", "all", "0"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "rate should be a number that > 0"), IsTrue) + re.NoError(err) + re.Contains(string(output), "rate should be a number that > 0") // store limit args = []string{"-u", pdAddr, "store", "limit"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) allAddPeerLimit := make(map[string]map[string]interface{}) json.Unmarshal(output, &allAddPeerLimit) - c.Assert(allAddPeerLimit["1"]["add-peer"].(float64), Equals, float64(20)) - c.Assert(allAddPeerLimit["3"]["add-peer"].(float64), Equals, float64(20)) + re.Equal(float64(20), allAddPeerLimit["1"]["add-peer"].(float64)) + re.Equal(float64(20), allAddPeerLimit["3"]["add-peer"].(float64)) _, ok := allAddPeerLimit["2"]["add-peer"] - c.Assert(ok, IsFalse) + re.False(ok) args = []string{"-u", pdAddr, "store", "limit", "remove-peer"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) allRemovePeerLimit := make(map[string]map[string]interface{}) json.Unmarshal(output, &allRemovePeerLimit) - c.Assert(allRemovePeerLimit["1"]["remove-peer"].(float64), Equals, float64(20)) - c.Assert(allRemovePeerLimit["3"]["remove-peer"].(float64), Equals, float64(25)) + re.Equal(float64(20), allRemovePeerLimit["1"]["remove-peer"].(float64)) + re.Equal(float64(25), allRemovePeerLimit["3"]["remove-peer"].(float64)) _, ok = allRemovePeerLimit["2"]["add-peer"] - c.Assert(ok, IsFalse) + re.False(ok) // put enough stores for replica. for id := 1000; id <= 1005; id++ { @@ -271,172 +273,179 @@ func (s *storeTestSuite) TestStore(c *C) { NodeState: metapb.NodeState_Serving, LastHeartbeat: time.Now().UnixNano(), } - pdctl.MustPutStore(c, leaderServer.GetServer(), store2) + pdctl.MustPutStore(re, leaderServer.GetServer(), store2) } // store delete command storeInfo.Store.State = metapb.StoreState(metapb.StoreState_value[storeInfo.Store.StateName]) - c.Assert(storeInfo.Store.State, Equals, metapb.StoreState_Up) + re.Equal(metapb.StoreState_Up, storeInfo.Store.State) args = []string{"-u", pdAddr, "store", "delete", "1"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) args = []string{"-u", pdAddr, "store", "1"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) storeInfo = new(api.StoreInfo) - c.Assert(json.Unmarshal(output, &storeInfo), IsNil) + re.NoError(json.Unmarshal(output, &storeInfo)) + storeInfo.Store.State = metapb.StoreState(metapb.StoreState_value[storeInfo.Store.StateName]) - c.Assert(storeInfo.Store.State, Equals, metapb.StoreState_Offline) + re.Equal(metapb.StoreState_Offline, storeInfo.Store.State) // store check status args = []string{"-u", pdAddr, "store", "check", "Offline"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "\"id\": 1,"), IsTrue) + re.NoError(err) + re.Contains(string(output), "\"id\": 1,") args = []string{"-u", pdAddr, "store", "check", "Tombstone"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "\"id\": 2,"), IsTrue) + re.NoError(err) + re.Contains(string(output), "\"id\": 2,") args = []string{"-u", pdAddr, "store", "check", "Up"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "\"id\": 3,"), IsTrue) + re.NoError(err) + re.Contains(string(output), "\"id\": 3,") args = []string{"-u", pdAddr, "store", "check", "Invalid_State"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "Unknown state: Invalid_state"), IsTrue) + re.NoError(err) + re.Contains(string(output), "Unknown state: Invalid_state") // store cancel-delete command limit = leaderServer.GetRaftCluster().GetStoreLimitByType(1, storelimit.RemovePeer) - c.Assert(limit, Equals, storelimit.Unlimited) + re.Equal(storelimit.Unlimited, limit) args = []string{"-u", pdAddr, "store", "cancel-delete", "1"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) args = []string{"-u", pdAddr, "store", "1"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) storeInfo = new(api.StoreInfo) - c.Assert(json.Unmarshal(output, &storeInfo), IsNil) - c.Assert(storeInfo.Store.State, Equals, metapb.StoreState_Up) + re.NoError(json.Unmarshal(output, &storeInfo)) + + re.Equal(metapb.StoreState_Up, storeInfo.Store.State) limit = leaderServer.GetRaftCluster().GetStoreLimitByType(1, storelimit.RemovePeer) - c.Assert(limit, Equals, 20.0) + re.Equal(20.0, limit) // store delete addr
args = []string{"-u", pdAddr, "store", "delete", "addr", "tikv3"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(string(output), Equals, "Success!\n") - c.Assert(err, IsNil) + re.Equal("Success!\n", string(output)) + re.NoError(err) args = []string{"-u", pdAddr, "store", "3"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) storeInfo = new(api.StoreInfo) - c.Assert(json.Unmarshal(output, &storeInfo), IsNil) + re.NoError(json.Unmarshal(output, &storeInfo)) + storeInfo.Store.State = metapb.StoreState(metapb.StoreState_value[storeInfo.Store.StateName]) - c.Assert(storeInfo.Store.State, Equals, metapb.StoreState_Offline) + re.Equal(metapb.StoreState_Offline, storeInfo.Store.State) // store cancel-delete addr
limit = leaderServer.GetRaftCluster().GetStoreLimitByType(3, storelimit.RemovePeer) - c.Assert(limit, Equals, storelimit.Unlimited) + re.Equal(storelimit.Unlimited, limit) args = []string{"-u", pdAddr, "store", "cancel-delete", "addr", "tikv3"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(string(output), Equals, "Success!\n") - c.Assert(err, IsNil) + re.Equal("Success!\n", string(output)) + re.NoError(err) args = []string{"-u", pdAddr, "store", "3"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) storeInfo = new(api.StoreInfo) - c.Assert(json.Unmarshal(output, &storeInfo), IsNil) - c.Assert(storeInfo.Store.State, Equals, metapb.StoreState_Up) + re.NoError(json.Unmarshal(output, &storeInfo)) + + re.Equal(metapb.StoreState_Up, storeInfo.Store.State) limit = leaderServer.GetRaftCluster().GetStoreLimitByType(3, storelimit.RemovePeer) - c.Assert(limit, Equals, 25.0) + re.Equal(25.0, limit) // store remove-tombstone args = []string{"-u", pdAddr, "store", "check", "Tombstone"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) storesInfo = new(api.StoresInfo) - c.Assert(json.Unmarshal(output, &storesInfo), IsNil) - c.Assert(storesInfo.Count, Equals, 1) + re.NoError(json.Unmarshal(output, &storesInfo)) + + re.Equal(1, storesInfo.Count) args = []string{"-u", pdAddr, "store", "remove-tombstone"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) args = []string{"-u", pdAddr, "store", "check", "Tombstone"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) storesInfo = new(api.StoresInfo) - c.Assert(json.Unmarshal(output, &storesInfo), IsNil) - c.Assert(storesInfo.Count, Equals, 0) + re.NoError(json.Unmarshal(output, &storesInfo)) + + re.Equal(0, storesInfo.Count) // It should be called after stores remove-tombstone. args = []string{"-u", pdAddr, "stores", "show", "limit"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "PANIC"), IsFalse) + re.NoError(err) + re.NotContains(string(output), "PANIC") args = []string{"-u", pdAddr, "stores", "show", "limit", "remove-peer"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "PANIC"), IsFalse) + re.NoError(err) + re.NotContains(string(output), "PANIC") args = []string{"-u", pdAddr, "stores", "show", "limit", "add-peer"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "PANIC"), IsFalse) + re.NoError(err) + re.NotContains(string(output), "PANIC") // store limit-scene args = []string{"-u", pdAddr, "store", "limit-scene"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) scene := &storelimit.Scene{} err = json.Unmarshal(output, scene) - c.Assert(err, IsNil) - c.Assert(scene, DeepEquals, storelimit.DefaultScene(storelimit.AddPeer)) + re.NoError(err) + re.Equal(storelimit.DefaultScene(storelimit.AddPeer), scene) // store limit-scene args = []string{"-u", pdAddr, "store", "limit-scene", "idle", "200"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) args = []string{"-u", pdAddr, "store", "limit-scene"} scene = &storelimit.Scene{} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) err = json.Unmarshal(output, scene) - c.Assert(err, IsNil) - c.Assert(scene.Idle, Equals, 200) + re.NoError(err) + re.Equal(200, scene.Idle) // store limit-scene args = []string{"-u", pdAddr, "store", "limit-scene", "idle", "100", "remove-peer"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) args = []string{"-u", pdAddr, "store", "limit-scene", "remove-peer"} scene = &storelimit.Scene{} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) err = json.Unmarshal(output, scene) - c.Assert(err, IsNil) - c.Assert(scene.Idle, Equals, 100) + re.NoError(err) + re.Equal(100, scene.Idle) // store limit all 201 is invalid for all args = []string{"-u", pdAddr, "store", "limit", "all", "201"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "rate should less than"), IsTrue) + re.NoError(err) + re.Contains(string(output), "rate should less than") // store limit all 201 is invalid for label args = []string{"-u", pdAddr, "store", "limit", "all", "engine", "key", "201", "add-peer"} output, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - c.Assert(strings.Contains(string(output), "rate should less than"), IsTrue) + re.NoError(err) + re.Contains(string(output), "rate should less than") } // https://github.com/tikv/pd/issues/5024 -func (s *storeTestSuite) TestTombstoneStore(c *C) { +func TestTombstoneStore(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() cluster, err := tests.NewTestCluster(ctx, 1) - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() pdAddr := cluster.GetConfig().GetClientURL() cmd := cmd.GetRootCmd() @@ -478,18 +487,19 @@ func (s *storeTestSuite) TestTombstoneStore(c *C) { } leaderServer := cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) + re.NoError(leaderServer.BootstrapCluster()) for _, store := range stores { - pdctl.MustPutStore(c, leaderServer.GetServer(), store.Store.Store) + pdctl.MustPutStore(re, leaderServer.GetServer(), store.Store.Store) } defer cluster.Destroy() - pdctl.MustPutRegion(c, cluster, 1, 2, []byte("a"), []byte("b"), core.SetWrittenBytes(3000000000), core.SetReportInterval(statistics.WriteReportInterval)) - pdctl.MustPutRegion(c, cluster, 2, 3, []byte("b"), []byte("c"), core.SetWrittenBytes(3000000000), core.SetReportInterval(statistics.WriteReportInterval)) + pdctl.MustPutRegion(re, cluster, 1, 2, []byte("a"), []byte("b"), core.SetWrittenBytes(3000000000), core.SetReportInterval(statistics.WriteReportInterval)) + pdctl.MustPutRegion(re, cluster, 2, 3, []byte("b"), []byte("c"), core.SetWrittenBytes(3000000000), core.SetReportInterval(statistics.WriteReportInterval)) // store remove-tombstone args := []string{"-u", pdAddr, "store", "remove-tombstone"} output, err := pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) message := string(output) - c.Assert(strings.Contains(message, "2") && strings.Contains(message, "3"), IsTrue) + re.Contains(message, "2") + re.Contains(message, "3") } diff --git a/tests/pdctl/tso/tso_test.go b/tests/pdctl/tso/tso_test.go index 1d2cdb77dc0..f6295424ddc 100644 --- a/tests/pdctl/tso/tso_test.go +++ b/tests/pdctl/tso/tso_test.go @@ -20,20 +20,13 @@ import ( "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/tests/pdctl" pdctlCmd "github.com/tikv/pd/tools/pd-ctl/pdctl" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&tsoTestSuite{}) - -type tsoTestSuite struct{} - -func (s *tsoTestSuite) TestTSO(c *C) { +func TestTSO(t *testing.T) { + re := require.New(t) cmd := pdctlCmd.GetRootCmd() const ( @@ -45,13 +38,12 @@ func (s *tsoTestSuite) TestTSO(c *C) { ts := "395181938313123110" args := []string{"-u", "127.0.0.1", "tso", ts} output, err := pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) - t, e := strconv.ParseUint(ts, 10, 64) - c.Assert(e, IsNil) - c.Assert(err, IsNil) - logicalTime := t & logicalBits - physical := t >> physicalShiftBits + re.NoError(err) + tsTime, err := strconv.ParseUint(ts, 10, 64) + re.NoError(err) + logicalTime := tsTime & logicalBits + physical := tsTime >> physicalShiftBits physicalTime := time.Unix(int64(physical/1000), int64(physical%1000)*time.Millisecond.Nanoseconds()) str := fmt.Sprintln("system: ", physicalTime) + fmt.Sprintln("logic: ", logicalTime) - c.Assert(str, Equals, string(output)) + re.Equal(string(output), str) } diff --git a/tests/pdctl/unsafe/unsafe_operation_test.go b/tests/pdctl/unsafe/unsafe_operation_test.go index 4bbe2309dc3..1e4e3468225 100644 --- a/tests/pdctl/unsafe/unsafe_operation_test.go +++ b/tests/pdctl/unsafe/unsafe_operation_test.go @@ -18,47 +18,40 @@ import ( "context" "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/tests" "github.com/tikv/pd/tests/pdctl" pdctlCmd "github.com/tikv/pd/tools/pd-ctl/pdctl" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&unsafeOperationTestSuite{}) - -type unsafeOperationTestSuite struct{} - -func (s *unsafeOperationTestSuite) TestRemoveFailedStores(c *C) { +func TestRemoveFailedStores(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() cluster, err := tests.NewTestCluster(ctx, 1) - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() err = cluster.GetServer(cluster.GetLeader()).BootstrapCluster() - c.Assert(err, IsNil) + re.NoError(err) pdAddr := cluster.GetConfig().GetClientURL() cmd := pdctlCmd.GetRootCmd() defer cluster.Destroy() args := []string{"-u", pdAddr, "unsafe", "remove-failed-stores", "1,2,3"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) args = []string{"-u", pdAddr, "unsafe", "remove-failed-stores", "1,2,3", "--timeout", "3600"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) args = []string{"-u", pdAddr, "unsafe", "remove-failed-stores", "1,2,3", "--timeout", "abc"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, Not(IsNil)) + re.Error(err) args = []string{"-u", pdAddr, "unsafe", "remove-failed-stores", "show"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) args = []string{"-u", pdAddr, "unsafe", "remove-failed-stores", "history"} _, err = pdctl.ExecuteCommand(cmd, args...) - c.Assert(err, IsNil) + re.NoError(err) } diff --git a/tests/server/api/api_test.go b/tests/server/api/api_test.go index 9e6248f9cec..e462adae2ba 100644 --- a/tests/server/api/api_test.go +++ b/tests/server/api/api_test.go @@ -542,12 +542,12 @@ func (s *testProgressSuite) TestRemovingProgress(c *C) { } for _, store := range stores { - pdctl.MustPutStore(c, leader.GetServer(), store) + pdctl.MustPutStoreWithCheck(c, leader.GetServer(), store) } - pdctl.MustPutRegion(c, cluster, 1000, 1, []byte("a"), []byte("b"), core.SetApproximateSize(60)) - pdctl.MustPutRegion(c, cluster, 1001, 2, []byte("c"), []byte("d"), core.SetApproximateSize(30)) - pdctl.MustPutRegion(c, cluster, 1002, 1, []byte("e"), []byte("f"), core.SetApproximateSize(50)) - pdctl.MustPutRegion(c, cluster, 1003, 2, []byte("g"), []byte("h"), core.SetApproximateSize(40)) + pdctl.MustPutRegionWithCheck(c, cluster, 1000, 1, []byte("a"), []byte("b"), core.SetApproximateSize(60)) + pdctl.MustPutRegionWithCheck(c, cluster, 1001, 2, []byte("c"), []byte("d"), core.SetApproximateSize(30)) + pdctl.MustPutRegionWithCheck(c, cluster, 1002, 1, []byte("e"), []byte("f"), core.SetApproximateSize(50)) + pdctl.MustPutRegionWithCheck(c, cluster, 1003, 2, []byte("g"), []byte("h"), core.SetApproximateSize(40)) // no store removing output := sendRequest(c, leader.GetAddr()+"/pd/api/v1/stores/progress?action=removing", http.MethodGet, http.StatusNotFound) @@ -569,8 +569,8 @@ func (s *testProgressSuite) TestRemovingProgress(c *C) { c.Assert(p.LeftSeconds, Equals, math.MaxFloat64) // update size - pdctl.MustPutRegion(c, cluster, 1000, 1, []byte("a"), []byte("b"), core.SetApproximateSize(20)) - pdctl.MustPutRegion(c, cluster, 1001, 2, []byte("c"), []byte("d"), core.SetApproximateSize(10)) + pdctl.MustPutRegionWithCheck(c, cluster, 1000, 1, []byte("a"), []byte("b"), core.SetApproximateSize(20)) + pdctl.MustPutRegionWithCheck(c, cluster, 1001, 2, []byte("c"), []byte("d"), core.SetApproximateSize(10)) // is not prepared time.Sleep(2 * time.Second) @@ -675,10 +675,10 @@ func (s *testProgressSuite) TestPreparingProgress(c *C) { } for _, store := range stores { - pdctl.MustPutStore(c, leader.GetServer(), store) + pdctl.MustPutStoreWithCheck(c, leader.GetServer(), store) } for i := 0; i < 100; i++ { - pdctl.MustPutRegion(c, cluster, uint64(i+1), uint64(i)%3+1, []byte(fmt.Sprintf("p%d", i)), []byte(fmt.Sprintf("%d", i+1)), core.SetApproximateSize(10)) + pdctl.MustPutRegionWithCheck(c, cluster, uint64(i+1), uint64(i)%3+1, []byte(fmt.Sprintf("p%d", i)), []byte(fmt.Sprintf("%d", i+1)), core.SetApproximateSize(10)) } // no store preparing output := sendRequest(c, leader.GetAddr()+"/pd/api/v1/stores/progress?action=preparing", http.MethodGet, http.StatusNotFound) @@ -705,8 +705,8 @@ func (s *testProgressSuite) TestPreparingProgress(c *C) { c.Assert(p.LeftSeconds, Equals, math.MaxFloat64) // update size - pdctl.MustPutRegion(c, cluster, 1000, 4, []byte(fmt.Sprintf("%d", 1000)), []byte(fmt.Sprintf("%d", 1001)), core.SetApproximateSize(10)) - pdctl.MustPutRegion(c, cluster, 1001, 5, []byte(fmt.Sprintf("%d", 1001)), []byte(fmt.Sprintf("%d", 1002)), core.SetApproximateSize(40)) + pdctl.MustPutRegionWithCheck(c, cluster, 1000, 4, []byte(fmt.Sprintf("%d", 1000)), []byte(fmt.Sprintf("%d", 1001)), core.SetApproximateSize(10)) + pdctl.MustPutRegionWithCheck(c, cluster, 1001, 5, []byte(fmt.Sprintf("%d", 1001)), []byte(fmt.Sprintf("%d", 1002)), core.SetApproximateSize(40)) time.Sleep(2 * time.Second) output = sendRequest(c, leader.GetAddr()+"/pd/api/v1/stores/progress?action=preparing", http.MethodGet, http.StatusOK) c.Assert(json.Unmarshal(output, &p), IsNil) diff --git a/tests/server/storage/hot_region_storage_test.go b/tests/server/storage/hot_region_storage_test.go index 5a11f8c23c4..662f128dd1b 100644 --- a/tests/server/storage/hot_region_storage_test.go +++ b/tests/server/storage/hot_region_storage_test.go @@ -69,14 +69,14 @@ func (s *hotRegionHistorySuite) TestHotRegionStorage(c *C) { leaderServer := cluster.GetServer(cluster.GetLeader()) c.Assert(leaderServer.BootstrapCluster(), IsNil) for _, store := range stores { - pdctl.MustPutStore(c, leaderServer.GetServer(), store) + pdctl.MustPutStoreWithCheck(c, leaderServer.GetServer(), store) } defer cluster.Destroy() startTime := time.Now().UnixNano() / int64(time.Millisecond) - pdctl.MustPutRegion(c, cluster, 1, 1, []byte("a"), []byte("b"), core.SetWrittenBytes(3000000000), core.SetReportInterval(statistics.WriteReportInterval)) - pdctl.MustPutRegion(c, cluster, 2, 2, []byte("c"), []byte("d"), core.SetWrittenBytes(6000000000), core.SetReportInterval(statistics.WriteReportInterval)) - pdctl.MustPutRegion(c, cluster, 3, 1, []byte("e"), []byte("f")) - pdctl.MustPutRegion(c, cluster, 4, 2, []byte("g"), []byte("h")) + pdctl.MustPutRegionWithCheck(c, cluster, 1, 1, []byte("a"), []byte("b"), core.SetWrittenBytes(3000000000), core.SetReportInterval(statistics.WriteReportInterval)) + pdctl.MustPutRegionWithCheck(c, cluster, 2, 2, []byte("c"), []byte("d"), core.SetWrittenBytes(6000000000), core.SetReportInterval(statistics.WriteReportInterval)) + pdctl.MustPutRegionWithCheck(c, cluster, 3, 1, []byte("e"), []byte("f")) + pdctl.MustPutRegionWithCheck(c, cluster, 4, 2, []byte("g"), []byte("h")) storeStats := []*pdpb.StoreStats{ { StoreId: 1, @@ -172,11 +172,11 @@ func (s *hotRegionHistorySuite) TestHotRegionStorageReservedDayConfigChange(c *C leaderServer := cluster.GetServer(cluster.GetLeader()) c.Assert(leaderServer.BootstrapCluster(), IsNil) for _, store := range stores { - pdctl.MustPutStore(c, leaderServer.GetServer(), store) + pdctl.MustPutStoreWithCheck(c, leaderServer.GetServer(), store) } defer cluster.Destroy() startTime := time.Now().UnixNano() / int64(time.Millisecond) - pdctl.MustPutRegion(c, cluster, 1, 1, []byte("a"), []byte("b"), core.SetWrittenBytes(3000000000), core.SetReportInterval(statistics.WriteReportInterval)) + pdctl.MustPutRegionWithCheck(c, cluster, 1, 1, []byte("a"), []byte("b"), core.SetWrittenBytes(3000000000), core.SetReportInterval(statistics.WriteReportInterval)) // wait hot scheduler starts time.Sleep(5000 * time.Millisecond) endTime := time.Now().UnixNano() / int64(time.Millisecond) @@ -196,7 +196,7 @@ func (s *hotRegionHistorySuite) TestHotRegionStorageReservedDayConfigChange(c *C schedule.HotRegionsReservedDays = 0 leaderServer.GetServer().SetScheduleConfig(schedule) time.Sleep(3 * interval) - pdctl.MustPutRegion(c, cluster, 2, 2, []byte("c"), []byte("d"), core.SetWrittenBytes(6000000000), core.SetReportInterval(statistics.WriteReportInterval)) + pdctl.MustPutRegionWithCheck(c, cluster, 2, 2, []byte("c"), []byte("d"), core.SetWrittenBytes(6000000000), core.SetReportInterval(statistics.WriteReportInterval)) time.Sleep(10 * interval) endTime = time.Now().UnixNano() / int64(time.Millisecond) hotRegionStorage = leaderServer.GetServer().GetHistoryHotRegionStorage() @@ -263,11 +263,11 @@ func (s *hotRegionHistorySuite) TestHotRegionStorageWriteIntervalConfigChange(c leaderServer := cluster.GetServer(cluster.GetLeader()) c.Assert(leaderServer.BootstrapCluster(), IsNil) for _, store := range stores { - pdctl.MustPutStore(c, leaderServer.GetServer(), store) + pdctl.MustPutStoreWithCheck(c, leaderServer.GetServer(), store) } defer cluster.Destroy() startTime := time.Now().UnixNano() / int64(time.Millisecond) - pdctl.MustPutRegion(c, cluster, 1, 1, []byte("a"), []byte("b"), core.SetWrittenBytes(3000000000), core.SetReportInterval(statistics.WriteReportInterval)) + pdctl.MustPutRegionWithCheck(c, cluster, 1, 1, []byte("a"), []byte("b"), core.SetWrittenBytes(3000000000), core.SetReportInterval(statistics.WriteReportInterval)) // wait hot scheduler starts time.Sleep(5000 * time.Millisecond) endTime := time.Now().UnixNano() / int64(time.Millisecond) @@ -287,7 +287,7 @@ func (s *hotRegionHistorySuite) TestHotRegionStorageWriteIntervalConfigChange(c schedule.HotRegionsWriteInterval.Duration = 20 * interval leaderServer.GetServer().SetScheduleConfig(schedule) time.Sleep(3 * interval) - pdctl.MustPutRegion(c, cluster, 2, 2, []byte("c"), []byte("d"), core.SetWrittenBytes(6000000000), core.SetReportInterval(statistics.WriteReportInterval)) + pdctl.MustPutRegionWithCheck(c, cluster, 2, 2, []byte("c"), []byte("d"), core.SetWrittenBytes(6000000000), core.SetReportInterval(statistics.WriteReportInterval)) time.Sleep(10 * interval) endTime = time.Now().UnixNano() / int64(time.Millisecond) // it cant get new hot region because wait time smaller than hot region write interval From 5d744d356a040a6bd21ce4eac0c0d5a586862b28 Mon Sep 17 00:00:00 2001 From: Shirly Date: Wed, 15 Jun 2022 15:52:33 +0800 Subject: [PATCH 47/82] server/grpc_service: make the lock for `UpdateServiceGCSafePoint` smaller (#5128) close tikv/pd#5019 Signed-off-by: shirly Co-authored-by: Ti Chi Robot --- server/gc/safepoint.go | 63 +++++++++++++++++--------- server/gc/safepoint_test.go | 88 ++++++++++++++++++++++++++++++++++++- server/grpc_service.go | 34 +++----------- server/server.go | 6 +-- 4 files changed, 135 insertions(+), 56 deletions(-) diff --git a/server/gc/safepoint.go b/server/gc/safepoint.go index 3cec08d8951..533ae338580 100644 --- a/server/gc/safepoint.go +++ b/server/gc/safepoint.go @@ -15,42 +15,35 @@ package gc import ( + "math" + "time" + "github.com/tikv/pd/pkg/syncutil" "github.com/tikv/pd/server/storage/endpoint" ) -// SafePointManager is the manager for safePoint of GC and services +// SafePointManager is the manager for safePoint of GC and services. type SafePointManager struct { - *gcSafePointManager - // TODO add ServiceSafepointManager -} - -// NewSafepointManager creates a SafePointManager of GC and services -func NewSafepointManager(store endpoint.GCSafePointStorage) *SafePointManager { - return &SafePointManager{ - newGCSafePointManager(store), - } + gcLock syncutil.Mutex + serviceGCLock syncutil.Mutex + store endpoint.GCSafePointStorage } -type gcSafePointManager struct { - syncutil.Mutex - store endpoint.GCSafePointStorage -} - -func newGCSafePointManager(store endpoint.GCSafePointStorage) *gcSafePointManager { - return &gcSafePointManager{store: store} +// NewSafePointManager creates a SafePointManager of GC and services. +func NewSafePointManager(store endpoint.GCSafePointStorage) *SafePointManager { + return &SafePointManager{store: store} } // LoadGCSafePoint loads current GC safe point from storage. -func (manager *gcSafePointManager) LoadGCSafePoint() (uint64, error) { +func (manager *SafePointManager) LoadGCSafePoint() (uint64, error) { return manager.store.LoadGCSafePoint() } // UpdateGCSafePoint updates the safepoint if it is greater than the previous one // it returns the old safepoint in the storage. -func (manager *gcSafePointManager) UpdateGCSafePoint(newSafePoint uint64) (oldSafePoint uint64, err error) { - manager.Lock() - defer manager.Unlock() +func (manager *SafePointManager) UpdateGCSafePoint(newSafePoint uint64) (oldSafePoint uint64, err error) { + manager.gcLock.Lock() + defer manager.gcLock.Unlock() // TODO: cache the safepoint in the storage. oldSafePoint, err = manager.store.LoadGCSafePoint() if err != nil { @@ -62,3 +55,31 @@ func (manager *gcSafePointManager) UpdateGCSafePoint(newSafePoint uint64) (oldSa err = manager.store.SaveGCSafePoint(newSafePoint) return } + +// UpdateServiceGCSafePoint update the safepoint for a specific service. +func (manager *SafePointManager) UpdateServiceGCSafePoint(serviceID string, newSafePoint uint64, ttl int64, now time.Time) (minServiceSafePoint *endpoint.ServiceSafePoint, updated bool, err error) { + manager.serviceGCLock.Lock() + defer manager.serviceGCLock.Unlock() + minServiceSafePoint, err = manager.store.LoadMinServiceGCSafePoint(now) + if err != nil || ttl <= 0 || newSafePoint < minServiceSafePoint.SafePoint { + return minServiceSafePoint, false, err + } + + ssp := &endpoint.ServiceSafePoint{ + ServiceID: serviceID, + ExpiredAt: now.Unix() + ttl, + SafePoint: newSafePoint, + } + if math.MaxInt64-now.Unix() <= ttl { + ssp.ExpiredAt = math.MaxInt64 + } + if err := manager.store.SaveServiceGCSafePoint(ssp); err != nil { + return nil, false, err + } + + // If the min safePoint is updated, load the next one. + if serviceID == minServiceSafePoint.ServiceID { + minServiceSafePoint, err = manager.store.LoadMinServiceGCSafePoint(now) + } + return minServiceSafePoint, true, err +} diff --git a/server/gc/safepoint_test.go b/server/gc/safepoint_test.go index 2af82ba7145..aebf8033dea 100644 --- a/server/gc/safepoint_test.go +++ b/server/gc/safepoint_test.go @@ -15,8 +15,10 @@ package gc import ( + "math" "sync" "testing" + "time" "github.com/stretchr/testify/require" "github.com/tikv/pd/server/storage/endpoint" @@ -28,7 +30,7 @@ func newGCStorage() endpoint.GCSafePointStorage { } func TestGCSafePointUpdateSequentially(t *testing.T) { - gcSafePointManager := newGCSafePointManager(newGCStorage()) + gcSafePointManager := NewSafePointManager(newGCStorage()) re := require.New(t) curSafePoint := uint64(0) // update gc safePoint with asc value. @@ -57,7 +59,7 @@ func TestGCSafePointUpdateSequentially(t *testing.T) { } func TestGCSafePointUpdateCurrently(t *testing.T) { - gcSafePointManager := newGCSafePointManager(newGCStorage()) + gcSafePointManager := NewSafePointManager(newGCStorage()) maxSafePoint := uint64(1000) wg := sync.WaitGroup{} re := require.New(t) @@ -78,3 +80,85 @@ func TestGCSafePointUpdateCurrently(t *testing.T) { re.NoError(err) re.Equal(maxSafePoint, safePoint) } + +func TestServiceGCSafePointUpdate(t *testing.T) { + re := require.New(t) + manager := NewSafePointManager(newGCStorage()) + gcworkerServiceID := "gc_worker" + cdcServiceID := "cdc" + brServiceID := "br" + cdcServiceSafePoint := uint64(10) + gcWorkerSafePoint := uint64(8) + brSafePoint := uint64(15) + + wg := sync.WaitGroup{} + wg.Add(5) + // update the safepoint for cdc to 10 should success + go func() { + defer wg.Done() + min, updated, err := manager.UpdateServiceGCSafePoint(cdcServiceID, cdcServiceSafePoint, 10000, time.Now()) + re.NoError(err) + re.True(updated) + // the service will init the service safepoint to 0(<10 for cdc) for gc_worker. + re.Equal(gcworkerServiceID, min.ServiceID) + }() + + // update the safepoint for br to 15 should success + go func() { + defer wg.Done() + min, updated, err := manager.UpdateServiceGCSafePoint(brServiceID, brSafePoint, 10000, time.Now()) + re.NoError(err) + re.True(updated) + // the service will init the service safepoint to 0(<10 for cdc) for gc_worker. + re.Equal(gcworkerServiceID, min.ServiceID) + }() + + // update safepoint to 8 for gc_woker should be success + go func() { + defer wg.Done() + // update with valid ttl for gc_worker should be success. + min, updated, _ := manager.UpdateServiceGCSafePoint(gcworkerServiceID, gcWorkerSafePoint, math.MaxInt64, time.Now()) + re.True(updated) + // the current min safepoint should be 8 for gc_worker(cdc 10) + re.Equal(gcWorkerSafePoint, min.SafePoint) + re.Equal(gcworkerServiceID, min.ServiceID) + }() + + go func() { + defer wg.Done() + // update safepoint of gc_worker's service with ttl not infinity should be failed. + _, updated, err := manager.UpdateServiceGCSafePoint(gcworkerServiceID, 10000, 10, time.Now()) + re.Error(err) + re.False(updated) + }() + + // update safepoint with negative ttl should be failed. + go func() { + defer wg.Done() + brTTL := int64(-100) + _, updated, err := manager.UpdateServiceGCSafePoint(brServiceID, uint64(10000), brTTL, time.Now()) + re.NoError(err) + re.False(updated) + }() + + wg.Wait() + // update safepoint to 15(>10 for cdc) for gc_worker + gcWorkerSafePoint = uint64(15) + min, updated, err := manager.UpdateServiceGCSafePoint(gcworkerServiceID, gcWorkerSafePoint, math.MaxInt64, time.Now()) + re.NoError(err) + re.True(updated) + re.Equal(cdcServiceID, min.ServiceID) + re.Equal(cdcServiceSafePoint, min.SafePoint) + + // the value shouldn't be updated with current safepoint smaller than the min safepoint. + brTTL := int64(100) + brSafePoint = min.SafePoint - 5 + min, updated, err = manager.UpdateServiceGCSafePoint(brServiceID, brSafePoint, brTTL, time.Now()) + re.NoError(err) + re.False(updated) + + brSafePoint = min.SafePoint + 10 + _, updated, err = manager.UpdateServiceGCSafePoint(brServiceID, brSafePoint, brTTL, time.Now()) + re.NoError(err) + re.True(updated) +} diff --git a/server/grpc_service.go b/server/grpc_service.go index 53b74ba517d..c02e51ed510 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -18,7 +18,6 @@ import ( "context" "fmt" "io" - "math" "strconv" "sync/atomic" "time" @@ -1358,8 +1357,6 @@ func (s *GrpcServer) UpdateGCSafePoint(ctx context.Context, request *pdpb.Update // UpdateServiceGCSafePoint update the safepoint for specific service func (s *GrpcServer) UpdateServiceGCSafePoint(ctx context.Context, request *pdpb.UpdateServiceGCSafePointRequest) (*pdpb.UpdateServiceGCSafePointResponse, error) { - s.serviceSafePointLock.Lock() - defer s.serviceSafePointLock.Unlock() fn := func(ctx context.Context, client *grpc.ClientConn) (interface{}, error) { return pdpb.NewPDClient(client).UpdateServiceGCSafePoint(ctx, request) } @@ -1385,36 +1382,17 @@ func (s *GrpcServer) UpdateServiceGCSafePoint(ctx context.Context, request *pdpb return nil, err } now, _ := tsoutil.ParseTimestamp(nowTSO) - min, err := storage.LoadMinServiceGCSafePoint(now) + serviceID := string(request.ServiceId) + min, updated, err := s.gcSafePointManager.UpdateServiceGCSafePoint(serviceID, request.GetSafePoint(), request.GetTTL(), now) if err != nil { return nil, err } - - if request.TTL > 0 && request.SafePoint >= min.SafePoint { - ssp := &endpoint.ServiceSafePoint{ - ServiceID: string(request.ServiceId), - ExpiredAt: now.Unix() + request.TTL, - SafePoint: request.SafePoint, - } - if math.MaxInt64-now.Unix() <= request.TTL { - ssp.ExpiredAt = math.MaxInt64 - } - if err := storage.SaveServiceGCSafePoint(ssp); err != nil { - return nil, err - } + if updated { log.Info("update service GC safe point", - zap.String("service-id", ssp.ServiceID), - zap.Int64("expire-at", ssp.ExpiredAt), - zap.Uint64("safepoint", ssp.SafePoint)) - // If the min safepoint is updated, load the next one - if string(request.ServiceId) == min.ServiceID { - min, err = storage.LoadMinServiceGCSafePoint(now) - if err != nil { - return nil, err - } - } + zap.String("service-id", serviceID), + zap.Int64("expire-at", now.Unix()+request.GetTTL()), + zap.Uint64("safepoint", request.GetSafePoint())) } - return &pdpb.UpdateServiceGCSafePointResponse{ Header: s.header(), ServiceId: []byte(min.ServiceID), diff --git a/server/server.go b/server/server.go index b618f97aeed..c7902692552 100644 --- a/server/server.go +++ b/server/server.go @@ -46,7 +46,6 @@ import ( "github.com/tikv/pd/pkg/grpcutil" "github.com/tikv/pd/pkg/logutil" "github.com/tikv/pd/pkg/ratelimit" - "github.com/tikv/pd/pkg/syncutil" "github.com/tikv/pd/pkg/systimemon" "github.com/tikv/pd/pkg/typeutil" "github.com/tikv/pd/server/cluster" @@ -146,9 +145,6 @@ type Server struct { startCallbacks []func() closeCallbacks []func() - // serviceSafePointLock is a lock for UpdateServiceGCSafePoint - serviceSafePointLock syncutil.Mutex - // hot region history info storeage hotRegionStorage *storage.HotRegionStorage // Store as map[string]*grpc.ClientConn @@ -404,7 +400,7 @@ func (s *Server) startServer(ctx context.Context) error { } defaultStorage := storage.NewStorageWithEtcdBackend(s.client, s.rootPath) s.storage = storage.NewCoreStorage(defaultStorage, regionStorage) - s.gcSafePointManager = gc.NewSafepointManager(s.storage) + s.gcSafePointManager = gc.NewSafePointManager(s.storage) s.basicCluster = core.NewBasicCluster() s.cluster = cluster.NewRaftCluster(ctx, s.clusterID, syncer.NewRegionSyncer(s), s.client, s.httpClient) s.hbStreams = hbstream.NewHeartbeatStreams(ctx, s.clusterID, s.cluster) From b58a7d5ae2a18216617d4ff70b29753583c99f14 Mon Sep 17 00:00:00 2001 From: Shirly Date: Wed, 15 Jun 2022 16:20:34 +0800 Subject: [PATCH 48/82] replica_strategy:speed up and reduce the complexity of selectStore to O(n) (#5144) close tikv/pd#5143 Signed-off-by: shirly Co-authored-by: Ti Chi Robot --- server/schedule/checker/replica_strategy.go | 11 +++-- server/schedule/filter/candidates.go | 49 ++++++++++++++------- server/schedule/filter/candidates_test.go | 19 +++++--- 3 files changed, 51 insertions(+), 28 deletions(-) diff --git a/server/schedule/checker/replica_strategy.go b/server/schedule/checker/replica_strategy.go index 5a249c46a8b..6ccad30a32d 100644 --- a/server/schedule/checker/replica_strategy.go +++ b/server/schedule/checker/replica_strategy.go @@ -72,12 +72,12 @@ func (s *ReplicaStrategy) SelectStoreToAdd(coLocationStores []*core.StoreInfo, e strictStateFilter := &filter.StoreStateFilter{ActionScope: s.checkerName, MoveRegion: true} targetCandidate := filter.NewCandidates(s.cluster.GetStores()). FilterTarget(s.cluster.GetOpts(), filters...). - Sort(isolationComparer).Reverse().Top(isolationComparer). // greater isolation score is better - Sort(filter.RegionScoreComparer(s.cluster.GetOpts())) // less region score is better + KeepTheTopStores(isolationComparer, false) // greater isolation score is better if targetCandidate.Len() == 0 { return 0, false } - target := targetCandidate.FilterTarget(s.cluster.GetOpts(), strictStateFilter).PickFirst() // the filter does not ignore temp states + target := targetCandidate.FilterTarget(s.cluster.GetOpts(), strictStateFilter). + PickTheTopStore(filter.RegionScoreComparer(s.cluster.GetOpts()), true) // less region score is better if target == nil { return 0, true // filter by temporary states } @@ -124,9 +124,8 @@ func (s *ReplicaStrategy) SelectStoreToRemove(coLocationStores []*core.StoreInfo isolationComparer := filter.IsolationComparer(s.locationLabels, coLocationStores) source := filter.NewCandidates(coLocationStores). FilterSource(s.cluster.GetOpts(), &filter.StoreStateFilter{ActionScope: replicaCheckerName, MoveRegion: true}). - Sort(isolationComparer).Top(isolationComparer). - Sort(filter.RegionScoreComparer(s.cluster.GetOpts())).Reverse(). - PickFirst() + KeepTheTopStores(isolationComparer, true). + PickTheTopStore(filter.RegionScoreComparer(s.cluster.GetOpts()), false) if source == nil { log.Debug("no removable store", zap.Uint64("region-id", s.region.GetID())) return 0 diff --git a/server/schedule/filter/candidates.go b/server/schedule/filter/candidates.go index 969fec34d38..dcbe89710a8 100644 --- a/server/schedule/filter/candidates.go +++ b/server/schedule/filter/candidates.go @@ -51,32 +51,49 @@ func (c *StoreCandidates) Sort(less StoreComparer) *StoreCandidates { return c } -// Reverse reverses the candidate store list. -func (c *StoreCandidates) Reverse() *StoreCandidates { - for i := len(c.Stores)/2 - 1; i >= 0; i-- { - opp := len(c.Stores) - 1 - i - c.Stores[i], c.Stores[opp] = c.Stores[opp], c.Stores[i] - } - return c -} - // Shuffle reorders all candidates randomly. func (c *StoreCandidates) Shuffle() *StoreCandidates { rand.Shuffle(len(c.Stores), func(i, j int) { c.Stores[i], c.Stores[j] = c.Stores[j], c.Stores[i] }) return c } -// Top keeps all stores that have the same priority with the first store. -// The store list should be sorted before calling Top. -func (c *StoreCandidates) Top(less StoreComparer) *StoreCandidates { - var i int - for i < len(c.Stores) && less(c.Stores[0], c.Stores[i]) == 0 { - i++ +// KeepTheTopStores keeps the slice of the stores in the front order by asc. +func (c *StoreCandidates) KeepTheTopStores(cmp StoreComparer, asc bool) *StoreCandidates { + if len(c.Stores) <= 1 { + return c + } + topIdx := 0 + for idx := 1; idx < c.Len(); idx++ { + compare := cmp(c.Stores[topIdx], c.Stores[idx]) + if compare == 0 { + topIdx++ + } else if (compare > 0 && asc) || (!asc && compare < 0) { + topIdx = 0 + } else { + continue + } + c.Stores[idx], c.Stores[topIdx] = c.Stores[topIdx], c.Stores[idx] } - c.Stores = c.Stores[:i] + c.Stores = c.Stores[:topIdx+1] return c } +// PickTheTopStore returns the first store order by asc. +// It returns the min item when asc is true, returns the max item when asc is false. +func (c *StoreCandidates) PickTheTopStore(cmp StoreComparer, asc bool) *core.StoreInfo { + if len(c.Stores) == 0 { + return nil + } + topIdx := 0 + for idx := 1; idx < len(c.Stores); idx++ { + compare := cmp(c.Stores[topIdx], c.Stores[idx]) + if (compare > 0 && asc) || (!asc && compare < 0) { + topIdx = idx + } + } + return c.Stores[topIdx] +} + // PickFirst returns the first store in candidate list. func (c *StoreCandidates) PickFirst() *core.StoreInfo { if len(c.Stores) == 0 { diff --git a/server/schedule/filter/candidates_test.go b/server/schedule/filter/candidates_test.go index 5150bed9b66..bb86906b081 100644 --- a/server/schedule/filter/candidates_test.go +++ b/server/schedule/filter/candidates_test.go @@ -15,10 +15,10 @@ package filter import ( - "github.com/pingcap/kvproto/pkg/metapb" - "github.com/stretchr/testify/require" "testing" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/core" ) @@ -71,13 +71,16 @@ func TestCandidates(t *testing.T) { re.Nil(store) cs = newTestCandidates(1, 3, 5, 7, 6, 2, 4) + minStore := cs.PickTheTopStore(idComparer, true) + re.Equal(uint64(1), minStore.GetID()) + maxStore := cs.PickTheTopStore(idComparer, false) + re.Equal(uint64(7), maxStore.GetID()) + cs.Sort(idComparer) check(re, cs, 1, 2, 3, 4, 5, 6, 7) store = cs.PickFirst() re.Equal(uint64(1), store.GetID()) - cs.Reverse() - check(re, cs, 7, 6, 5, 4, 3, 2, 1) - store = cs.PickFirst() + store = cs.PickTheTopStore(idComparer, false) re.Equal(uint64(7), store.GetID()) cs.Shuffle() cs.Sort(idComparer) @@ -87,8 +90,12 @@ func TestCandidates(t *testing.T) { re.Less(store.GetID(), uint64(8)) cs = newTestCandidates(10, 15, 23, 20, 33, 32, 31) - cs.Sort(idComparer).Reverse().Top(idComparer2) + cs.KeepTheTopStores(idComparer2, false) check(re, cs, 33, 32, 31) + + cs = newTestCandidates(10, 15, 23, 20, 33, 32, 31) + cs.KeepTheTopStores(idComparer2, true) + check(re, cs, 10, 15) } func newTestCandidates(ids ...uint64) *StoreCandidates { From c628ff94a9a3251f959d069f0d1ebbbec3d6e0a3 Mon Sep 17 00:00:00 2001 From: "Reg [bot]" <86050514+tidb-dashboard-bot@users.noreply.github.com> Date: Wed, 15 Jun 2022 16:32:34 +0800 Subject: [PATCH 49/82] Update TiDB Dashboard to v2022.06.13.1 [master] (#5147) ref tikv/pd#4257 Signed-off-by: tidb-dashboard-bot Co-authored-by: tidb-dashboard-bot Co-authored-by: ShuNing Co-authored-by: Ti Chi Robot --- go.mod | 2 +- go.sum | 4 ++-- tests/client/go.sum | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index 5eeb2656499..0a662f0f16b 100644 --- a/go.mod +++ b/go.mod @@ -32,7 +32,7 @@ require ( github.com/pingcap/kvproto v0.0.0-20220510035547-0e2f26c0a46a github.com/pingcap/log v0.0.0-20210906054005-afc726e70354 github.com/pingcap/sysutil v0.0.0-20211208032423-041a72e5860d - github.com/pingcap/tidb-dashboard v0.0.0-20220518164040-4d621864a9a0 + github.com/pingcap/tidb-dashboard v0.0.0-20220613053259-1b8920062bd3 github.com/prometheus/client_golang v1.1.0 github.com/prometheus/common v0.6.0 github.com/sasha-s/go-deadlock v0.2.0 diff --git a/go.sum b/go.sum index 50997691e2a..ad3415e8116 100644 --- a/go.sum +++ b/go.sum @@ -408,8 +408,8 @@ github.com/pingcap/log v0.0.0-20210906054005-afc726e70354 h1:SvWCbCPh1YeHd9yQLks github.com/pingcap/log v0.0.0-20210906054005-afc726e70354/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= github.com/pingcap/sysutil v0.0.0-20211208032423-041a72e5860d h1:k3/APKZjXOyJrFy8VyYwRlZhMelpD3qBLJNsw3bPl/g= github.com/pingcap/sysutil v0.0.0-20211208032423-041a72e5860d/go.mod h1:7j18ezaWTao2LHOyMlsc2Dg1vW+mDY9dEbPzVyOlaeM= -github.com/pingcap/tidb-dashboard v0.0.0-20220518164040-4d621864a9a0 h1:SNfoqt/qZ+tSnFcOIn6rvhmH06UGJ137Of+uK9q1oOk= -github.com/pingcap/tidb-dashboard v0.0.0-20220518164040-4d621864a9a0/go.mod h1:Hc2LXf5Vs+KwyegHd6osyZ2+LfaVSfWEwuR86SNg7tk= +github.com/pingcap/tidb-dashboard v0.0.0-20220613053259-1b8920062bd3 h1:chUUmmcfNVtfR1c7/qaoLLA2SgaP79LLVXoXV9F4lP8= +github.com/pingcap/tidb-dashboard v0.0.0-20220613053259-1b8920062bd3/go.mod h1:Hc2LXf5Vs+KwyegHd6osyZ2+LfaVSfWEwuR86SNg7tk= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= diff --git a/tests/client/go.sum b/tests/client/go.sum index 493b1cd39f4..b9d76d704a8 100644 --- a/tests/client/go.sum +++ b/tests/client/go.sum @@ -417,8 +417,8 @@ github.com/pingcap/log v0.0.0-20211215031037-e024ba4eb0ee h1:VO2t6IBpfvW34TdtD/G github.com/pingcap/log v0.0.0-20211215031037-e024ba4eb0ee/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= github.com/pingcap/sysutil v0.0.0-20211208032423-041a72e5860d h1:k3/APKZjXOyJrFy8VyYwRlZhMelpD3qBLJNsw3bPl/g= github.com/pingcap/sysutil v0.0.0-20211208032423-041a72e5860d/go.mod h1:7j18ezaWTao2LHOyMlsc2Dg1vW+mDY9dEbPzVyOlaeM= -github.com/pingcap/tidb-dashboard v0.0.0-20220518164040-4d621864a9a0 h1:SNfoqt/qZ+tSnFcOIn6rvhmH06UGJ137Of+uK9q1oOk= -github.com/pingcap/tidb-dashboard v0.0.0-20220518164040-4d621864a9a0/go.mod h1:Hc2LXf5Vs+KwyegHd6osyZ2+LfaVSfWEwuR86SNg7tk= +github.com/pingcap/tidb-dashboard v0.0.0-20220613053259-1b8920062bd3 h1:chUUmmcfNVtfR1c7/qaoLLA2SgaP79LLVXoXV9F4lP8= +github.com/pingcap/tidb-dashboard v0.0.0-20220613053259-1b8920062bd3/go.mod h1:Hc2LXf5Vs+KwyegHd6osyZ2+LfaVSfWEwuR86SNg7tk= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= From 0da658cec9659c08e8dfefea54daac3de25c5331 Mon Sep 17 00:00:00 2001 From: Yongbo Jiang Date: Wed, 15 Jun 2022 17:22:34 +0800 Subject: [PATCH 50/82] api: add Rate-limit config update API (#4843) ref tikv/pd#4666, ref tikv/pd#4839 add Rate-limit config update API Signed-off-by: Cabinfever_B Co-authored-by: Ryan Leung Co-authored-by: Ti Chi Robot --- pkg/jsonutil/jsonutil.go | 49 ++++++++ pkg/jsonutil/jsonutil_test.go | 65 +++++++++++ server/api/config.go | 43 +------ server/api/router.go | 1 + server/api/service_middleware.go | 123 ++++++++++++++++---- server/api/service_middleware_test.go | 160 +++++++++++++++++++++++++- server/server.go | 40 +++++++ 7 files changed, 420 insertions(+), 61 deletions(-) create mode 100644 pkg/jsonutil/jsonutil.go create mode 100644 pkg/jsonutil/jsonutil_test.go diff --git a/pkg/jsonutil/jsonutil.go b/pkg/jsonutil/jsonutil.go new file mode 100644 index 00000000000..c5ae2f378da --- /dev/null +++ b/pkg/jsonutil/jsonutil.go @@ -0,0 +1,49 @@ +// Copyright 2022 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package jsonutil + +import ( + "bytes" + "encoding/json" + + "github.com/tikv/pd/pkg/reflectutil" +) + +// AddKeyValue is used to add a key value pair into `old` +func AddKeyValue(old interface{}, key string, value interface{}) (updated bool, found bool, err error) { + data, err := json.Marshal(map[string]interface{}{key: value}) + if err != nil { + return false, false, err + } + return MergeJSONObject(old, data) +} + +// MergeJSONObject is used to merge a marshaled json object into v +func MergeJSONObject(v interface{}, data []byte) (updated bool, found bool, err error) { + old, _ := json.Marshal(v) + if err := json.Unmarshal(data, v); err != nil { + return false, false, err + } + new, _ := json.Marshal(v) + if !bytes.Equal(old, new) { + return true, true, nil + } + m := make(map[string]interface{}) + if err := json.Unmarshal(data, &m); err != nil { + return false, false, err + } + found = reflectutil.FindSameFieldByJSON(v, m) + return false, found, nil +} diff --git a/pkg/jsonutil/jsonutil_test.go b/pkg/jsonutil/jsonutil_test.go new file mode 100644 index 00000000000..a046fbaf70a --- /dev/null +++ b/pkg/jsonutil/jsonutil_test.go @@ -0,0 +1,65 @@ +// Copyright 2022 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package jsonutil + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +type testJSONStructLevel1 struct { + Name string `json:"name"` + Sub1 testJSONStructLevel2 `json:"sub1"` + Sub2 testJSONStructLevel2 `json:"sub2"` +} + +type testJSONStructLevel2 struct { + SubName string `json:"sub-name"` +} + +func TestJSONUtil(t *testing.T) { + t.Parallel() + re := require.New(t) + father := &testJSONStructLevel1{ + Name: "father", + } + son1 := &testJSONStructLevel2{ + SubName: "son1", + } + update, found, err := AddKeyValue(&father, "sub1", &son1) + re.NoError(err) + re.True(update) + re.True(found) + + son2 := &testJSONStructLevel2{ + SubName: "son2", + } + + update, found, err = AddKeyValue(father, "sub2", &son2) + re.NoError(err) + re.True(update) + re.True(found) + + update, found, err = AddKeyValue(father, "sub3", &son2) + re.NoError(err) + re.False(update) + re.False(found) + + update, found, err = AddKeyValue(father, "sub2", &son2) + re.NoError(err) + re.False(update) + re.True(found) +} diff --git a/server/api/config.go b/server/api/config.go index d4d90735289..b33dd5c5a97 100644 --- a/server/api/config.go +++ b/server/api/config.go @@ -15,7 +15,6 @@ package api import ( - "bytes" "encoding/json" "fmt" "io" @@ -29,6 +28,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/log" "github.com/tikv/pd/pkg/apiutil" + "github.com/tikv/pd/pkg/jsonutil" "github.com/tikv/pd/pkg/logutil" "github.com/tikv/pd/pkg/reflectutil" "github.com/tikv/pd/server" @@ -166,12 +166,7 @@ func (h *confHandler) updateConfig(cfg *config.Config, key string, value interfa } func (h *confHandler) updateSchedule(config *config.Config, key string, value interface{}) error { - data, err := json.Marshal(map[string]interface{}{key: value}) - if err != nil { - return err - } - - updated, found, err := mergeConfig(&config.Schedule, data) + updated, found, err := jsonutil.AddKeyValue(&config.Schedule, key, value) if err != nil { return err } @@ -187,12 +182,7 @@ func (h *confHandler) updateSchedule(config *config.Config, key string, value in } func (h *confHandler) updateReplication(config *config.Config, key string, value interface{}) error { - data, err := json.Marshal(map[string]interface{}{key: value}) - if err != nil { - return err - } - - updated, found, err := mergeConfig(&config.Replication, data) + updated, found, err := jsonutil.AddKeyValue(&config.Replication, key, value) if err != nil { return err } @@ -214,8 +204,7 @@ func (h *confHandler) updateReplicationModeConfig(config *config.Config, key []s if err != nil { return err } - - updated, found, err := mergeConfig(&config.ReplicationMode, data) + updated, found, err := jsonutil.MergeJSONObject(&config.ReplicationMode, data) if err != nil { return err } @@ -231,12 +220,7 @@ func (h *confHandler) updateReplicationModeConfig(config *config.Config, key []s } func (h *confHandler) updatePDServerConfig(config *config.Config, key string, value interface{}) error { - data, err := json.Marshal(map[string]interface{}{key: value}) - if err != nil { - return err - } - - updated, found, err := mergeConfig(&config.PDServerCfg, data) + updated, found, err := jsonutil.AddKeyValue(&config.PDServerCfg, key, value) if err != nil { return err } @@ -288,23 +272,6 @@ func getConfigMap(cfg map[string]interface{}, key []string, value interface{}) m return cfg } -func mergeConfig(v interface{}, data []byte) (updated bool, found bool, err error) { - old, _ := json.Marshal(v) - if err := json.Unmarshal(data, v); err != nil { - return false, false, err - } - new, _ := json.Marshal(v) - if !bytes.Equal(old, new) { - return true, true, nil - } - m := make(map[string]interface{}) - if err := json.Unmarshal(data, &m); err != nil { - return false, false, err - } - found = reflectutil.FindSameFieldByJSON(v, m) - return false, found, nil -} - // @Tags config // @Summary Get schedule config. // @Produce json diff --git a/server/api/router.go b/server/api/router.go index e755341ebef..3e8061fd74e 100644 --- a/server/api/router.go +++ b/server/api/router.go @@ -285,6 +285,7 @@ func createRouter(prefix string, svr *server.Server) *mux.Router { serviceMiddlewareHandler := newServiceMiddlewareHandler(svr, rd) registerFunc(apiRouter, "/service-middleware/config", serviceMiddlewareHandler.GetServiceMiddlewareConfig, setMethods("GET")) registerFunc(apiRouter, "/service-middleware/config", serviceMiddlewareHandler.SetServiceMiddlewareConfig, setMethods("POST"), setAuditBackend(localLog)) + registerFunc(apiRouter, "/service-middleware/config/rate-limit", serviceMiddlewareHandler.SetRatelimitConfig, setMethods("POST"), setAuditBackend(localLog)) logHandler := newLogHandler(svr, rd) registerFunc(apiRouter, "/admin/log", logHandler.SetLogLevel, setMethods("POST"), setAuditBackend(localLog)) diff --git a/server/api/service_middleware.go b/server/api/service_middleware.go index 0f41f8ae725..426399a1d6e 100644 --- a/server/api/service_middleware.go +++ b/server/api/service_middleware.go @@ -23,6 +23,9 @@ import ( "strings" "github.com/pingcap/errors" + "github.com/tikv/pd/pkg/apiutil" + "github.com/tikv/pd/pkg/jsonutil" + "github.com/tikv/pd/pkg/ratelimit" "github.com/tikv/pd/pkg/reflectutil" "github.com/tikv/pd/server" "github.com/tikv/pd/server/config" @@ -107,18 +110,13 @@ func (h *serviceMiddlewareHandler) updateServiceMiddlewareConfig(cfg *config.Ser case "audit": return h.updateAudit(cfg, kp[len(kp)-1], value) case "rate-limit": - return h.updateRateLimit(cfg, kp[len(kp)-1], value) + return h.svr.UpdateRateLimit(&cfg.RateLimitConfig, kp[len(kp)-1], value) } return errors.Errorf("config prefix %s not found", kp[0]) } func (h *serviceMiddlewareHandler) updateAudit(config *config.ServiceMiddlewareConfig, key string, value interface{}) error { - data, err := json.Marshal(map[string]interface{}{key: value}) - if err != nil { - return err - } - - updated, found, err := mergeConfig(&config.AuditConfig, data) + updated, found, err := jsonutil.AddKeyValue(&config.AuditConfig, key, value) if err != nil { return err } @@ -133,23 +131,104 @@ func (h *serviceMiddlewareHandler) updateAudit(config *config.ServiceMiddlewareC return err } -func (h *serviceMiddlewareHandler) updateRateLimit(config *config.ServiceMiddlewareConfig, key string, value interface{}) error { - data, err := json.Marshal(map[string]interface{}{key: value}) - if err != nil { - return err +// @Tags service_middleware +// @Summary update ratelimit config +// @Param body body object string "json params" +// @Produce json +// @Success 200 {string} string +// @Failure 400 {string} string "The input is invalid." +// @Failure 500 {string} string "config item not found" +// @Router /service-middleware/config/rate-limit [POST] +func (h *serviceMiddlewareHandler) SetRatelimitConfig(w http.ResponseWriter, r *http.Request) { + var input map[string]interface{} + if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &input); err != nil { + return } - - updated, found, err := mergeConfig(&config.RateLimitConfig, data) - if err != nil { - return err + typeStr, ok := input["type"].(string) + if !ok { + h.rd.JSON(w, http.StatusBadRequest, "The type is empty.") + return } - - if !found { - return errors.Errorf("config item %s not found", key) + var serviceLabel string + switch typeStr { + case "label": + serviceLabel, ok = input["label"].(string) + if !ok || len(serviceLabel) == 0 { + h.rd.JSON(w, http.StatusBadRequest, "The label is empty.") + return + } + if len(h.svr.GetServiceLabels(serviceLabel)) == 0 { + h.rd.JSON(w, http.StatusBadRequest, "There is no label matched.") + return + } + case "path": + method, _ := input["method"].(string) + path, ok := input["path"].(string) + if !ok || len(path) == 0 { + h.rd.JSON(w, http.StatusBadRequest, "The path is empty.") + return + } + serviceLabel = h.svr.GetAPIAccessServiceLabel(apiutil.NewAccessPath(path, method)) + if len(serviceLabel) == 0 { + h.rd.JSON(w, http.StatusBadRequest, "There is no label matched.") + return + } + default: + h.rd.JSON(w, http.StatusBadRequest, "The type is invalid.") + return } - - if updated { - err = h.svr.SetRateLimitConfig(config.RateLimitConfig) + if h.svr.IsInRateLimitAllowList(serviceLabel) { + h.rd.JSON(w, http.StatusBadRequest, "This service is in allow list whose config can not be changed.") + return } - return err + cfg := h.svr.GetRateLimitConfig().LimiterConfig[serviceLabel] + // update concurrency limiter + concurrencyUpdatedFlag := "Concurrency limiter is not changed." + concurrencyFloat, okc := input["concurrency"].(float64) + if okc { + cfg.ConcurrencyLimit = uint64(concurrencyFloat) + } + // update qps rate limiter + qpsRateUpdatedFlag := "QPS rate limiter is not changed." + qps, okq := input["qps"].(float64) + if okq { + brust := 0 + if int(qps) > 1 { + brust = int(qps) + } else if qps > 0 { + brust = 1 + } + cfg.QPS = qps + cfg.QPSBurst = brust + } + if !okc && !okq { + h.rd.JSON(w, http.StatusOK, "No changed.") + } else { + status := h.svr.UpdateServiceRateLimiter(serviceLabel, ratelimit.UpdateDimensionConfig(&cfg)) + switch { + case status&ratelimit.QPSChanged != 0: + qpsRateUpdatedFlag = "QPS rate limiter is changed." + case status&ratelimit.QPSDeleted != 0: + qpsRateUpdatedFlag = "QPS rate limiter is deleted." + } + switch { + case status&ratelimit.ConcurrencyChanged != 0: + concurrencyUpdatedFlag = "Concurrency limiter is changed." + case status&ratelimit.ConcurrencyDeleted != 0: + concurrencyUpdatedFlag = "Concurrency limiter is deleted." + } + err := h.svr.UpdateRateLimitConfig("limiter-config", serviceLabel, cfg) + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + } else { + result := rateLimitResult{concurrencyUpdatedFlag, qpsRateUpdatedFlag, h.svr.GetServiceMiddlewareConfig().RateLimitConfig.LimiterConfig} + h.rd.JSON(w, http.StatusOK, result) + } + } +} + +type rateLimitResult struct { + ConcurrencyUpdatedFlag string `json:"concurrency"` + QPSRateUpdatedFlag string `json:"qps"` + LimiterConfig map[string]ratelimit.DimensionConfig `json:"limiter-config"` } diff --git a/server/api/service_middleware_test.go b/server/api/service_middleware_test.go index a1d4804650c..6ea0343f53b 100644 --- a/server/api/service_middleware_test.go +++ b/server/api/service_middleware_test.go @@ -57,7 +57,8 @@ func (s *testAuditMiddlewareSuite) TestConfigAuditSwitch(c *C) { c.Assert(sc.EnableAudit, Equals, false) ms := map[string]interface{}{ - "enable-audit": "true", + "enable-audit": "true", + "enable-rate-limit": "true", } postData, err := json.Marshal(ms) c.Assert(err, IsNil) @@ -65,8 +66,10 @@ func (s *testAuditMiddlewareSuite) TestConfigAuditSwitch(c *C) { sc = &config.ServiceMiddlewareConfig{} c.Assert(tu.ReadGetJSON(c, testDialClient, addr, sc), IsNil) c.Assert(sc.EnableAudit, Equals, true) + c.Assert(sc.EnableRateLimit, Equals, true) ms = map[string]interface{}{ "audit.enable-audit": "false", + "enable-rate-limit": "false", } postData, err = json.Marshal(ms) c.Assert(err, IsNil) @@ -74,6 +77,7 @@ func (s *testAuditMiddlewareSuite) TestConfigAuditSwitch(c *C) { sc = &config.ServiceMiddlewareConfig{} c.Assert(tu.ReadGetJSON(c, testDialClient, addr, sc), IsNil) c.Assert(sc.EnableAudit, Equals, false) + c.Assert(sc.EnableRateLimit, Equals, false) // test empty ms = map[string]interface{}{} @@ -124,6 +128,160 @@ func (s *testRateLimitConfigSuite) TearDownSuite(c *C) { s.cleanup() } +func (s *testRateLimitConfigSuite) TestUpdateRateLimitConfig(c *C) { + urlPrefix := fmt.Sprintf("%s%s/api/v1/service-middleware/config/rate-limit", s.svr.GetAddr(), apiPrefix) + + // test empty type + input := make(map[string]interface{}) + input["type"] = 123 + jsonBody, err := json.Marshal(input) + c.Assert(err, IsNil) + + err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, + tu.Status(c, http.StatusBadRequest), tu.StringEqual(c, "\"The type is empty.\"\n")) + c.Assert(err, IsNil) + // test invalid type + input = make(map[string]interface{}) + input["type"] = "url" + jsonBody, err = json.Marshal(input) + c.Assert(err, IsNil) + err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, + tu.Status(c, http.StatusBadRequest), tu.StringEqual(c, "\"The type is invalid.\"\n")) + c.Assert(err, IsNil) + + // test empty label + input = make(map[string]interface{}) + input["type"] = "label" + input["label"] = "" + jsonBody, err = json.Marshal(input) + c.Assert(err, IsNil) + err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, + tu.Status(c, http.StatusBadRequest), tu.StringEqual(c, "\"The label is empty.\"\n")) + c.Assert(err, IsNil) + // test no label matched + input = make(map[string]interface{}) + input["type"] = "label" + input["label"] = "TestLabel" + jsonBody, err = json.Marshal(input) + c.Assert(err, IsNil) + err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, + tu.Status(c, http.StatusBadRequest), tu.StringEqual(c, "\"There is no label matched.\"\n")) + c.Assert(err, IsNil) + + // test empty path + input = make(map[string]interface{}) + input["type"] = "path" + input["path"] = "" + jsonBody, err = json.Marshal(input) + c.Assert(err, IsNil) + err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, + tu.Status(c, http.StatusBadRequest), tu.StringEqual(c, "\"The path is empty.\"\n")) + c.Assert(err, IsNil) + + // test path but no label matched + input = make(map[string]interface{}) + input["type"] = "path" + input["path"] = "/pd/api/v1/test" + jsonBody, err = json.Marshal(input) + c.Assert(err, IsNil) + err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, + tu.Status(c, http.StatusBadRequest), tu.StringEqual(c, "\"There is no label matched.\"\n")) + c.Assert(err, IsNil) + + // no change + input = make(map[string]interface{}) + input["type"] = "label" + input["label"] = "GetHealthStatus" + jsonBody, err = json.Marshal(input) + c.Assert(err, IsNil) + err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, + tu.StatusOK(c), tu.StringEqual(c, "\"No changed.\"\n")) + c.Assert(err, IsNil) + + // change concurrency + input = make(map[string]interface{}) + input["type"] = "path" + input["path"] = "/pd/api/v1/health" + input["method"] = "GET" + input["concurrency"] = 100 + jsonBody, err = json.Marshal(input) + c.Assert(err, IsNil) + err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, + tu.StatusOK(c), tu.StringContain(c, "Concurrency limiter is changed.")) + c.Assert(err, IsNil) + input["concurrency"] = 0 + jsonBody, err = json.Marshal(input) + c.Assert(err, IsNil) + err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, + tu.StatusOK(c), tu.StringContain(c, "Concurrency limiter is deleted.")) + c.Assert(err, IsNil) + + // change qps + input = make(map[string]interface{}) + input["type"] = "path" + input["path"] = "/pd/api/v1/health" + input["method"] = "GET" + input["qps"] = 100 + jsonBody, err = json.Marshal(input) + c.Assert(err, IsNil) + err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, + tu.StatusOK(c), tu.StringContain(c, "QPS rate limiter is changed.")) + c.Assert(err, IsNil) + + input = make(map[string]interface{}) + input["type"] = "path" + input["path"] = "/pd/api/v1/health" + input["method"] = "GET" + input["qps"] = 0.3 + jsonBody, err = json.Marshal(input) + c.Assert(err, IsNil) + err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, + tu.StatusOK(c), tu.StringContain(c, "QPS rate limiter is changed.")) + c.Assert(err, IsNil) + c.Assert(s.svr.GetRateLimitConfig().LimiterConfig["GetHealthStatus"].QPSBurst, Equals, 1) + + input["qps"] = -1 + jsonBody, err = json.Marshal(input) + c.Assert(err, IsNil) + err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, + tu.StatusOK(c), tu.StringContain(c, "QPS rate limiter is deleted.")) + c.Assert(err, IsNil) + + // change both + input = make(map[string]interface{}) + input["type"] = "path" + input["path"] = "/pd/api/v1/debug/pprof/profile" + input["qps"] = 100 + input["concurrency"] = 100 + jsonBody, err = json.Marshal(input) + c.Assert(err, IsNil) + result := rateLimitResult{} + err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, + tu.StatusOK(c), tu.StringContain(c, "Concurrency limiter is changed."), + tu.StringContain(c, "QPS rate limiter is changed."), + tu.ExtractJSON(c, &result), + ) + c.Assert(result.LimiterConfig["Profile"].QPS, Equals, 100.) + c.Assert(result.LimiterConfig["Profile"].QPSBurst, Equals, 100) + c.Assert(result.LimiterConfig["Profile"].ConcurrencyLimit, Equals, uint64(100)) + c.Assert(err, IsNil) + + limiter := s.svr.GetServiceRateLimiter() + limiter.Update("SetRatelimitConfig", ratelimit.AddLabelAllowList()) + + // Allow list + input = make(map[string]interface{}) + input["type"] = "label" + input["label"] = "SetRatelimitConfig" + input["qps"] = 100 + input["concurrency"] = 100 + jsonBody, err = json.Marshal(input) + c.Assert(err, IsNil) + err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, + tu.StatusNotOK(c), tu.StringEqual(c, "\"This service is in allow list whose config can not be changed.\"\n")) + c.Assert(err, IsNil) +} + func (s *testRateLimitConfigSuite) TestConfigRateLimitSwitch(c *C) { addr := fmt.Sprintf("%s/service-middleware/config", s.urlPrefix) diff --git a/server/server.go b/server/server.go index c7902692552..871c2e60738 100644 --- a/server/server.go +++ b/server/server.go @@ -44,6 +44,7 @@ import ( "github.com/tikv/pd/pkg/errs" "github.com/tikv/pd/pkg/etcdutil" "github.com/tikv/pd/pkg/grpcutil" + "github.com/tikv/pd/pkg/jsonutil" "github.com/tikv/pd/pkg/logutil" "github.com/tikv/pd/pkg/ratelimit" "github.com/tikv/pd/pkg/systimemon" @@ -255,6 +256,7 @@ func CreateServer(ctx context.Context, cfg *config.Config, serviceBuilders ...Ha audit.NewLocalLogBackend(true), audit.NewPrometheusHistogramBackend(serviceAuditHistogram, false), } + s.serviceRateLimiter = ratelimit.NewLimiter() s.serviceAuditBackendLabels = make(map[string]*audit.BackendLabels) s.serviceRateLimiter = ratelimit.NewLimiter() s.serviceLabels = make(map[string][]apiutil.AccessPath) @@ -978,6 +980,34 @@ func (s *Server) SetAuditConfig(cfg config.AuditConfig) error { return nil } +// UpdateRateLimitConfig is used to update rate-limit config which will reserve old limiter-config +func (s *Server) UpdateRateLimitConfig(key, label string, value ratelimit.DimensionConfig) error { + cfg := s.GetServiceMiddlewareConfig() + rateLimitCfg := make(map[string]ratelimit.DimensionConfig) + for label, item := range cfg.LimiterConfig { + rateLimitCfg[label] = item + } + rateLimitCfg[label] = value + return s.UpdateRateLimit(&cfg.RateLimitConfig, key, &rateLimitCfg) +} + +// UpdateRateLimit is used to update rate-limit config which will overwrite limiter-config +func (s *Server) UpdateRateLimit(cfg *config.RateLimitConfig, key string, value interface{}) error { + updated, found, err := jsonutil.AddKeyValue(cfg, key, value) + if err != nil { + return err + } + + if !found { + return errors.Errorf("config item %s not found", key) + } + + if updated { + err = s.SetRateLimitConfig(*cfg) + } + return err +} + // GetRateLimitConfig gets the rate limit config information. func (s *Server) GetRateLimitConfig() *config.RateLimitConfig { return s.serviceMiddlewarePersistOptions.GetRateLimitConfig().Clone() @@ -1221,6 +1251,16 @@ func (s *Server) GetServiceRateLimiter() *ratelimit.Limiter { return s.serviceRateLimiter } +// IsInRateLimitAllowList returns whethis given service label is in allow lost +func (s *Server) IsInRateLimitAllowList(serviceLabel string) bool { + return s.serviceRateLimiter.IsInAllowList(serviceLabel) +} + +// UpdateServiceRateLimiter is used to update RateLimiter +func (s *Server) UpdateServiceRateLimiter(serviceLabel string, opts ...ratelimit.Option) ratelimit.UpdateStatus { + return s.serviceRateLimiter.Update(serviceLabel, opts...) +} + // GetClusterStatus gets cluster status. func (s *Server) GetClusterStatus() (*cluster.Status, error) { s.cluster.Lock() From ee302fcd827131afb06f5eec9973a779d254a7c8 Mon Sep 17 00:00:00 2001 From: LLThomas Date: Wed, 15 Jun 2022 21:08:33 +0800 Subject: [PATCH 51/82] *: fix some typos (#5165) ref tikv/pd#4820 Fix some typos. Signed-off-by: LLThomas --- server/server.go | 6 +++--- server/storage/hot_region_storage.go | 4 ++-- server/storage/hot_region_storage_test.go | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/server/server.go b/server/server.go index 871c2e60738..6e900580d66 100644 --- a/server/server.go +++ b/server/server.go @@ -1409,20 +1409,20 @@ func (s *Server) campaignLeader() { go s.member.KeepLeader(ctx) log.Info("campaign pd leader ok", zap.String("campaign-pd-leader-name", s.Name())) - alllocator, err := s.tsoAllocatorManager.GetAllocator(tso.GlobalDCLocation) + allocator, err := s.tsoAllocatorManager.GetAllocator(tso.GlobalDCLocation) if err != nil { log.Error("failed to get the global TSO allocator", errs.ZapError(err)) return } log.Info("initializing the global TSO allocator") - if err := alllocator.Initialize(0); err != nil { + if err := allocator.Initialize(0); err != nil { log.Error("failed to initialize the global TSO allocator", errs.ZapError(err)) return } defer func() { s.tsoAllocatorManager.ResetAllocatorGroup(tso.GlobalDCLocation) failpoint.Inject("updateAfterResetTSO", func() { - if err = alllocator.UpdateTSO(); err != nil { + if err = allocator.UpdateTSO(); err != nil { panic(err) } }) diff --git a/server/storage/hot_region_storage.go b/server/storage/hot_region_storage.go index 162b631ddb6..597a6fc83ce 100644 --- a/server/storage/hot_region_storage.go +++ b/server/storage/hot_region_storage.go @@ -144,13 +144,13 @@ func NewHotRegionsStorage( if err != nil { return nil, err } - hotRegionInfoCtx, hotRegionInfoCancle := context.WithCancel(ctx) + hotRegionInfoCtx, hotRegionInfoCancel := context.WithCancel(ctx) h := HotRegionStorage{ LevelDBKV: levelDB, ekm: ekm, batchHotInfo: make(map[string]*HistoryHotRegion), hotRegionInfoCtx: hotRegionInfoCtx, - hotRegionInfoCancel: hotRegionInfoCancle, + hotRegionInfoCancel: hotRegionInfoCancel, hotRegionStorageHandler: hotRegionStorageHandler, curReservedDays: hotRegionStorageHandler.GetHotRegionsReservedDays(), curInterval: hotRegionStorageHandler.GetHotRegionsWriteInterval(), diff --git a/server/storage/hot_region_storage_test.go b/server/storage/hot_region_storage_test.go index 29dc4140317..aa3b5b974b9 100644 --- a/server/storage/hot_region_storage_test.go +++ b/server/storage/hot_region_storage_test.go @@ -293,7 +293,7 @@ func newTestHotRegionStorage(pullInterval time.Duration, } packHotRegionInfo.pullInterval = pullInterval packHotRegionInfo.reservedDays = reservedDays - // delete data in between today and tomrrow + // delete data in between today and tomorrow hotRegionStorage, err = NewHotRegionsStorage(ctx, writePath, nil, packHotRegionInfo) if err != nil { From 2efa259d42faaf01efe66d3a694eed318c508ef8 Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Thu, 16 Jun 2022 15:52:34 +0800 Subject: [PATCH 52/82] metrics: delete the metrics instead of setting them to 0 (#5162) close tikv/pd#5163 Signed-off-by: Ryan Leung Co-authored-by: Ti Chi Robot --- server/api/config.go | 1 - server/cluster/coordinator.go | 18 +++++++++--------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/server/api/config.go b/server/api/config.go index b33dd5c5a97..7ed3a9de56c 100644 --- a/server/api/config.go +++ b/server/api/config.go @@ -371,7 +371,6 @@ func (h *confHandler) SetReplicationConfig(w http.ResponseWriter, r *http.Reques // @Summary Get label property config. // @Produce json // @Success 200 {object} config.LabelPropertyConfig -// @Failure 400 {string} string "The input is invalid." // @Router /config/label-property [get] func (h *confHandler) GetLabelPropertyConfig(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusOK, h.svr.GetLabelProperty()) diff --git a/server/cluster/coordinator.go b/server/cluster/coordinator.go index b3f72a3be5f..85c80365c9a 100644 --- a/server/cluster/coordinator.go +++ b/server/cluster/coordinator.go @@ -581,10 +581,10 @@ func collectHotMetrics(cluster *RaftCluster, stores []*core.StoreInfo, typ stati hotSpotStatusGauge.WithLabelValues(storeAddress, storeLabel, "total_"+kind+"_query_as_leader").Set(stat.TotalLoads[queryTyp]) hotSpotStatusGauge.WithLabelValues(storeAddress, storeLabel, "hot_"+kind+"_region_as_leader").Set(float64(stat.Count)) } else { - hotSpotStatusGauge.WithLabelValues(storeAddress, storeLabel, "total_"+kind+"_bytes_as_leader").Set(0) - hotSpotStatusGauge.WithLabelValues(storeAddress, storeLabel, "total_"+kind+"_keys_as_leader").Set(0) - hotSpotStatusGauge.WithLabelValues(storeAddress, storeLabel, "total_"+kind+"_query_as_leader").Set(0) - hotSpotStatusGauge.WithLabelValues(storeAddress, storeLabel, "hot_"+kind+"_region_as_leader").Set(0) + hotSpotStatusGauge.DeleteLabelValues(storeAddress, storeLabel, "total_"+kind+"_bytes_as_leader") + hotSpotStatusGauge.DeleteLabelValues(storeAddress, storeLabel, "total_"+kind+"_keys_as_leader") + hotSpotStatusGauge.DeleteLabelValues(storeAddress, storeLabel, "total_"+kind+"_query_as_leader") + hotSpotStatusGauge.DeleteLabelValues(storeAddress, storeLabel, "hot_"+kind+"_region_as_leader") } stat, ok = status.AsPeer[storeID] @@ -594,10 +594,10 @@ func collectHotMetrics(cluster *RaftCluster, stores []*core.StoreInfo, typ stati hotSpotStatusGauge.WithLabelValues(storeAddress, storeLabel, "total_"+kind+"_query_as_peer").Set(stat.TotalLoads[queryTyp]) hotSpotStatusGauge.WithLabelValues(storeAddress, storeLabel, "hot_"+kind+"_region_as_peer").Set(float64(stat.Count)) } else { - hotSpotStatusGauge.WithLabelValues(storeAddress, storeLabel, "total_"+kind+"_bytes_as_peer").Set(0) - hotSpotStatusGauge.WithLabelValues(storeAddress, storeLabel, "total_"+kind+"_keys_as_peer").Set(0) - hotSpotStatusGauge.WithLabelValues(storeAddress, storeLabel, "total_"+kind+"_query_as_peer").Set(0) - hotSpotStatusGauge.WithLabelValues(storeAddress, storeLabel, "hot_"+kind+"_region_as_peer").Set(0) + hotSpotStatusGauge.DeleteLabelValues(storeAddress, storeLabel, "total_"+kind+"_bytes_as_peer") + hotSpotStatusGauge.DeleteLabelValues(storeAddress, storeLabel, "total_"+kind+"_keys_as_peer") + hotSpotStatusGauge.DeleteLabelValues(storeAddress, storeLabel, "total_"+kind+"_query_as_peer") + hotSpotStatusGauge.DeleteLabelValues(storeAddress, storeLabel, "hot_"+kind+"_region_as_peer") } } } @@ -673,7 +673,7 @@ func (c *coordinator) removeScheduler(name string) error { } s.Stop() - schedulerStatusGauge.WithLabelValues(name, "allow").Set(0) + schedulerStatusGauge.DeleteLabelValues(name, "allow") delete(c.schedulers, name) return nil From e6fd11821f004001bf441fca46219ab1a86c37f3 Mon Sep 17 00:00:00 2001 From: Shirly Date: Thu, 16 Jun 2022 18:20:35 +0800 Subject: [PATCH 53/82] scheduler/balance_leader: fix data race in the function of clone for config (#5157) close tikv/pd#5156 Signed-off-by: shirly Co-authored-by: Ti Chi Robot --- server/schedulers/balance_leader.go | 4 ++- server/schedulers/balance_leader_test.go | 38 ++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) create mode 100644 server/schedulers/balance_leader_test.go diff --git a/server/schedulers/balance_leader.go b/server/schedulers/balance_leader.go index 1b1f1bd8a64..f5c9c667264 100644 --- a/server/schedulers/balance_leader.go +++ b/server/schedulers/balance_leader.go @@ -128,8 +128,10 @@ func (conf *balanceLeaderSchedulerConfig) validate() bool { func (conf *balanceLeaderSchedulerConfig) Clone() *balanceLeaderSchedulerConfig { conf.mu.RLock() defer conf.mu.RUnlock() + ranges := make([]core.KeyRange, len(conf.Ranges)) + copy(ranges, conf.Ranges) return &balanceLeaderSchedulerConfig{ - Ranges: conf.Ranges, + Ranges: ranges, Batch: conf.Batch, } } diff --git a/server/schedulers/balance_leader_test.go b/server/schedulers/balance_leader_test.go new file mode 100644 index 00000000000..a74709de640 --- /dev/null +++ b/server/schedulers/balance_leader_test.go @@ -0,0 +1,38 @@ +// Copyright 2022 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package schedulers + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBalanceLeaderSchedulerConfigClone(t *testing.T) { + re := require.New(t) + keyRanges1, _ := getKeyRanges([]string{"a", "b", "c", "d"}) + conf := &balanceLeaderSchedulerConfig{ + Ranges: keyRanges1, + Batch: 10, + } + conf2 := conf.Clone() + re.Equal(conf.Batch, conf2.Batch) + re.Equal(conf.Ranges, conf2.Ranges) + + keyRanges2, _ := getKeyRanges([]string{"e", "f", "g", "h"}) + // update conf2 + conf2.Ranges[1] = keyRanges2[1] + re.NotEqual(conf.Ranges, conf2.Ranges) +} From 9acd56ad305f42a1cb6670a7c04a5095c37e3609 Mon Sep 17 00:00:00 2001 From: LLThomas Date: Fri, 17 Jun 2022 10:54:34 +0800 Subject: [PATCH 54/82] =?UTF-8?q?server/schedulers:=20fix=20potential=20da?= =?UTF-8?q?ta=20race=20in=20the=20function=20of=20clone=20for=20grantHotRe?= =?UTF-8?q?gionS=E2=80=A6=20(#5173)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ref tikv/pd#5170 As the title says. Signed-off-by: LLThomas --- server/schedulers/grant_hot_region.go | 4 +++- server/schedulers/grant_leader.go | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/server/schedulers/grant_hot_region.go b/server/schedulers/grant_hot_region.go index e001371899e..4decd1b1340 100644 --- a/server/schedulers/grant_hot_region.go +++ b/server/schedulers/grant_hot_region.go @@ -123,8 +123,10 @@ func (conf *grantHotRegionSchedulerConfig) SetStoreLeaderID(id uint64) { func (conf *grantHotRegionSchedulerConfig) Clone() *grantHotRegionSchedulerConfig { conf.mu.RLock() defer conf.mu.RUnlock() + newStoreIDs := make([]uint64, len(conf.StoreIDs)) + copy(newStoreIDs, conf.StoreIDs) return &grantHotRegionSchedulerConfig{ - StoreIDs: conf.StoreIDs, + StoreIDs: newStoreIDs, StoreLeaderID: conf.StoreLeaderID, } } diff --git a/server/schedulers/grant_leader.go b/server/schedulers/grant_leader.go index 40dd5c8a073..845f2b100c7 100644 --- a/server/schedulers/grant_leader.go +++ b/server/schedulers/grant_leader.go @@ -102,8 +102,12 @@ func (conf *grantLeaderSchedulerConfig) BuildWithArgs(args []string) error { func (conf *grantLeaderSchedulerConfig) Clone() *grantLeaderSchedulerConfig { conf.mu.RLock() defer conf.mu.RUnlock() + newStoreIDWithRanges := make(map[uint64][]core.KeyRange) + for k, v := range conf.StoreIDWithRanges { + newStoreIDWithRanges[k] = v + } return &grantLeaderSchedulerConfig{ - StoreIDWithRanges: conf.StoreIDWithRanges, + StoreIDWithRanges: newStoreIDWithRanges, } } From cb23d6c48cf42d6543b220991b1a88a9ee64c580 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Fri, 17 Jun 2022 13:30:35 +0800 Subject: [PATCH 55/82] tests: testify the api and storage tests (#5166) ref tikv/pd#4813 Testify the api and storage tests. Signed-off-by: JmPotato Co-authored-by: Ti Chi Robot --- tests/pdctl/helper.go | 34 -- tests/server/api/api_test.go | 440 +++++++++--------- .../server/storage/hot_region_storage_test.go | 183 ++++---- 3 files changed, 307 insertions(+), 350 deletions(-) diff --git a/tests/pdctl/helper.go b/tests/pdctl/helper.go index c5aaf948aa2..775f0b40f15 100644 --- a/tests/pdctl/helper.go +++ b/tests/pdctl/helper.go @@ -21,7 +21,6 @@ import ( "sort" "github.com/gogo/protobuf/proto" - "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" "github.com/spf13/cobra" @@ -105,20 +104,6 @@ func MustPutStore(re *require.Assertions, svr *server.Server, store *metapb.Stor re.NoError(err) } -// MustPutStoreWithCheck is a temporary function for test purpose. -func MustPutStoreWithCheck(c *check.C, svr *server.Server, store *metapb.Store) { - store.Address = fmt.Sprintf("tikv%d", store.GetId()) - if len(store.Version) == 0 { - store.Version = versioninfo.MinSupportedVersion(versioninfo.Version2_0).String() - } - grpcServer := &server.GrpcServer{Server: svr} - _, err := grpcServer.PutStore(context.Background(), &pdpb.PutStoreRequest{ - Header: &pdpb.RequestHeader{ClusterId: svr.ClusterID()}, - Store: store, - }) - c.Assert(err, check.IsNil) -} - // MustPutRegion is used for test purpose. func MustPutRegion(re *require.Assertions, cluster *tests.TestCluster, regionID, storeID uint64, start, end []byte, opts ...core.RegionCreateOption) *core.RegionInfo { leader := &metapb.Peer{ @@ -138,25 +123,6 @@ func MustPutRegion(re *require.Assertions, cluster *tests.TestCluster, regionID, return r } -// MustPutRegionWithCheck is a temporary function for test purpose. -func MustPutRegionWithCheck(c *check.C, cluster *tests.TestCluster, regionID, storeID uint64, start, end []byte, opts ...core.RegionCreateOption) *core.RegionInfo { - leader := &metapb.Peer{ - Id: regionID, - StoreId: storeID, - } - metaRegion := &metapb.Region{ - Id: regionID, - StartKey: start, - EndKey: end, - Peers: []*metapb.Peer{leader}, - RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 1}, - } - r := core.NewRegionInfo(metaRegion, leader, opts...) - err := cluster.HandleRegionHeartbeat(r) - c.Assert(err, check.IsNil) - return r -} - func checkerWithNilAssert(re *require.Assertions) *assertutil.Checker { checker := assertutil.NewChecker(func() { re.FailNow("should be nil") diff --git a/tests/server/api/api_test.go b/tests/server/api/api_test.go index e462adae2ba..49f8026e2af 100644 --- a/tests/server/api/api_test.go +++ b/tests/server/api/api_test.go @@ -27,11 +27,12 @@ import ( "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/log" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/apiutil/serverapi" "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/pkg/typeutil" @@ -51,55 +52,47 @@ var dialClient = &http.Client{ }, } -func Test(t *testing.T) { - TestingT(t) -} - func TestMain(m *testing.M) { goleak.VerifyTestMain(m, testutil.LeakOptions...) } -var _ = Suite(&serverTestSuite{}) - -type serverTestSuite struct{} - -func (s *serverTestSuite) TestReconnect(c *C) { +func TestReconnect(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() cluster, err := tests.NewTestCluster(ctx, 3, func(conf *config.Config, serverName string) { conf.TickInterval = typeutil.Duration{Duration: 50 * time.Millisecond} conf.ElectionInterval = typeutil.Duration{Duration: 250 * time.Millisecond} }) - c.Assert(err, IsNil) + re.NoError(err) defer cluster.Destroy() - err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(cluster.RunInitialServers()) // Make connections to followers. // Make sure they proxy requests to the leader. leader := cluster.WaitLeader() for name, s := range cluster.GetServers() { if name != leader { - res, e := http.Get(s.GetConfig().AdvertiseClientUrls + "/pd/api/v1/version") - c.Assert(e, IsNil) + res, err := http.Get(s.GetConfig().AdvertiseClientUrls + "/pd/api/v1/version") + re.NoError(err) res.Body.Close() - c.Assert(res.StatusCode, Equals, http.StatusOK) + re.Equal(http.StatusOK, res.StatusCode) } } // Close the leader and wait for a new one. err = cluster.GetServer(leader).Stop() - c.Assert(err, IsNil) + re.NoError(err) newLeader := cluster.WaitLeader() - c.Assert(newLeader, Not(HasLen), 0) + re.NotEmpty(newLeader) // Make sure they proxy requests to the new leader. for name, s := range cluster.GetServers() { if name != leader { - testutil.WaitUntil(c, func() bool { - res, e := http.Get(s.GetConfig().AdvertiseClientUrls + "/pd/api/v1/version") - c.Assert(e, IsNil) + testutil.Eventually(re, func() bool { + res, err := http.Get(s.GetConfig().AdvertiseClientUrls + "/pd/api/v1/version") + re.NoError(err) defer res.Body.Close() return res.StatusCode == http.StatusOK }) @@ -107,15 +100,14 @@ func (s *serverTestSuite) TestReconnect(c *C) { } // Close the new leader and then we have only one node. - err = cluster.GetServer(newLeader).Stop() - c.Assert(err, IsNil) + re.NoError(cluster.GetServer(newLeader).Stop()) // Request will fail with no leader. for name, s := range cluster.GetServers() { if name != leader && name != newLeader { - testutil.WaitUntil(c, func() bool { + testutil.Eventually(re, func() bool { res, err := http.Get(s.GetConfig().AdvertiseClientUrls + "/pd/api/v1/version") - c.Assert(err, IsNil) + re.NoError(err) defer res.Body.Close() return res.StatusCode == http.StatusServiceUnavailable }) @@ -123,77 +115,80 @@ func (s *serverTestSuite) TestReconnect(c *C) { } } -var _ = Suite(&testMiddlewareSuite{}) - -type testMiddlewareSuite struct { +type middlewareTestSuite struct { + suite.Suite cleanup func() cluster *tests.TestCluster } -func (s *testMiddlewareSuite) SetUpSuite(c *C) { - c.Assert(failpoint.Enable("github.com/tikv/pd/server/api/enableFailpointAPI", "return(true)"), IsNil) +func TestMiddlewareTestSuite(t *testing.T) { + suite.Run(t, new(middlewareTestSuite)) +} + +func (suite *middlewareTestSuite) SetupSuite() { + suite.NoError(failpoint.Enable("github.com/tikv/pd/server/api/enableFailpointAPI", "return(true)")) ctx, cancel := context.WithCancel(context.Background()) - s.cleanup = cancel + suite.cleanup = cancel cluster, err := tests.NewTestCluster(ctx, 3) - c.Assert(err, IsNil) - c.Assert(cluster.RunInitialServers(), IsNil) - c.Assert(cluster.WaitLeader(), Not(HasLen), 0) - s.cluster = cluster + suite.NoError(err) + suite.NoError(cluster.RunInitialServers()) + suite.NotEmpty(cluster.WaitLeader()) + suite.cluster = cluster } -func (s *testMiddlewareSuite) TearDownSuite(c *C) { - c.Assert(failpoint.Disable("github.com/tikv/pd/server/api/enableFailpointAPI"), IsNil) - s.cleanup() - s.cluster.Destroy() +func (suite *middlewareTestSuite) TearDownSuite() { + suite.NoError(failpoint.Disable("github.com/tikv/pd/server/api/enableFailpointAPI")) + suite.cleanup() + suite.cluster.Destroy() } -func (s *testMiddlewareSuite) TestRequestInfoMiddleware(c *C) { - c.Assert(failpoint.Enable("github.com/tikv/pd/server/api/addRequestInfoMiddleware", "return(true)"), IsNil) - leader := s.cluster.GetServer(s.cluster.GetLeader()) +func (suite *middlewareTestSuite) TestRequestInfoMiddleware() { + suite.NoError(failpoint.Enable("github.com/tikv/pd/server/api/addRequestInfoMiddleware", "return(true)")) + leader := suite.cluster.GetServer(suite.cluster.GetLeader()) input := map[string]interface{}{ "enable-audit": "true", } data, err := json.Marshal(input) - c.Assert(err, IsNil) + suite.NoError(err) req, _ := http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) resp, err := dialClient.Do(req) - c.Assert(err, IsNil) + suite.NoError(err) resp.Body.Close() - c.Assert(leader.GetServer().GetServiceMiddlewarePersistOptions().IsAuditEnabled(), Equals, true) + suite.True(leader.GetServer().GetServiceMiddlewarePersistOptions().IsAuditEnabled()) labels := make(map[string]interface{}) labels["testkey"] = "testvalue" data, _ = json.Marshal(labels) resp, err = dialClient.Post(leader.GetAddr()+"/pd/api/v1/debug/pprof/profile?force=true", "application/json", bytes.NewBuffer(data)) - c.Assert(err, IsNil) + suite.NoError(err) _, err = io.ReadAll(resp.Body) resp.Body.Close() - c.Assert(err, IsNil) - c.Assert(resp.StatusCode, Equals, http.StatusOK) + suite.NoError(err) + suite.Equal(http.StatusOK, resp.StatusCode) - c.Assert(resp.Header.Get("service-label"), Equals, "Profile") - c.Assert(resp.Header.Get("url-param"), Equals, "{\"force\":[\"true\"]}") - c.Assert(resp.Header.Get("body-param"), Equals, "{\"testkey\":\"testvalue\"}") - c.Assert(resp.Header.Get("method"), Equals, "HTTP/1.1/POST:/pd/api/v1/debug/pprof/profile") - c.Assert(resp.Header.Get("component"), Equals, "anonymous") - c.Assert(resp.Header.Get("ip"), Equals, "127.0.0.1") + suite.Equal("Profile", resp.Header.Get("service-label")) + suite.Equal("{\"force\":[\"true\"]}", resp.Header.Get("url-param")) + suite.Equal("{\"testkey\":\"testvalue\"}", resp.Header.Get("body-param")) + suite.Equal("HTTP/1.1/POST:/pd/api/v1/debug/pprof/profile", resp.Header.Get("method")) + suite.Equal("anonymous", resp.Header.Get("component")) + suite.Equal("127.0.0.1", resp.Header.Get("ip")) input = map[string]interface{}{ "enable-audit": "false", } data, err = json.Marshal(input) - c.Assert(err, IsNil) + suite.NoError(err) req, _ = http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) resp, err = dialClient.Do(req) - c.Assert(err, IsNil) + suite.NoError(err) resp.Body.Close() - c.Assert(leader.GetServer().GetServiceMiddlewarePersistOptions().IsAuditEnabled(), Equals, false) + suite.False(leader.GetServer().GetServiceMiddlewarePersistOptions().IsAuditEnabled()) - header := mustRequestSuccess(c, leader.GetServer()) - c.Assert(header.Get("service-label"), Equals, "") + header := mustRequestSuccess(suite.Require(), leader.GetServer()) + suite.Equal("", header.Get("service-label")) - c.Assert(failpoint.Disable("github.com/tikv/pd/server/api/addRequestInfoMiddleware"), IsNil) + suite.NoError(failpoint.Disable("github.com/tikv/pd/server/api/addRequestInfoMiddleware")) } func BenchmarkDoRequestWithServiceMiddleware(b *testing.B) { @@ -248,96 +243,96 @@ func doTestRequest(srv *tests.TestServer) { resp.Body.Close() } -func (s *testMiddlewareSuite) TestAuditPrometheusBackend(c *C) { - leader := s.cluster.GetServer(s.cluster.GetLeader()) +func (suite *middlewareTestSuite) TestAuditPrometheusBackend() { + leader := suite.cluster.GetServer(suite.cluster.GetLeader()) input := map[string]interface{}{ "enable-audit": "true", } data, err := json.Marshal(input) - c.Assert(err, IsNil) + suite.NoError(err) req, _ := http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) resp, err := dialClient.Do(req) - c.Assert(err, IsNil) + suite.NoError(err) resp.Body.Close() - c.Assert(leader.GetServer().GetServiceMiddlewarePersistOptions().IsAuditEnabled(), Equals, true) + suite.True(leader.GetServer().GetServiceMiddlewarePersistOptions().IsAuditEnabled()) timeUnix := time.Now().Unix() - 20 req, _ = http.NewRequest("GET", fmt.Sprintf("%s/pd/api/v1/trend?from=%d", leader.GetAddr(), timeUnix), nil) resp, err = dialClient.Do(req) - c.Assert(err, IsNil) + suite.NoError(err) _, err = io.ReadAll(resp.Body) resp.Body.Close() - c.Assert(err, IsNil) + suite.NoError(err) req, _ = http.NewRequest("GET", leader.GetAddr()+"/metrics", nil) resp, err = dialClient.Do(req) - c.Assert(err, IsNil) + suite.NoError(err) defer resp.Body.Close() content, _ := io.ReadAll(resp.Body) output := string(content) - c.Assert(strings.Contains(output, "pd_service_audit_handling_seconds_count{component=\"anonymous\",method=\"HTTP\",service=\"GetTrend\"} 1"), Equals, true) + suite.Contains(output, "pd_service_audit_handling_seconds_count{component=\"anonymous\",method=\"HTTP\",service=\"GetTrend\"} 1") // resign to test persist config oldLeaderName := leader.GetServer().Name() leader.GetServer().GetMember().ResignEtcdLeader(leader.GetServer().Context(), oldLeaderName, "") - mustWaitLeader(c, s.cluster.GetServers()) - leader = s.cluster.GetServer(s.cluster.GetLeader()) + suite.mustWaitLeader() + leader = suite.cluster.GetServer(suite.cluster.GetLeader()) timeUnix = time.Now().Unix() - 20 req, _ = http.NewRequest("GET", fmt.Sprintf("%s/pd/api/v1/trend?from=%d", leader.GetAddr(), timeUnix), nil) resp, err = dialClient.Do(req) - c.Assert(err, IsNil) + suite.NoError(err) _, err = io.ReadAll(resp.Body) resp.Body.Close() - c.Assert(err, IsNil) + suite.NoError(err) req, _ = http.NewRequest("GET", leader.GetAddr()+"/metrics", nil) resp, err = dialClient.Do(req) - c.Assert(err, IsNil) + suite.NoError(err) defer resp.Body.Close() content, _ = io.ReadAll(resp.Body) output = string(content) - c.Assert(strings.Contains(output, "pd_service_audit_handling_seconds_count{component=\"anonymous\",method=\"HTTP\",service=\"GetTrend\"} 2"), Equals, true) + suite.Contains(output, "pd_service_audit_handling_seconds_count{component=\"anonymous\",method=\"HTTP\",service=\"GetTrend\"} 2") input = map[string]interface{}{ "enable-audit": "false", } data, err = json.Marshal(input) - c.Assert(err, IsNil) + suite.NoError(err) req, _ = http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) resp, err = dialClient.Do(req) - c.Assert(err, IsNil) + suite.NoError(err) resp.Body.Close() - c.Assert(leader.GetServer().GetServiceMiddlewarePersistOptions().IsAuditEnabled(), Equals, false) + suite.False(leader.GetServer().GetServiceMiddlewarePersistOptions().IsAuditEnabled()) } -func (s *testMiddlewareSuite) TestAuditLocalLogBackend(c *C) { +func (suite *middlewareTestSuite) TestAuditLocalLogBackend() { tempStdoutFile, _ := os.CreateTemp("/tmp", "pd_tests") cfg := &log.Config{} cfg.File.Filename = tempStdoutFile.Name() cfg.Level = "info" lg, p, _ := log.InitLogger(cfg) log.ReplaceGlobals(lg, p) - leader := s.cluster.GetServer(s.cluster.GetLeader()) + leader := suite.cluster.GetServer(suite.cluster.GetLeader()) input := map[string]interface{}{ "enable-audit": "true", } data, err := json.Marshal(input) - c.Assert(err, IsNil) + suite.NoError(err) req, _ := http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) resp, err := dialClient.Do(req) - c.Assert(err, IsNil) + suite.NoError(err) resp.Body.Close() - c.Assert(leader.GetServer().GetServiceMiddlewarePersistOptions().IsAuditEnabled(), Equals, true) + suite.True(leader.GetServer().GetServiceMiddlewarePersistOptions().IsAuditEnabled()) req, _ = http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/admin/log", strings.NewReader("\"info\"")) resp, err = dialClient.Do(req) - c.Assert(err, IsNil) + suite.NoError(err) _, err = io.ReadAll(resp.Body) resp.Body.Close() b, _ := os.ReadFile(tempStdoutFile.Name()) - c.Assert(strings.Contains(string(b), "Audit Log"), Equals, true) - c.Assert(err, IsNil) - c.Assert(resp.StatusCode, Equals, http.StatusOK) + suite.Contains(string(b), "Audit Log") + suite.NoError(err) + suite.Equal(http.StatusOK, resp.StatusCode) os.Remove(tempStdoutFile.Name()) } @@ -386,50 +381,54 @@ func BenchmarkDoRequestWithoutLocalLogAudit(b *testing.B) { cluster.Destroy() } -var _ = Suite(&testRedirectorSuite{}) - -type testRedirectorSuite struct { +type redirectorTestSuite struct { + suite.Suite cleanup func() cluster *tests.TestCluster } -func (s *testRedirectorSuite) SetUpSuite(c *C) { +func TestRedirectorTestSuite(t *testing.T) { + suite.Run(t, new(redirectorTestSuite)) +} + +func (suite *redirectorTestSuite) SetupSuite() { ctx, cancel := context.WithCancel(context.Background()) - s.cleanup = cancel + suite.cleanup = cancel cluster, err := tests.NewTestCluster(ctx, 3, func(conf *config.Config, serverName string) { conf.TickInterval = typeutil.Duration{Duration: 50 * time.Millisecond} conf.ElectionInterval = typeutil.Duration{Duration: 250 * time.Millisecond} }) - c.Assert(err, IsNil) - c.Assert(cluster.RunInitialServers(), IsNil) - c.Assert(cluster.WaitLeader(), Not(HasLen), 0) - s.cluster = cluster + suite.NoError(err) + suite.NoError(cluster.RunInitialServers()) + suite.NotEmpty(cluster.WaitLeader(), 0) + suite.cluster = cluster } -func (s *testRedirectorSuite) TearDownSuite(c *C) { - s.cleanup() - s.cluster.Destroy() +func (suite *redirectorTestSuite) TearDownSuite() { + suite.cleanup() + suite.cluster.Destroy() } -func (s *testRedirectorSuite) TestRedirect(c *C) { - leader := s.cluster.GetServer(s.cluster.GetLeader()) - c.Assert(leader, NotNil) - header := mustRequestSuccess(c, leader.GetServer()) +func (suite *redirectorTestSuite) TestRedirect() { + re := suite.Require() + leader := suite.cluster.GetServer(suite.cluster.GetLeader()) + suite.NotNil(leader) + header := mustRequestSuccess(re, leader.GetServer()) header.Del("Date") - for _, svr := range s.cluster.GetServers() { + for _, svr := range suite.cluster.GetServers() { if svr != leader { - h := mustRequestSuccess(c, svr.GetServer()) + h := mustRequestSuccess(re, svr.GetServer()) h.Del("Date") - c.Assert(header, DeepEquals, h) + suite.Equal(h, header) } } } -func (s *testRedirectorSuite) TestAllowFollowerHandle(c *C) { +func (suite *redirectorTestSuite) TestAllowFollowerHandle() { // Find a follower. var follower *server.Server - leader := s.cluster.GetServer(s.cluster.GetLeader()) - for _, svr := range s.cluster.GetServers() { + leader := suite.cluster.GetServer(suite.cluster.GetLeader()) + for _, svr := range suite.cluster.GetServers() { if svr != leader { follower = svr.GetServer() break @@ -438,22 +437,22 @@ func (s *testRedirectorSuite) TestAllowFollowerHandle(c *C) { addr := follower.GetAddr() + "/pd/api/v1/version" request, err := http.NewRequest(http.MethodGet, addr, nil) - c.Assert(err, IsNil) + suite.NoError(err) request.Header.Add(serverapi.AllowFollowerHandle, "true") resp, err := dialClient.Do(request) - c.Assert(err, IsNil) - c.Assert(resp.Header.Get(serverapi.RedirectorHeader), Equals, "") + suite.NoError(err) + suite.Equal("", resp.Header.Get(serverapi.RedirectorHeader)) defer resp.Body.Close() - c.Assert(resp.StatusCode, Equals, http.StatusOK) + suite.Equal(http.StatusOK, resp.StatusCode) _, err = io.ReadAll(resp.Body) - c.Assert(err, IsNil) + suite.NoError(err) } -func (s *testRedirectorSuite) TestNotLeader(c *C) { +func (suite *redirectorTestSuite) TestNotLeader() { // Find a follower. var follower *server.Server - leader := s.cluster.GetServer(s.cluster.GetLeader()) - for _, svr := range s.cluster.GetServers() { + leader := suite.cluster.GetServer(suite.cluster.GetLeader()) + for _, svr := range suite.cluster.GetServers() { if svr != leader { follower = svr.GetServer() break @@ -463,55 +462,52 @@ func (s *testRedirectorSuite) TestNotLeader(c *C) { addr := follower.GetAddr() + "/pd/api/v1/version" // Request to follower without redirectorHeader is OK. request, err := http.NewRequest(http.MethodGet, addr, nil) - c.Assert(err, IsNil) + suite.NoError(err) resp, err := dialClient.Do(request) - c.Assert(err, IsNil) + suite.NoError(err) defer resp.Body.Close() - c.Assert(resp.StatusCode, Equals, http.StatusOK) + suite.Equal(http.StatusOK, resp.StatusCode) _, err = io.ReadAll(resp.Body) - c.Assert(err, IsNil) + suite.NoError(err) // Request to follower with redirectorHeader will fail. request.RequestURI = "" request.Header.Set(serverapi.RedirectorHeader, "pd") resp1, err := dialClient.Do(request) - c.Assert(err, IsNil) + suite.NoError(err) defer resp1.Body.Close() - c.Assert(resp1.StatusCode, Not(Equals), http.StatusOK) + suite.NotEqual(http.StatusOK, resp1.StatusCode) _, err = io.ReadAll(resp1.Body) - c.Assert(err, IsNil) + suite.NoError(err) } -func mustRequestSuccess(c *C, s *server.Server) http.Header { +func mustRequestSuccess(re *require.Assertions, s *server.Server) http.Header { resp, err := dialClient.Get(s.GetAddr() + "/pd/api/v1/version") - c.Assert(err, IsNil) + re.NoError(err) defer resp.Body.Close() _, err = io.ReadAll(resp.Body) - c.Assert(err, IsNil) - c.Assert(resp.StatusCode, Equals, http.StatusOK) + re.NoError(err) + re.Equal(http.StatusOK, resp.StatusCode) return resp.Header } -var _ = Suite(&testProgressSuite{}) - -type testProgressSuite struct{} - -func (s *testProgressSuite) TestRemovingProgress(c *C) { - c.Assert(failpoint.Enable("github.com/tikv/pd/server/cluster/highFrequencyClusterJobs", `return(true)`), IsNil) +func TestRemovingProgress(t *testing.T) { + re := require.New(t) + re.NoError(failpoint.Enable("github.com/tikv/pd/server/cluster/highFrequencyClusterJobs", `return(true)`)) ctx, cancel := context.WithCancel(context.Background()) defer cancel() cluster, err := tests.NewTestCluster(ctx, 1, func(conf *config.Config, serverName string) { conf.Replication.MaxReplicas = 1 }) - c.Assert(err, IsNil) + re.NoError(err) defer cluster.Destroy() err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() leader := cluster.GetServer(cluster.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClient(c, leader.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leader.GetAddr()) clusterID := leader.GetClusterID() req := &pdpb.BootstrapRequest{ Header: testutil.NewRequestHeader(clusterID), @@ -519,7 +515,7 @@ func (s *testProgressSuite) TestRemovingProgress(c *C) { Region: &metapb.Region{Id: 2, Peers: []*metapb.Peer{{Id: 3, StoreId: 1, Role: metapb.PeerRole_Voter}}}, } _, err = grpcPDClient.Bootstrap(context.Background(), req) - c.Assert(err, IsNil) + re.NoError(err) stores := []*metapb.Store{ { Id: 1, @@ -542,92 +538,93 @@ func (s *testProgressSuite) TestRemovingProgress(c *C) { } for _, store := range stores { - pdctl.MustPutStoreWithCheck(c, leader.GetServer(), store) + pdctl.MustPutStore(re, leader.GetServer(), store) } - pdctl.MustPutRegionWithCheck(c, cluster, 1000, 1, []byte("a"), []byte("b"), core.SetApproximateSize(60)) - pdctl.MustPutRegionWithCheck(c, cluster, 1001, 2, []byte("c"), []byte("d"), core.SetApproximateSize(30)) - pdctl.MustPutRegionWithCheck(c, cluster, 1002, 1, []byte("e"), []byte("f"), core.SetApproximateSize(50)) - pdctl.MustPutRegionWithCheck(c, cluster, 1003, 2, []byte("g"), []byte("h"), core.SetApproximateSize(40)) + pdctl.MustPutRegion(re, cluster, 1000, 1, []byte("a"), []byte("b"), core.SetApproximateSize(60)) + pdctl.MustPutRegion(re, cluster, 1001, 2, []byte("c"), []byte("d"), core.SetApproximateSize(30)) + pdctl.MustPutRegion(re, cluster, 1002, 1, []byte("e"), []byte("f"), core.SetApproximateSize(50)) + pdctl.MustPutRegion(re, cluster, 1003, 2, []byte("g"), []byte("h"), core.SetApproximateSize(40)) // no store removing - output := sendRequest(c, leader.GetAddr()+"/pd/api/v1/stores/progress?action=removing", http.MethodGet, http.StatusNotFound) - c.Assert(strings.Contains((string(output)), "no progress found for the action"), IsTrue) - output = sendRequest(c, leader.GetAddr()+"/pd/api/v1/stores/progress?id=2", http.MethodGet, http.StatusNotFound) - c.Assert(strings.Contains((string(output)), "no progress found for the given store ID"), IsTrue) + output := sendRequest(re, leader.GetAddr()+"/pd/api/v1/stores/progress?action=removing", http.MethodGet, http.StatusNotFound) + re.Contains((string(output)), "no progress found for the action") + output = sendRequest(re, leader.GetAddr()+"/pd/api/v1/stores/progress?id=2", http.MethodGet, http.StatusNotFound) + re.Contains((string(output)), "no progress found for the given store ID") // remove store 1 and store 2 - _ = sendRequest(c, leader.GetAddr()+"/pd/api/v1/store/1", http.MethodDelete, http.StatusOK) - _ = sendRequest(c, leader.GetAddr()+"/pd/api/v1/store/2", http.MethodDelete, http.StatusOK) + _ = sendRequest(re, leader.GetAddr()+"/pd/api/v1/store/1", http.MethodDelete, http.StatusOK) + _ = sendRequest(re, leader.GetAddr()+"/pd/api/v1/store/2", http.MethodDelete, http.StatusOK) // size is not changed. - output = sendRequest(c, leader.GetAddr()+"/pd/api/v1/stores/progress?action=removing", http.MethodGet, http.StatusOK) + output = sendRequest(re, leader.GetAddr()+"/pd/api/v1/stores/progress?action=removing", http.MethodGet, http.StatusOK) var p api.Progress - c.Assert(json.Unmarshal(output, &p), IsNil) - c.Assert(p.Action, Equals, "removing") - c.Assert(p.Progress, Equals, 0.0) - c.Assert(p.CurrentSpeed, Equals, 0.0) - c.Assert(p.LeftSeconds, Equals, math.MaxFloat64) + re.NoError(json.Unmarshal(output, &p)) + re.Equal("removing", p.Action) + re.Equal(0.0, p.Progress) + re.Equal(0.0, p.CurrentSpeed) + re.Equal(math.MaxFloat64, p.LeftSeconds) // update size - pdctl.MustPutRegionWithCheck(c, cluster, 1000, 1, []byte("a"), []byte("b"), core.SetApproximateSize(20)) - pdctl.MustPutRegionWithCheck(c, cluster, 1001, 2, []byte("c"), []byte("d"), core.SetApproximateSize(10)) + pdctl.MustPutRegion(re, cluster, 1000, 1, []byte("a"), []byte("b"), core.SetApproximateSize(20)) + pdctl.MustPutRegion(re, cluster, 1001, 2, []byte("c"), []byte("d"), core.SetApproximateSize(10)) // is not prepared time.Sleep(2 * time.Second) - output = sendRequest(c, leader.GetAddr()+"/pd/api/v1/stores/progress?action=removing", http.MethodGet, http.StatusOK) - c.Assert(json.Unmarshal(output, &p), IsNil) - c.Assert(p.Action, Equals, "removing") - c.Assert(p.Progress, Equals, 0.0) - c.Assert(p.CurrentSpeed, Equals, 0.0) - c.Assert(p.LeftSeconds, Equals, math.MaxFloat64) + output = sendRequest(re, leader.GetAddr()+"/pd/api/v1/stores/progress?action=removing", http.MethodGet, http.StatusOK) + re.NoError(json.Unmarshal(output, &p)) + re.Equal("removing", p.Action) + re.Equal(0.0, p.Progress) + re.Equal(0.0, p.CurrentSpeed) + re.Equal(math.MaxFloat64, p.LeftSeconds) leader.GetRaftCluster().SetPrepared() time.Sleep(2 * time.Second) - output = sendRequest(c, leader.GetAddr()+"/pd/api/v1/stores/progress?action=removing", http.MethodGet, http.StatusOK) - c.Assert(json.Unmarshal(output, &p), IsNil) - c.Assert(p.Action, Equals, "removing") + output = sendRequest(re, leader.GetAddr()+"/pd/api/v1/stores/progress?action=removing", http.MethodGet, http.StatusOK) + re.NoError(json.Unmarshal(output, &p)) + re.Equal("removing", p.Action) // store 1: (60-20)/(60+50) ~= 0.36 // store 2: (30-10)/(30+40) ~= 0.28 // average progress ~= (0.36+0.28)/2 = 0.32 - c.Assert(fmt.Sprintf("%.2f", p.Progress), Equals, "0.32") + re.Equal("0.32", fmt.Sprintf("%.2f", p.Progress)) // store 1: 40/10s = 4 // store 2: 20/10s = 2 // average speed = (2+4)/2 = 33 - c.Assert(p.CurrentSpeed, Equals, 3.0) + re.Equal(3.0, p.CurrentSpeed) // store 1: (20+50)/4 = 17.5s // store 2: (10+40)/2 = 25s // average time = (17.5+25)/2 = 21.25s - c.Assert(p.LeftSeconds, Equals, 21.25) + re.Equal(21.25, p.LeftSeconds) - output = sendRequest(c, leader.GetAddr()+"/pd/api/v1/stores/progress?id=2", http.MethodGet, http.StatusOK) - c.Assert(json.Unmarshal(output, &p), IsNil) - c.Assert(p.Action, Equals, "removing") + output = sendRequest(re, leader.GetAddr()+"/pd/api/v1/stores/progress?id=2", http.MethodGet, http.StatusOK) + re.NoError(json.Unmarshal(output, &p)) + re.Equal("removing", p.Action) // store 2: (30-10)/(30+40) ~= 0.285 - c.Assert(fmt.Sprintf("%.2f", p.Progress), Equals, "0.29") + re.Equal("0.29", fmt.Sprintf("%.2f", p.Progress)) // store 2: 20/10s = 2 - c.Assert(p.CurrentSpeed, Equals, 2.0) + re.Equal(2.0, p.CurrentSpeed) // store 2: (10+40)/2 = 25s - c.Assert(p.LeftSeconds, Equals, 25.0) + re.Equal(25.0, p.LeftSeconds) - c.Assert(failpoint.Disable("github.com/tikv/pd/server/cluster/highFrequencyClusterJobs"), IsNil) + re.NoError(failpoint.Disable("github.com/tikv/pd/server/cluster/highFrequencyClusterJobs")) } -func (s *testProgressSuite) TestPreparingProgress(c *C) { - c.Assert(failpoint.Enable("github.com/tikv/pd/server/cluster/highFrequencyClusterJobs", `return(true)`), IsNil) +func TestPreparingProgress(t *testing.T) { + re := require.New(t) + re.NoError(failpoint.Enable("github.com/tikv/pd/server/cluster/highFrequencyClusterJobs", `return(true)`)) ctx, cancel := context.WithCancel(context.Background()) defer cancel() cluster, err := tests.NewTestCluster(ctx, 1, func(conf *config.Config, serverName string) { conf.Replication.MaxReplicas = 1 }) - c.Assert(err, IsNil) + re.NoError(err) defer cluster.Destroy() err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() leader := cluster.GetServer(cluster.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClient(c, leader.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leader.GetAddr()) clusterID := leader.GetClusterID() req := &pdpb.BootstrapRequest{ Header: testutil.NewRequestHeader(clusterID), @@ -635,7 +632,7 @@ func (s *testProgressSuite) TestPreparingProgress(c *C) { Region: &metapb.Region{Id: 2, Peers: []*metapb.Peer{{Id: 3, StoreId: 1, Role: metapb.PeerRole_Voter}}}, } _, err = grpcPDClient.Bootstrap(context.Background(), req) - c.Assert(err, IsNil) + re.NoError(err) stores := []*metapb.Store{ { Id: 1, @@ -675,80 +672,79 @@ func (s *testProgressSuite) TestPreparingProgress(c *C) { } for _, store := range stores { - pdctl.MustPutStoreWithCheck(c, leader.GetServer(), store) + pdctl.MustPutStore(re, leader.GetServer(), store) } for i := 0; i < 100; i++ { - pdctl.MustPutRegionWithCheck(c, cluster, uint64(i+1), uint64(i)%3+1, []byte(fmt.Sprintf("p%d", i)), []byte(fmt.Sprintf("%d", i+1)), core.SetApproximateSize(10)) + pdctl.MustPutRegion(re, cluster, uint64(i+1), uint64(i)%3+1, []byte(fmt.Sprintf("p%d", i)), []byte(fmt.Sprintf("%d", i+1)), core.SetApproximateSize(10)) } // no store preparing - output := sendRequest(c, leader.GetAddr()+"/pd/api/v1/stores/progress?action=preparing", http.MethodGet, http.StatusNotFound) - c.Assert(strings.Contains((string(output)), "no progress found for the action"), IsTrue) - output = sendRequest(c, leader.GetAddr()+"/pd/api/v1/stores/progress?id=4", http.MethodGet, http.StatusNotFound) - c.Assert(strings.Contains((string(output)), "no progress found for the given store ID"), IsTrue) + output := sendRequest(re, leader.GetAddr()+"/pd/api/v1/stores/progress?action=preparing", http.MethodGet, http.StatusNotFound) + re.Contains((string(output)), "no progress found for the action") + output = sendRequest(re, leader.GetAddr()+"/pd/api/v1/stores/progress?id=4", http.MethodGet, http.StatusNotFound) + re.Contains((string(output)), "no progress found for the given store ID") // is not prepared time.Sleep(2 * time.Second) - output = sendRequest(c, leader.GetAddr()+"/pd/api/v1/stores/progress?action=preparing", http.MethodGet, http.StatusNotFound) - c.Assert(strings.Contains((string(output)), "no progress found for the action"), IsTrue) - output = sendRequest(c, leader.GetAddr()+"/pd/api/v1/stores/progress?id=4", http.MethodGet, http.StatusNotFound) - c.Assert(strings.Contains((string(output)), "no progress found for the given store ID"), IsTrue) + output = sendRequest(re, leader.GetAddr()+"/pd/api/v1/stores/progress?action=preparing", http.MethodGet, http.StatusNotFound) + re.Contains((string(output)), "no progress found for the action") + output = sendRequest(re, leader.GetAddr()+"/pd/api/v1/stores/progress?id=4", http.MethodGet, http.StatusNotFound) + re.Contains((string(output)), "no progress found for the given store ID") // size is not changed. leader.GetRaftCluster().SetPrepared() time.Sleep(2 * time.Second) - output = sendRequest(c, leader.GetAddr()+"/pd/api/v1/stores/progress?action=preparing", http.MethodGet, http.StatusOK) + output = sendRequest(re, leader.GetAddr()+"/pd/api/v1/stores/progress?action=preparing", http.MethodGet, http.StatusOK) var p api.Progress - c.Assert(json.Unmarshal(output, &p), IsNil) - c.Assert(p.Action, Equals, "preparing") - c.Assert(p.Progress, Equals, 0.0) - c.Assert(p.CurrentSpeed, Equals, 0.0) - c.Assert(p.LeftSeconds, Equals, math.MaxFloat64) + re.NoError(json.Unmarshal(output, &p)) + re.Equal("preparing", p.Action) + re.Equal(0.0, p.Progress) + re.Equal(0.0, p.CurrentSpeed) + re.Equal(math.MaxFloat64, p.LeftSeconds) // update size - pdctl.MustPutRegionWithCheck(c, cluster, 1000, 4, []byte(fmt.Sprintf("%d", 1000)), []byte(fmt.Sprintf("%d", 1001)), core.SetApproximateSize(10)) - pdctl.MustPutRegionWithCheck(c, cluster, 1001, 5, []byte(fmt.Sprintf("%d", 1001)), []byte(fmt.Sprintf("%d", 1002)), core.SetApproximateSize(40)) + pdctl.MustPutRegion(re, cluster, 1000, 4, []byte(fmt.Sprintf("%d", 1000)), []byte(fmt.Sprintf("%d", 1001)), core.SetApproximateSize(10)) + pdctl.MustPutRegion(re, cluster, 1001, 5, []byte(fmt.Sprintf("%d", 1001)), []byte(fmt.Sprintf("%d", 1002)), core.SetApproximateSize(40)) time.Sleep(2 * time.Second) - output = sendRequest(c, leader.GetAddr()+"/pd/api/v1/stores/progress?action=preparing", http.MethodGet, http.StatusOK) - c.Assert(json.Unmarshal(output, &p), IsNil) - c.Assert(p.Action, Equals, "preparing") + output = sendRequest(re, leader.GetAddr()+"/pd/api/v1/stores/progress?action=preparing", http.MethodGet, http.StatusOK) + re.NoError(json.Unmarshal(output, &p)) + re.Equal("preparing", p.Action) // store 4: 10/(210*0.9) ~= 0.05 // store 5: 40/(210*0.9) ~= 0.21 // average progress ~= (0.05+0.21)/2 = 0.13 - c.Assert(fmt.Sprintf("%.2f", p.Progress), Equals, "0.13") + re.Equal("0.13", fmt.Sprintf("%.2f", p.Progress)) // store 4: 10/10s = 1 // store 5: 40/10s = 4 // average speed = (1+4)/2 = 2.5 - c.Assert(p.CurrentSpeed, Equals, 2.5) + re.Equal(2.5, p.CurrentSpeed) // store 4: 179/1 ~= 179 // store 5: 149/4 ~= 37.25 // average time ~= (179+37.25)/2 = 108.125 - c.Assert(p.LeftSeconds, Equals, 108.125) - - output = sendRequest(c, leader.GetAddr()+"/pd/api/v1/stores/progress?id=4", http.MethodGet, http.StatusOK) - c.Assert(json.Unmarshal(output, &p), IsNil) - c.Assert(p.Action, Equals, "preparing") - c.Assert(fmt.Sprintf("%.2f", p.Progress), Equals, "0.05") - c.Assert(p.CurrentSpeed, Equals, 1.0) - c.Assert(p.LeftSeconds, Equals, 179.0) - - c.Assert(failpoint.Disable("github.com/tikv/pd/server/cluster/highFrequencyClusterJobs"), IsNil) + re.Equal(108.125, p.LeftSeconds) + + output = sendRequest(re, leader.GetAddr()+"/pd/api/v1/stores/progress?id=4", http.MethodGet, http.StatusOK) + re.NoError(json.Unmarshal(output, &p)) + re.Equal("preparing", p.Action) + re.Equal("0.05", fmt.Sprintf("%.2f", p.Progress)) + re.Equal(1.0, p.CurrentSpeed) + re.Equal(179.0, p.LeftSeconds) + re.NoError(failpoint.Disable("github.com/tikv/pd/server/cluster/highFrequencyClusterJobs")) } -func sendRequest(c *C, url string, method string, statusCode int) []byte { +func sendRequest(re *require.Assertions, url string, method string, statusCode int) []byte { req, _ := http.NewRequest(method, url, nil) resp, err := dialClient.Do(req) - c.Assert(err, IsNil) - c.Assert(resp.StatusCode, Equals, statusCode) + re.NoError(err) + re.Equal(statusCode, resp.StatusCode) output, err := io.ReadAll(resp.Body) - c.Assert(err, IsNil) + re.NoError(err) resp.Body.Close() return output } -func mustWaitLeader(c *C, svrs map[string]*tests.TestServer) *server.Server { +func (suite *middlewareTestSuite) mustWaitLeader() *server.Server { var leader *server.Server - testutil.WaitUntil(c, func() bool { - for _, s := range svrs { + testutil.Eventually(suite.Require(), func() bool { + for _, s := range suite.cluster.GetServers() { if !s.GetServer().IsClosed() && s.GetServer().GetMember().IsLeader() { leader = s.GetServer() return true diff --git a/tests/server/storage/hot_region_storage_test.go b/tests/server/storage/hot_region_storage_test.go index 662f128dd1b..9432ceb0c77 100644 --- a/tests/server/storage/hot_region_storage_test.go +++ b/tests/server/storage/hot_region_storage_test.go @@ -19,9 +19,9 @@ import ( "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/core" "github.com/tikv/pd/server/statistics" @@ -30,15 +30,8 @@ import ( "github.com/tikv/pd/tests/pdctl" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&hotRegionHistorySuite{}) - -type hotRegionHistorySuite struct{} - -func (s *hotRegionHistorySuite) TestHotRegionStorage(c *C) { +func TestHotRegionStorage(t *testing.T) { + re := require.New(t) statistics.Denoising = false ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -49,9 +42,9 @@ func (s *hotRegionHistorySuite) TestHotRegionStorage(c *C) { cfg.Schedule.HotRegionsReservedDays = 1 }, ) - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() stores := []*metapb.Store{ { @@ -67,16 +60,16 @@ func (s *hotRegionHistorySuite) TestHotRegionStorage(c *C) { } leaderServer := cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) + re.NoError(leaderServer.BootstrapCluster()) for _, store := range stores { - pdctl.MustPutStoreWithCheck(c, leaderServer.GetServer(), store) + pdctl.MustPutStore(re, leaderServer.GetServer(), store) } defer cluster.Destroy() startTime := time.Now().UnixNano() / int64(time.Millisecond) - pdctl.MustPutRegionWithCheck(c, cluster, 1, 1, []byte("a"), []byte("b"), core.SetWrittenBytes(3000000000), core.SetReportInterval(statistics.WriteReportInterval)) - pdctl.MustPutRegionWithCheck(c, cluster, 2, 2, []byte("c"), []byte("d"), core.SetWrittenBytes(6000000000), core.SetReportInterval(statistics.WriteReportInterval)) - pdctl.MustPutRegionWithCheck(c, cluster, 3, 1, []byte("e"), []byte("f")) - pdctl.MustPutRegionWithCheck(c, cluster, 4, 2, []byte("g"), []byte("h")) + pdctl.MustPutRegion(re, cluster, 1, 1, []byte("a"), []byte("b"), core.SetWrittenBytes(3000000000), core.SetReportInterval(statistics.WriteReportInterval)) + pdctl.MustPutRegion(re, cluster, 2, 2, []byte("c"), []byte("d"), core.SetWrittenBytes(6000000000), core.SetReportInterval(statistics.WriteReportInterval)) + pdctl.MustPutRegion(re, cluster, 3, 1, []byte("e"), []byte("f")) + pdctl.MustPutRegion(re, cluster, 4, 2, []byte("g"), []byte("h")) storeStats := []*pdpb.StoreStats{ { StoreId: 1, @@ -108,39 +101,40 @@ func (s *hotRegionHistorySuite) TestHotRegionStorage(c *C) { hotRegionStorage := leaderServer.GetServer().GetHistoryHotRegionStorage() iter := hotRegionStorage.NewIterator([]string{storage.WriteType.String()}, startTime, endTime) next, err := iter.Next() - c.Assert(next, NotNil) - c.Assert(err, IsNil) - c.Assert(next.RegionID, Equals, uint64(1)) - c.Assert(next.StoreID, Equals, uint64(1)) - c.Assert(next.HotRegionType, Equals, storage.WriteType.String()) + re.NoError(err) + re.NotNil(next) + re.Equal(uint64(1), next.RegionID) + re.Equal(uint64(1), next.StoreID) + re.Equal(storage.WriteType.String(), next.HotRegionType) next, err = iter.Next() - c.Assert(next, NotNil) - c.Assert(err, IsNil) - c.Assert(next.RegionID, Equals, uint64(2)) - c.Assert(next.StoreID, Equals, uint64(2)) - c.Assert(next.HotRegionType, Equals, storage.WriteType.String()) + re.NoError(err) + re.NotNil(next) + re.Equal(uint64(2), next.RegionID) + re.Equal(uint64(2), next.StoreID) + re.Equal(storage.WriteType.String(), next.HotRegionType) next, err = iter.Next() - c.Assert(next, IsNil) - c.Assert(err, IsNil) + re.NoError(err) + re.Nil(next) iter = hotRegionStorage.NewIterator([]string{storage.ReadType.String()}, startTime, endTime) next, err = iter.Next() - c.Assert(next, NotNil) - c.Assert(err, IsNil) - c.Assert(next.RegionID, Equals, uint64(3)) - c.Assert(next.StoreID, Equals, uint64(1)) - c.Assert(next.HotRegionType, Equals, storage.ReadType.String()) + re.NoError(err) + re.NotNil(next) + re.Equal(uint64(3), next.RegionID) + re.Equal(uint64(1), next.StoreID) + re.Equal(storage.ReadType.String(), next.HotRegionType) next, err = iter.Next() - c.Assert(next, NotNil) - c.Assert(err, IsNil) - c.Assert(next.RegionID, Equals, uint64(4)) - c.Assert(next.StoreID, Equals, uint64(2)) - c.Assert(next.HotRegionType, Equals, storage.ReadType.String()) + re.NoError(err) + re.NotNil(next) + re.Equal(uint64(4), next.RegionID) + re.Equal(uint64(2), next.StoreID) + re.Equal(storage.ReadType.String(), next.HotRegionType) next, err = iter.Next() - c.Assert(next, IsNil) - c.Assert(err, IsNil) + re.NoError(err) + re.Nil(next) } -func (s *hotRegionHistorySuite) TestHotRegionStorageReservedDayConfigChange(c *C) { +func TestHotRegionStorageReservedDayConfigChange(t *testing.T) { + re := require.New(t) statistics.Denoising = false ctx, cancel := context.WithCancel(context.Background()) interval := 100 * time.Millisecond @@ -152,9 +146,9 @@ func (s *hotRegionHistorySuite) TestHotRegionStorageReservedDayConfigChange(c *C cfg.Schedule.HotRegionsReservedDays = 1 }, ) - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() stores := []*metapb.Store{ { @@ -170,46 +164,46 @@ func (s *hotRegionHistorySuite) TestHotRegionStorageReservedDayConfigChange(c *C } leaderServer := cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) + re.NoError(leaderServer.BootstrapCluster()) for _, store := range stores { - pdctl.MustPutStoreWithCheck(c, leaderServer.GetServer(), store) + pdctl.MustPutStore(re, leaderServer.GetServer(), store) } defer cluster.Destroy() startTime := time.Now().UnixNano() / int64(time.Millisecond) - pdctl.MustPutRegionWithCheck(c, cluster, 1, 1, []byte("a"), []byte("b"), core.SetWrittenBytes(3000000000), core.SetReportInterval(statistics.WriteReportInterval)) + pdctl.MustPutRegion(re, cluster, 1, 1, []byte("a"), []byte("b"), core.SetWrittenBytes(3000000000), core.SetReportInterval(statistics.WriteReportInterval)) // wait hot scheduler starts time.Sleep(5000 * time.Millisecond) endTime := time.Now().UnixNano() / int64(time.Millisecond) hotRegionStorage := leaderServer.GetServer().GetHistoryHotRegionStorage() iter := hotRegionStorage.NewIterator([]string{storage.WriteType.String()}, startTime, endTime) next, err := iter.Next() - c.Assert(next, NotNil) - c.Assert(err, IsNil) - c.Assert(next.RegionID, Equals, uint64(1)) - c.Assert(next.StoreID, Equals, uint64(1)) - c.Assert(next.HotRegionType, Equals, storage.WriteType.String()) + re.NoError(err) + re.NotNil(next) + re.Equal(uint64(1), next.RegionID) + re.Equal(uint64(1), next.StoreID) + re.Equal(storage.WriteType.String(), next.HotRegionType) next, err = iter.Next() - c.Assert(err, IsNil) - c.Assert(next, IsNil) + re.NoError(err) + re.Nil(next) schedule := leaderServer.GetConfig().Schedule // set reserved day to zero,close hot region storage schedule.HotRegionsReservedDays = 0 leaderServer.GetServer().SetScheduleConfig(schedule) time.Sleep(3 * interval) - pdctl.MustPutRegionWithCheck(c, cluster, 2, 2, []byte("c"), []byte("d"), core.SetWrittenBytes(6000000000), core.SetReportInterval(statistics.WriteReportInterval)) + pdctl.MustPutRegion(re, cluster, 2, 2, []byte("c"), []byte("d"), core.SetWrittenBytes(6000000000), core.SetReportInterval(statistics.WriteReportInterval)) time.Sleep(10 * interval) endTime = time.Now().UnixNano() / int64(time.Millisecond) hotRegionStorage = leaderServer.GetServer().GetHistoryHotRegionStorage() iter = hotRegionStorage.NewIterator([]string{storage.WriteType.String()}, startTime, endTime) next, err = iter.Next() - c.Assert(next, NotNil) - c.Assert(err, IsNil) - c.Assert(next.RegionID, Equals, uint64(1)) - c.Assert(next.StoreID, Equals, uint64(1)) - c.Assert(next.HotRegionType, Equals, storage.WriteType.String()) + re.NoError(err) + re.NotNil(next) + re.Equal(uint64(1), next.RegionID) + re.Equal(uint64(1), next.StoreID) + re.Equal(storage.WriteType.String(), next.HotRegionType) next, err = iter.Next() - c.Assert(err, IsNil) - c.Assert(next, IsNil) + re.NoError(err) + re.Nil(next) // set reserved day to one,open hot region storage schedule.HotRegionsReservedDays = 1 leaderServer.GetServer().SetScheduleConfig(schedule) @@ -218,20 +212,21 @@ func (s *hotRegionHistorySuite) TestHotRegionStorageReservedDayConfigChange(c *C hotRegionStorage = leaderServer.GetServer().GetHistoryHotRegionStorage() iter = hotRegionStorage.NewIterator([]string{storage.WriteType.String()}, startTime, endTime) next, err = iter.Next() - c.Assert(next, NotNil) - c.Assert(err, IsNil) - c.Assert(next.RegionID, Equals, uint64(1)) - c.Assert(next.StoreID, Equals, uint64(1)) - c.Assert(next.HotRegionType, Equals, storage.WriteType.String()) + re.NoError(err) + re.NotNil(next) + re.Equal(uint64(1), next.RegionID) + re.Equal(uint64(1), next.StoreID) + re.Equal(storage.WriteType.String(), next.HotRegionType) next, err = iter.Next() - c.Assert(next, NotNil) - c.Assert(err, IsNil) - c.Assert(next.RegionID, Equals, uint64(2)) - c.Assert(next.StoreID, Equals, uint64(2)) - c.Assert(next.HotRegionType, Equals, storage.WriteType.String()) + re.NoError(err) + re.NotNil(next) + re.Equal(uint64(2), next.RegionID) + re.Equal(uint64(2), next.StoreID) + re.Equal(storage.WriteType.String(), next.HotRegionType) } -func (s *hotRegionHistorySuite) TestHotRegionStorageWriteIntervalConfigChange(c *C) { +func TestHotRegionStorageWriteIntervalConfigChange(t *testing.T) { + re := require.New(t) statistics.Denoising = false ctx, cancel := context.WithCancel(context.Background()) interval := 100 * time.Millisecond @@ -243,9 +238,9 @@ func (s *hotRegionHistorySuite) TestHotRegionStorageWriteIntervalConfigChange(c cfg.Schedule.HotRegionsReservedDays = 1 }, ) - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() stores := []*metapb.Store{ { @@ -261,45 +256,45 @@ func (s *hotRegionHistorySuite) TestHotRegionStorageWriteIntervalConfigChange(c } leaderServer := cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) + re.NoError(leaderServer.BootstrapCluster()) for _, store := range stores { - pdctl.MustPutStoreWithCheck(c, leaderServer.GetServer(), store) + pdctl.MustPutStore(re, leaderServer.GetServer(), store) } defer cluster.Destroy() startTime := time.Now().UnixNano() / int64(time.Millisecond) - pdctl.MustPutRegionWithCheck(c, cluster, 1, 1, []byte("a"), []byte("b"), core.SetWrittenBytes(3000000000), core.SetReportInterval(statistics.WriteReportInterval)) + pdctl.MustPutRegion(re, cluster, 1, 1, []byte("a"), []byte("b"), core.SetWrittenBytes(3000000000), core.SetReportInterval(statistics.WriteReportInterval)) // wait hot scheduler starts time.Sleep(5000 * time.Millisecond) endTime := time.Now().UnixNano() / int64(time.Millisecond) hotRegionStorage := leaderServer.GetServer().GetHistoryHotRegionStorage() iter := hotRegionStorage.NewIterator([]string{storage.WriteType.String()}, startTime, endTime) next, err := iter.Next() - c.Assert(next, NotNil) - c.Assert(err, IsNil) - c.Assert(next.RegionID, Equals, uint64(1)) - c.Assert(next.StoreID, Equals, uint64(1)) - c.Assert(next.HotRegionType, Equals, storage.WriteType.String()) + re.NoError(err) + re.NotNil(next) + re.Equal(uint64(1), next.RegionID) + re.Equal(uint64(1), next.StoreID) + re.Equal(storage.WriteType.String(), next.HotRegionType) next, err = iter.Next() - c.Assert(err, IsNil) - c.Assert(next, IsNil) + re.NoError(err) + re.Nil(next) schedule := leaderServer.GetConfig().Schedule // set the time to 20 times the interval schedule.HotRegionsWriteInterval.Duration = 20 * interval leaderServer.GetServer().SetScheduleConfig(schedule) time.Sleep(3 * interval) - pdctl.MustPutRegionWithCheck(c, cluster, 2, 2, []byte("c"), []byte("d"), core.SetWrittenBytes(6000000000), core.SetReportInterval(statistics.WriteReportInterval)) + pdctl.MustPutRegion(re, cluster, 2, 2, []byte("c"), []byte("d"), core.SetWrittenBytes(6000000000), core.SetReportInterval(statistics.WriteReportInterval)) time.Sleep(10 * interval) endTime = time.Now().UnixNano() / int64(time.Millisecond) // it cant get new hot region because wait time smaller than hot region write interval hotRegionStorage = leaderServer.GetServer().GetHistoryHotRegionStorage() iter = hotRegionStorage.NewIterator([]string{storage.WriteType.String()}, startTime, endTime) next, err = iter.Next() - c.Assert(next, NotNil) - c.Assert(err, IsNil) - c.Assert(next.RegionID, Equals, uint64(1)) - c.Assert(next.StoreID, Equals, uint64(1)) - c.Assert(next.HotRegionType, Equals, storage.WriteType.String()) + re.NoError(err) + re.NotNil(next) + re.Equal(uint64(1), next.RegionID) + re.Equal(uint64(1), next.StoreID) + re.Equal(storage.WriteType.String(), next.HotRegionType) next, err = iter.Next() - c.Assert(err, IsNil) - c.Assert(next, IsNil) + re.NoError(err) + re.Nil(next) } From eb2ed76f35a8098a756e742c44196af8dd22b58c Mon Sep 17 00:00:00 2001 From: JmPotato Date: Fri, 17 Jun 2022 15:30:35 +0800 Subject: [PATCH 56/82] tests: testify the cluster tests (#5167) ref tikv/pd#4813 Testify the cluster tests. Signed-off-by: JmPotato Co-authored-by: Ti Chi Robot --- tests/server/cluster/cluster_test.go | 888 +++++++++++----------- tests/server/cluster/cluster_work_test.go | 83 +- 2 files changed, 490 insertions(+), 481 deletions(-) diff --git a/tests/server/cluster/cluster_test.go b/tests/server/cluster/cluster_test.go index f493ed21b00..5c3ea03827e 100644 --- a/tests/server/cluster/cluster_test.go +++ b/tests/server/cluster/cluster_test.go @@ -24,11 +24,11 @@ import ( "time" "github.com/coreos/go-semver/semver" - . "github.com/pingcap/check" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/kvproto/pkg/replication_modepb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/dashboard" "github.com/tikv/pd/pkg/mock/mockid" "github.com/tikv/pd/pkg/testutil" @@ -47,10 +47,6 @@ import ( "google.golang.org/grpc/status" ) -func Test(t *testing.T) { - TestingT(t) -} - const ( initEpochVersion uint64 = 1 initEpochConfVer uint64 = 1 @@ -59,73 +55,62 @@ const ( testStoreAddr = "127.0.0.1:0" ) -var _ = Suite(&clusterTestSuite{}) - -type clusterTestSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *clusterTestSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) - // to prevent GetStorage - dashboard.SetCheckInterval(30 * time.Minute) -} - -func (s *clusterTestSuite) TearDownSuite(c *C) { - s.cancel() -} - -func (s *clusterTestSuite) TestBootstrap(c *C) { - tc, err := tests.NewTestCluster(s.ctx, 1) +func TestBootstrap(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tc, err := tests.NewTestCluster(ctx, 1) defer tc.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = tc.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClient(c, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() // IsBootstrapped returns false. req := newIsBootstrapRequest(clusterID) resp, err := grpcPDClient.IsBootstrapped(context.Background(), req) - c.Assert(err, IsNil) - c.Assert(resp, NotNil) - c.Assert(resp.GetBootstrapped(), IsFalse) + re.NoError(err) + re.NotNil(resp) + re.False(resp.GetBootstrapped()) // Bootstrap the cluster. - bootstrapCluster(c, clusterID, grpcPDClient) + bootstrapCluster(re, clusterID, grpcPDClient) // IsBootstrapped returns true. req = newIsBootstrapRequest(clusterID) resp, err = grpcPDClient.IsBootstrapped(context.Background(), req) - c.Assert(err, IsNil) - c.Assert(resp.GetBootstrapped(), IsTrue) + re.NoError(err) + re.True(resp.GetBootstrapped()) // check bootstrapped error. reqBoot := newBootstrapRequest(clusterID) respBoot, err := grpcPDClient.Bootstrap(context.Background(), reqBoot) - c.Assert(err, IsNil) - c.Assert(respBoot.GetHeader().GetError(), NotNil) - c.Assert(respBoot.GetHeader().GetError().GetType(), Equals, pdpb.ErrorType_ALREADY_BOOTSTRAPPED) + re.NoError(err) + re.NotNil(respBoot.GetHeader().GetError()) + re.Equal(pdpb.ErrorType_ALREADY_BOOTSTRAPPED, respBoot.GetHeader().GetError().GetType()) } -func (s *clusterTestSuite) TestDamagedRegion(c *C) { - tc, err := tests.NewTestCluster(s.ctx, 1) +func TestDamagedRegion(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tc, err := tests.NewTestCluster(ctx, 1) defer tc.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = tc.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClient(c, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() - bootstrapCluster(c, clusterID, grpcPDClient) + bootstrapCluster(re, clusterID, grpcPDClient) rc := leaderServer.GetRaftCluster() region := &metapb.Region{ @@ -142,7 +127,7 @@ func (s *clusterTestSuite) TestDamagedRegion(c *C) { // To put region. regionInfo := core.NewRegionInfo(region, region.Peers[0], core.SetApproximateSize(30)) err = tc.HandleRegionHeartbeat(regionInfo) - c.Assert(err, IsNil) + re.NoError(err) stores := []*pdpb.PutStoreRequest{ { @@ -175,7 +160,7 @@ func (s *clusterTestSuite) TestDamagedRegion(c *C) { svr := &server.GrpcServer{Server: leaderServer.GetServer()} for _, store := range stores { _, err = svr.PutStore(context.Background(), store) - c.Assert(err, IsNil) + re.NoError(err) } // To validate remove peer op be added. @@ -183,50 +168,53 @@ func (s *clusterTestSuite) TestDamagedRegion(c *C) { Header: testutil.NewRequestHeader(clusterID), Stats: &pdpb.StoreStats{StoreId: 2, DamagedRegionsId: []uint64{10}}, } - c.Assert(rc.GetOperatorController().OperatorCount(operator.OpAdmin), Equals, uint64(0)) + re.Equal(uint64(0), rc.GetOperatorController().OperatorCount(operator.OpAdmin)) _, err1 := grpcPDClient.StoreHeartbeat(context.Background(), req1) - c.Assert(err1, IsNil) - c.Assert(rc.GetOperatorController().OperatorCount(operator.OpAdmin), Equals, uint64(1)) + re.NoError(err1) + re.Equal(uint64(1), rc.GetOperatorController().OperatorCount(operator.OpAdmin)) } -func (s *clusterTestSuite) TestGetPutConfig(c *C) { - tc, err := tests.NewTestCluster(s.ctx, 1) +func TestGetPutConfig(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tc, err := tests.NewTestCluster(ctx, 1) defer tc.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = tc.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClient(c, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() - bootstrapCluster(c, clusterID, grpcPDClient) + bootstrapCluster(re, clusterID, grpcPDClient) rc := leaderServer.GetRaftCluster() - c.Assert(rc, NotNil) + re.NotNil(rc) // Get region. - region := getRegion(c, clusterID, grpcPDClient, []byte("abc")) - c.Assert(region.GetPeers(), HasLen, 1) + region := getRegion(re, clusterID, grpcPDClient, []byte("abc")) + re.Len(region.GetPeers(), 1) peer := region.GetPeers()[0] // Get region by id. - regionByID := getRegionByID(c, clusterID, grpcPDClient, region.GetId()) - c.Assert(region, DeepEquals, regionByID) + regionByID := getRegionByID(re, clusterID, grpcPDClient, region.GetId()) + re.Equal(regionByID, region) r := core.NewRegionInfo(region, region.Peers[0], core.SetApproximateSize(30)) err = tc.HandleRegionHeartbeat(r) - c.Assert(err, IsNil) + re.NoError(err) // Get store. storeID := peer.GetStoreId() - store := getStore(c, clusterID, grpcPDClient, storeID) + store := getStore(re, clusterID, grpcPDClient, storeID) // Update store. store.Address = "127.0.0.1:1" - testPutStore(c, clusterID, rc, grpcPDClient, store) + testPutStore(re, clusterID, rc, grpcPDClient, store) // Remove store. - testRemoveStore(c, clusterID, rc, grpcPDClient, store) + testRemoveStore(re, clusterID, rc, grpcPDClient, store) // Update cluster config. req := &pdpb.PutClusterConfigRequest{ @@ -237,73 +225,73 @@ func (s *clusterTestSuite) TestGetPutConfig(c *C) { }, } resp, err := grpcPDClient.PutClusterConfig(context.Background(), req) - c.Assert(err, IsNil) - c.Assert(resp, NotNil) - meta := getClusterConfig(c, clusterID, grpcPDClient) - c.Assert(meta.GetMaxPeerCount(), Equals, uint32(5)) + re.NoError(err) + re.NotNil(resp) + meta := getClusterConfig(re, clusterID, grpcPDClient) + re.Equal(uint32(5), meta.GetMaxPeerCount()) } -func testPutStore(c *C, clusterID uint64, rc *cluster.RaftCluster, grpcPDClient pdpb.PDClient, store *metapb.Store) { +func testPutStore(re *require.Assertions, clusterID uint64, rc *cluster.RaftCluster, grpcPDClient pdpb.PDClient, store *metapb.Store) { // Update store. _, err := putStore(grpcPDClient, clusterID, store) - c.Assert(err, IsNil) - updatedStore := getStore(c, clusterID, grpcPDClient, store.GetId()) - c.Assert(updatedStore, DeepEquals, store) + re.NoError(err) + updatedStore := getStore(re, clusterID, grpcPDClient, store.GetId()) + re.Equal(store, updatedStore) // Update store again. _, err = putStore(grpcPDClient, clusterID, store) - c.Assert(err, IsNil) + re.NoError(err) rc.GetAllocator().Alloc() id, err := rc.GetAllocator().Alloc() - c.Assert(err, IsNil) + re.NoError(err) // Put new store with a duplicated address when old store is up will fail. _, err = putStore(grpcPDClient, clusterID, newMetaStore(id, store.GetAddress(), "2.1.0", metapb.StoreState_Up, getTestDeployPath(id))) - c.Assert(err, NotNil) + re.Error(err) id, err = rc.GetAllocator().Alloc() - c.Assert(err, IsNil) + re.NoError(err) // Put new store with a duplicated address when old store is offline will fail. - resetStoreState(c, rc, store.GetId(), metapb.StoreState_Offline) + resetStoreState(re, rc, store.GetId(), metapb.StoreState_Offline) _, err = putStore(grpcPDClient, clusterID, newMetaStore(id, store.GetAddress(), "2.1.0", metapb.StoreState_Up, getTestDeployPath(id))) - c.Assert(err, NotNil) + re.Error(err) id, err = rc.GetAllocator().Alloc() - c.Assert(err, IsNil) + re.NoError(err) // Put new store with a duplicated address when old store is tombstone is OK. - resetStoreState(c, rc, store.GetId(), metapb.StoreState_Tombstone) + resetStoreState(re, rc, store.GetId(), metapb.StoreState_Tombstone) rc.GetStore(store.GetId()) _, err = putStore(grpcPDClient, clusterID, newMetaStore(id, store.GetAddress(), "2.1.0", metapb.StoreState_Up, getTestDeployPath(id))) - c.Assert(err, IsNil) + re.NoError(err) id, err = rc.GetAllocator().Alloc() - c.Assert(err, IsNil) + re.NoError(err) deployPath := getTestDeployPath(id) // Put a new store. _, err = putStore(grpcPDClient, clusterID, newMetaStore(id, testMetaStoreAddr, "2.1.0", metapb.StoreState_Up, deployPath)) - c.Assert(err, IsNil) + re.NoError(err) s := rc.GetStore(id).GetMeta() - c.Assert(s.DeployPath, Equals, deployPath) + re.Equal(deployPath, s.DeployPath) deployPath = fmt.Sprintf("move/test/store%d", id) _, err = putStore(grpcPDClient, clusterID, newMetaStore(id, testMetaStoreAddr, "2.1.0", metapb.StoreState_Up, deployPath)) - c.Assert(err, IsNil) + re.NoError(err) s = rc.GetStore(id).GetMeta() - c.Assert(s.DeployPath, Equals, deployPath) + re.Equal(deployPath, s.DeployPath) // Put an existed store with duplicated address with other old stores. - resetStoreState(c, rc, store.GetId(), metapb.StoreState_Up) + resetStoreState(re, rc, store.GetId(), metapb.StoreState_Up) _, err = putStore(grpcPDClient, clusterID, newMetaStore(store.GetId(), testMetaStoreAddr, "2.1.0", metapb.StoreState_Up, getTestDeployPath(store.GetId()))) - c.Assert(err, NotNil) + re.Error(err) } func getTestDeployPath(storeID uint64) string { return fmt.Sprintf("test/store%d", storeID) } -func resetStoreState(c *C, rc *cluster.RaftCluster, storeID uint64, state metapb.StoreState) { +func resetStoreState(re *require.Assertions, rc *cluster.RaftCluster, storeID uint64, state metapb.StoreState) { store := rc.GetStore(storeID) - c.Assert(store, NotNil) + re.NotNil(store) newStore := store.Clone(core.OfflineStore(false)) if state == metapb.StoreState_Up { newStore = newStore.Clone(core.UpStore()) @@ -319,7 +307,7 @@ func resetStoreState(c *C, rc *cluster.RaftCluster, storeID uint64, state metapb } } -func testStateAndLimit(c *C, clusterID uint64, rc *cluster.RaftCluster, grpcPDClient pdpb.PDClient, store *metapb.Store, beforeState metapb.StoreState, run func(*cluster.RaftCluster) error, expectStates ...metapb.StoreState) { +func testStateAndLimit(re *require.Assertions, clusterID uint64, rc *cluster.RaftCluster, grpcPDClient pdpb.PDClient, store *metapb.Store, beforeState metapb.StoreState, run func(*cluster.RaftCluster) error, expectStates ...metapb.StoreState) { // prepare storeID := store.GetId() oc := rc.GetOperatorController() @@ -330,68 +318,68 @@ func testStateAndLimit(c *C, clusterID uint64, rc *cluster.RaftCluster, grpcPDCl op = operator.NewTestOperator(2, &metapb.RegionEpoch{}, operator.OpRegion, operator.RemovePeer{FromStore: storeID}) oc.AddOperator(op) - resetStoreState(c, rc, store.GetId(), beforeState) + resetStoreState(re, rc, store.GetId(), beforeState) _, isOKBefore := rc.GetAllStoresLimit()[storeID] // run err := run(rc) // judge _, isOKAfter := rc.GetAllStoresLimit()[storeID] if len(expectStates) != 0 { - c.Assert(err, IsNil) + re.NoError(err) expectState := expectStates[0] - c.Assert(getStore(c, clusterID, grpcPDClient, storeID).GetState(), Equals, expectState) + re.Equal(expectState, getStore(re, clusterID, grpcPDClient, storeID).GetState()) if expectState == metapb.StoreState_Offline { - c.Assert(isOKAfter, IsTrue) + re.True(isOKAfter) } else if expectState == metapb.StoreState_Tombstone { - c.Assert(isOKAfter, IsFalse) + re.False(isOKAfter) } } else { - c.Assert(err, NotNil) - c.Assert(isOKBefore, Equals, isOKAfter) + re.Error(err) + re.Equal(isOKAfter, isOKBefore) } } -func testRemoveStore(c *C, clusterID uint64, rc *cluster.RaftCluster, grpcPDClient pdpb.PDClient, store *metapb.Store) { +func testRemoveStore(re *require.Assertions, clusterID uint64, rc *cluster.RaftCluster, grpcPDClient pdpb.PDClient, store *metapb.Store) { rc.GetOpts().SetMaxReplicas(2) defer rc.GetOpts().SetMaxReplicas(3) { beforeState := metapb.StoreState_Up // When store is up // Case 1: RemoveStore should be OK; - testStateAndLimit(c, clusterID, rc, grpcPDClient, store, beforeState, func(cluster *cluster.RaftCluster) error { + testStateAndLimit(re, clusterID, rc, grpcPDClient, store, beforeState, func(cluster *cluster.RaftCluster) error { return cluster.RemoveStore(store.GetId(), false) }, metapb.StoreState_Offline) // Case 2: RemoveStore with physically destroyed should be OK; - testStateAndLimit(c, clusterID, rc, grpcPDClient, store, beforeState, func(cluster *cluster.RaftCluster) error { + testStateAndLimit(re, clusterID, rc, grpcPDClient, store, beforeState, func(cluster *cluster.RaftCluster) error { return cluster.RemoveStore(store.GetId(), true) }, metapb.StoreState_Offline) } { beforeState := metapb.StoreState_Offline // When store is offline // Case 1: RemoveStore should be OK; - testStateAndLimit(c, clusterID, rc, grpcPDClient, store, beforeState, func(cluster *cluster.RaftCluster) error { + testStateAndLimit(re, clusterID, rc, grpcPDClient, store, beforeState, func(cluster *cluster.RaftCluster) error { return cluster.RemoveStore(store.GetId(), false) }, metapb.StoreState_Offline) // Case 2: remove store with physically destroyed should be success - testStateAndLimit(c, clusterID, rc, grpcPDClient, store, beforeState, func(cluster *cluster.RaftCluster) error { + testStateAndLimit(re, clusterID, rc, grpcPDClient, store, beforeState, func(cluster *cluster.RaftCluster) error { return cluster.RemoveStore(store.GetId(), true) }, metapb.StoreState_Offline) } { beforeState := metapb.StoreState_Tombstone // When store is tombstone // Case 1: RemoveStore should should fail; - testStateAndLimit(c, clusterID, rc, grpcPDClient, store, beforeState, func(cluster *cluster.RaftCluster) error { + testStateAndLimit(re, clusterID, rc, grpcPDClient, store, beforeState, func(cluster *cluster.RaftCluster) error { return cluster.RemoveStore(store.GetId(), false) }) // Case 2: RemoveStore with physically destroyed should fail; - testStateAndLimit(c, clusterID, rc, grpcPDClient, store, beforeState, func(cluster *cluster.RaftCluster) error { + testStateAndLimit(re, clusterID, rc, grpcPDClient, store, beforeState, func(cluster *cluster.RaftCluster) error { return cluster.RemoveStore(store.GetId(), true) }) } { // Put after removed should return tombstone error. resp, err := putStore(grpcPDClient, clusterID, store) - c.Assert(err, IsNil) - c.Assert(resp.GetHeader().GetError().GetType(), Equals, pdpb.ErrorType_STORE_TOMBSTONE) + re.NoError(err) + re.Equal(pdpb.ErrorType_STORE_TOMBSTONE, resp.GetHeader().GetError().GetType()) } { // Update after removed should return tombstone error. @@ -400,182 +388,196 @@ func testRemoveStore(c *C, clusterID uint64, rc *cluster.RaftCluster, grpcPDClie Stats: &pdpb.StoreStats{StoreId: store.GetId()}, } resp, err := grpcPDClient.StoreHeartbeat(context.Background(), req) - c.Assert(err, IsNil) - c.Assert(resp.GetHeader().GetError().GetType(), Equals, pdpb.ErrorType_STORE_TOMBSTONE) + re.NoError(err) + re.Equal(pdpb.ErrorType_STORE_TOMBSTONE, resp.GetHeader().GetError().GetType()) } } // Make sure PD will not panic if it start and stop again and again. -func (s *clusterTestSuite) TestRaftClusterRestart(c *C) { - tc, err := tests.NewTestCluster(s.ctx, 1) +func TestRaftClusterRestart(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tc, err := tests.NewTestCluster(ctx, 1) defer tc.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = tc.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClient(c, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() - bootstrapCluster(c, clusterID, grpcPDClient) + bootstrapCluster(re, clusterID, grpcPDClient) rc := leaderServer.GetRaftCluster() - c.Assert(rc, NotNil) + re.NotNil(rc) rc.Stop() err = rc.Start(leaderServer.GetServer()) - c.Assert(err, IsNil) + re.NoError(err) rc = leaderServer.GetRaftCluster() - c.Assert(rc, NotNil) + re.NotNil(rc) rc.Stop() } // Make sure PD will not deadlock if it start and stop again and again. -func (s *clusterTestSuite) TestRaftClusterMultipleRestart(c *C) { - tc, err := tests.NewTestCluster(s.ctx, 1) +func TestRaftClusterMultipleRestart(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tc, err := tests.NewTestCluster(ctx, 1) defer tc.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = tc.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClient(c, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() - bootstrapCluster(c, clusterID, grpcPDClient) + bootstrapCluster(re, clusterID, grpcPDClient) // add an offline store storeID, err := leaderServer.GetAllocator().Alloc() - c.Assert(err, IsNil) + re.NoError(err) store := newMetaStore(storeID, "127.0.0.1:4", "2.1.0", metapb.StoreState_Offline, getTestDeployPath(storeID)) rc := leaderServer.GetRaftCluster() - c.Assert(rc, NotNil) + re.NotNil(rc) err = rc.PutStore(store) - c.Assert(err, IsNil) - c.Assert(tc, NotNil) + re.NoError(err) + re.NotNil(tc) // let the job run at small interval - c.Assert(failpoint.Enable("github.com/tikv/pd/server/cluster/highFrequencyClusterJobs", `return(true)`), IsNil) + re.NoError(failpoint.Enable("github.com/tikv/pd/server/cluster/highFrequencyClusterJobs", `return(true)`)) for i := 0; i < 100; i++ { err = rc.Start(leaderServer.GetServer()) - c.Assert(err, IsNil) + re.NoError(err) time.Sleep(time.Millisecond) rc = leaderServer.GetRaftCluster() - c.Assert(rc, NotNil) + re.NotNil(rc) rc.Stop() } - c.Assert(failpoint.Disable("github.com/tikv/pd/server/cluster/highFrequencyClusterJobs"), IsNil) + re.NoError(failpoint.Disable("github.com/tikv/pd/server/cluster/highFrequencyClusterJobs")) } func newMetaStore(storeID uint64, addr, version string, state metapb.StoreState, deployPath string) *metapb.Store { return &metapb.Store{Id: storeID, Address: addr, Version: version, State: state, DeployPath: deployPath} } -func (s *clusterTestSuite) TestGetPDMembers(c *C) { - tc, err := tests.NewTestCluster(s.ctx, 1) +func TestGetPDMembers(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tc, err := tests.NewTestCluster(ctx, 1) defer tc.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = tc.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClient(c, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() req := &pdpb.GetMembersRequest{Header: testutil.NewRequestHeader(clusterID)} resp, err := grpcPDClient.GetMembers(context.Background(), req) - c.Assert(err, IsNil) + re.NoError(err) // A more strict test can be found at api/member_test.go - c.Assert(resp.GetMembers(), Not(HasLen), 0) + re.NotEmpty(resp.GetMembers()) } -func (s *clusterTestSuite) TestNotLeader(c *C) { - tc, err := tests.NewTestCluster(s.ctx, 2) +func TestNotLeader(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tc, err := tests.NewTestCluster(ctx, 2) defer tc.Destroy() - c.Assert(err, IsNil) - c.Assert(tc.RunInitialServers(), IsNil) - + re.NoError(err) + re.NoError(tc.RunInitialServers()) tc.WaitLeader() followerServer := tc.GetServer(tc.GetFollower()) - grpcPDClient := testutil.MustNewGrpcClient(c, followerServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, followerServer.GetAddr()) clusterID := followerServer.GetClusterID() req := &pdpb.AllocIDRequest{Header: testutil.NewRequestHeader(clusterID)} resp, err := grpcPDClient.AllocID(context.Background(), req) - c.Assert(resp, IsNil) + re.Nil(resp) grpcStatus, ok := status.FromError(err) - c.Assert(ok, IsTrue) - c.Assert(grpcStatus.Code(), Equals, codes.Unavailable) - c.Assert(grpcStatus.Message(), Equals, "not leader") + re.True(ok) + re.Equal(codes.Unavailable, grpcStatus.Code()) + re.Equal("not leader", grpcStatus.Message()) } -func (s *clusterTestSuite) TestStoreVersionChange(c *C) { - tc, err := tests.NewTestCluster(s.ctx, 1) +func TestStoreVersionChange(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tc, err := tests.NewTestCluster(ctx, 1) defer tc.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = tc.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClient(c, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() - bootstrapCluster(c, clusterID, grpcPDClient) + bootstrapCluster(re, clusterID, grpcPDClient) svr := leaderServer.GetServer() svr.SetClusterVersion("2.0.0") storeID, err := leaderServer.GetAllocator().Alloc() - c.Assert(err, IsNil) + re.NoError(err) store := newMetaStore(storeID, "127.0.0.1:4", "2.1.0", metapb.StoreState_Up, getTestDeployPath(storeID)) var wg sync.WaitGroup - c.Assert(failpoint.Enable("github.com/tikv/pd/server/versionChangeConcurrency", `return(true)`), IsNil) + re.NoError(failpoint.Enable("github.com/tikv/pd/server/versionChangeConcurrency", `return(true)`)) wg.Add(1) go func() { defer wg.Done() _, err = putStore(grpcPDClient, clusterID, store) - c.Assert(err, IsNil) + re.NoError(err) }() time.Sleep(100 * time.Millisecond) svr.SetClusterVersion("1.0.0") wg.Wait() v, err := semver.NewVersion("1.0.0") - c.Assert(err, IsNil) - c.Assert(svr.GetClusterVersion(), Equals, *v) - c.Assert(failpoint.Disable("github.com/tikv/pd/server/versionChangeConcurrency"), IsNil) + re.NoError(err) + re.Equal(*v, svr.GetClusterVersion()) + re.NoError(failpoint.Disable("github.com/tikv/pd/server/versionChangeConcurrency")) } -func (s *clusterTestSuite) TestConcurrentHandleRegion(c *C) { - tc, err := tests.NewTestCluster(s.ctx, 1) +func TestConcurrentHandleRegion(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + dashboard.SetCheckInterval(30 * time.Minute) + tc, err := tests.NewTestCluster(ctx, 1) defer tc.Destroy() - c.Assert(err, IsNil) - + re.NoError(err) err = tc.RunInitialServers() - c.Assert(err, IsNil) - + re.NoError(err) tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClient(c, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() - bootstrapCluster(c, clusterID, grpcPDClient) + bootstrapCluster(re, clusterID, grpcPDClient) storeAddrs := []string{"127.0.1.1:0", "127.0.1.1:1", "127.0.1.1:2"} rc := leaderServer.GetRaftCluster() - c.Assert(rc, NotNil) + re.NotNil(rc) rc.SetStorage(storage.NewStorageWithMemoryBackend()) stores := make([]*metapb.Store, 0, len(storeAddrs)) id := leaderServer.GetAllocator() for _, addr := range storeAddrs { storeID, err := id.Alloc() - c.Assert(err, IsNil) + re.NoError(err) store := newMetaStore(storeID, addr, "2.1.0", metapb.StoreState_Up, getTestDeployPath(storeID)) stores = append(stores, store) _, err = putStore(grpcPDClient, clusterID, store) - c.Assert(err, IsNil) + re.NoError(err) } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() var wg sync.WaitGroup // register store and bind stream for i, store := range stores { @@ -589,13 +591,13 @@ func (s *clusterTestSuite) TestConcurrentHandleRegion(c *C) { } grpcServer := &server.GrpcServer{Server: leaderServer.GetServer()} _, err := grpcServer.StoreHeartbeat(context.TODO(), req) - c.Assert(err, IsNil) + re.NoError(err) stream, err := grpcPDClient.RegionHeartbeat(ctx) - c.Assert(err, IsNil) + re.NoError(err) peerID, err := id.Alloc() - c.Assert(err, IsNil) + re.NoError(err) regionID, err := id.Alloc() - c.Assert(err, IsNil) + re.NoError(err) peer := &metapb.Peer{Id: peerID, StoreId: store.GetId()} regionReq := &pdpb.RegionHeartbeatRequest{ Header: testutil.NewRequestHeader(clusterID), @@ -606,7 +608,7 @@ func (s *clusterTestSuite) TestConcurrentHandleRegion(c *C) { Leader: peer, } err = stream.Send(regionReq) - c.Assert(err, IsNil) + re.NoError(err) // make sure the first store can receive one response if i == 0 { wg.Add(1) @@ -614,7 +616,7 @@ func (s *clusterTestSuite) TestConcurrentHandleRegion(c *C) { go func(isReceiver bool) { if isReceiver { _, err := stream.Recv() - c.Assert(err, IsNil) + re.NoError(err) wg.Done() } for { @@ -631,9 +633,9 @@ func (s *clusterTestSuite) TestConcurrentHandleRegion(c *C) { concurrent := 1000 for i := 0; i < concurrent; i++ { peerID, err := id.Alloc() - c.Assert(err, IsNil) + re.NoError(err) regionID, err := id.Alloc() - c.Assert(err, IsNil) + re.NoError(err) region := &metapb.Region{ Id: regionID, StartKey: []byte(fmt.Sprintf("%5d", i)), @@ -654,33 +656,36 @@ func (s *clusterTestSuite) TestConcurrentHandleRegion(c *C) { go func() { defer wg.Done() err := rc.HandleRegionHeartbeat(core.NewRegionInfo(region, region.Peers[0])) - c.Assert(err, IsNil) + re.NoError(err) }() } wg.Wait() } -func (s *clusterTestSuite) TestSetScheduleOpt(c *C) { +func TestSetScheduleOpt(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() // TODO: enable placementrules - tc, err := tests.NewTestCluster(s.ctx, 1, func(cfg *config.Config, svr string) { cfg.Replication.EnablePlacementRules = false }) + tc, err := tests.NewTestCluster(ctx, 1, func(cfg *config.Config, svr string) { cfg.Replication.EnablePlacementRules = false }) defer tc.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = tc.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClient(c, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() - bootstrapCluster(c, clusterID, grpcPDClient) + bootstrapCluster(re, clusterID, grpcPDClient) cfg := config.NewConfig() cfg.Schedule.TolerantSizeRatio = 5 err = cfg.Adjust(nil, false) - c.Assert(err, IsNil) + re.NoError(err) opt := config.NewPersistOptions(cfg) - c.Assert(err, IsNil) + re.NoError(err) svr := leaderServer.GetServer() scheduleCfg := opt.GetScheduleConfig() @@ -693,68 +698,63 @@ func (s *clusterTestSuite) TestSetScheduleOpt(c *C) { scheduleCfg.MaxSnapshotCount = 10 pdServerCfg.UseRegionStorage = true typ, labelKey, labelValue := "testTyp", "testKey", "testValue" - - c.Assert(svr.SetScheduleConfig(*scheduleCfg), IsNil) - c.Assert(svr.SetPDServerConfig(*pdServerCfg), IsNil) - c.Assert(svr.SetLabelProperty(typ, labelKey, labelValue), IsNil) - c.Assert(svr.SetReplicationConfig(*replicationCfg), IsNil) - - c.Assert(persistOptions.GetMaxReplicas(), Equals, 5) - c.Assert(persistOptions.GetMaxSnapshotCount(), Equals, uint64(10)) - c.Assert(persistOptions.IsUseRegionStorage(), IsTrue) - c.Assert(persistOptions.GetLabelPropertyConfig()[typ][0].Key, Equals, "testKey") - c.Assert(persistOptions.GetLabelPropertyConfig()[typ][0].Value, Equals, "testValue") - - c.Assert(svr.DeleteLabelProperty(typ, labelKey, labelValue), IsNil) - - c.Assert(persistOptions.GetLabelPropertyConfig()[typ], HasLen, 0) + re.NoError(svr.SetScheduleConfig(*scheduleCfg)) + re.NoError(svr.SetPDServerConfig(*pdServerCfg)) + re.NoError(svr.SetLabelProperty(typ, labelKey, labelValue)) + re.NoError(svr.SetReplicationConfig(*replicationCfg)) + re.Equal(5, persistOptions.GetMaxReplicas()) + re.Equal(uint64(10), persistOptions.GetMaxSnapshotCount()) + re.True(persistOptions.IsUseRegionStorage()) + re.Equal("testKey", persistOptions.GetLabelPropertyConfig()[typ][0].Key) + re.Equal("testValue", persistOptions.GetLabelPropertyConfig()[typ][0].Value) + re.NoError(svr.DeleteLabelProperty(typ, labelKey, labelValue)) + re.Len(persistOptions.GetLabelPropertyConfig()[typ], 0) // PUT GET failed - c.Assert(failpoint.Enable("github.com/tikv/pd/server/storage/kv/etcdSaveFailed", `return(true)`), IsNil) + re.NoError(failpoint.Enable("github.com/tikv/pd/server/storage/kv/etcdSaveFailed", `return(true)`)) replicationCfg.MaxReplicas = 7 scheduleCfg.MaxSnapshotCount = 20 pdServerCfg.UseRegionStorage = false - - c.Assert(svr.SetScheduleConfig(*scheduleCfg), NotNil) - c.Assert(svr.SetReplicationConfig(*replicationCfg), NotNil) - c.Assert(svr.SetPDServerConfig(*pdServerCfg), NotNil) - c.Assert(svr.SetLabelProperty(typ, labelKey, labelValue), NotNil) - - c.Assert(persistOptions.GetMaxReplicas(), Equals, 5) - c.Assert(persistOptions.GetMaxSnapshotCount(), Equals, uint64(10)) - c.Assert(persistOptions.GetPDServerConfig().UseRegionStorage, IsTrue) - c.Assert(persistOptions.GetLabelPropertyConfig()[typ], HasLen, 0) + re.Error(svr.SetScheduleConfig(*scheduleCfg)) + re.Error(svr.SetReplicationConfig(*replicationCfg)) + re.Error(svr.SetPDServerConfig(*pdServerCfg)) + re.Error(svr.SetLabelProperty(typ, labelKey, labelValue)) + re.Equal(5, persistOptions.GetMaxReplicas()) + re.Equal(uint64(10), persistOptions.GetMaxSnapshotCount()) + re.True(persistOptions.GetPDServerConfig().UseRegionStorage) + re.Len(persistOptions.GetLabelPropertyConfig()[typ], 0) // DELETE failed - c.Assert(failpoint.Disable("github.com/tikv/pd/server/storage/kv/etcdSaveFailed"), IsNil) - c.Assert(svr.SetReplicationConfig(*replicationCfg), IsNil) - - c.Assert(failpoint.Enable("github.com/tikv/pd/server/storage/kv/etcdSaveFailed", `return(true)`), IsNil) - c.Assert(svr.DeleteLabelProperty(typ, labelKey, labelValue), NotNil) - - c.Assert(persistOptions.GetLabelPropertyConfig()[typ][0].Key, Equals, "testKey") - c.Assert(persistOptions.GetLabelPropertyConfig()[typ][0].Value, Equals, "testValue") - c.Assert(failpoint.Disable("github.com/tikv/pd/server/storage/kv/etcdSaveFailed"), IsNil) + re.NoError(failpoint.Disable("github.com/tikv/pd/server/storage/kv/etcdSaveFailed")) + re.NoError(svr.SetReplicationConfig(*replicationCfg)) + re.NoError(failpoint.Enable("github.com/tikv/pd/server/storage/kv/etcdSaveFailed", `return(true)`)) + re.Error(svr.DeleteLabelProperty(typ, labelKey, labelValue)) + re.Equal("testKey", persistOptions.GetLabelPropertyConfig()[typ][0].Key) + re.Equal("testValue", persistOptions.GetLabelPropertyConfig()[typ][0].Value) + re.NoError(failpoint.Disable("github.com/tikv/pd/server/storage/kv/etcdSaveFailed")) } -func (s *clusterTestSuite) TestLoadClusterInfo(c *C) { - tc, err := tests.NewTestCluster(s.ctx, 1) +func TestLoadClusterInfo(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tc, err := tests.NewTestCluster(ctx, 1) defer tc.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = tc.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) svr := leaderServer.GetServer() - rc := cluster.NewRaftCluster(s.ctx, svr.ClusterID(), syncer.NewRegionSyncer(svr), svr.GetClient(), svr.GetHTTPClient()) + rc := cluster.NewRaftCluster(ctx, svr.ClusterID(), syncer.NewRegionSyncer(svr), svr.GetClient(), svr.GetHTTPClient()) // Cluster is not bootstrapped. rc.InitCluster(svr.GetAllocator(), svr.GetPersistOptions(), svr.GetStorage(), svr.GetBasicCluster()) raftCluster, err := rc.LoadClusterInfo() - c.Assert(err, IsNil) - c.Assert(raftCluster, IsNil) + re.NoError(err) + re.Nil(raftCluster) storage := rc.GetStorage() basicCluster := rc.GetBasicCluster() @@ -762,7 +762,7 @@ func (s *clusterTestSuite) TestLoadClusterInfo(c *C) { // Save meta, stores and regions. n := 10 meta := &metapb.Cluster{Id: 123} - c.Assert(storage.SaveMeta(meta), IsNil) + re.NoError(storage.SaveMeta(meta)) stores := make([]*metapb.Store, 0, n) for i := 0; i < n; i++ { store := &metapb.Store{Id: uint64(i)} @@ -770,7 +770,7 @@ func (s *clusterTestSuite) TestLoadClusterInfo(c *C) { } for _, store := range stores { - c.Assert(storage.SaveStore(store), IsNil) + re.NoError(storage.SaveStore(store)) } regions := make([]*metapb.Region, 0, n) @@ -785,25 +785,25 @@ func (s *clusterTestSuite) TestLoadClusterInfo(c *C) { } for _, region := range regions { - c.Assert(storage.SaveRegion(region), IsNil) + re.NoError(storage.SaveRegion(region)) } - c.Assert(storage.Flush(), IsNil) + re.NoError(storage.Flush()) - raftCluster = cluster.NewRaftCluster(s.ctx, svr.ClusterID(), syncer.NewRegionSyncer(svr), svr.GetClient(), svr.GetHTTPClient()) + raftCluster = cluster.NewRaftCluster(ctx, svr.ClusterID(), syncer.NewRegionSyncer(svr), svr.GetClient(), svr.GetHTTPClient()) raftCluster.InitCluster(mockid.NewIDAllocator(), opt, storage, basicCluster) raftCluster, err = raftCluster.LoadClusterInfo() - c.Assert(err, IsNil) - c.Assert(raftCluster, NotNil) + re.NoError(err) + re.NotNil(raftCluster) // Check meta, stores, and regions. - c.Assert(raftCluster.GetMetaCluster(), DeepEquals, meta) - c.Assert(raftCluster.GetStoreCount(), Equals, n) + re.Equal(meta, raftCluster.GetMetaCluster()) + re.Equal(n, raftCluster.GetStoreCount()) for _, store := range raftCluster.GetMetaStores() { - c.Assert(store, DeepEquals, stores[store.GetId()]) + re.Equal(stores[store.GetId()], store) } - c.Assert(raftCluster.GetRegionCount(), Equals, n) + re.Equal(n, raftCluster.GetRegionCount()) for _, region := range raftCluster.GetMetaRegions() { - c.Assert(region, DeepEquals, regions[region.GetId()]) + re.Equal(regions[region.GetId()], region) } m := 20 @@ -819,23 +819,26 @@ func (s *clusterTestSuite) TestLoadClusterInfo(c *C) { } for _, region := range regions { - c.Assert(storage.SaveRegion(region), IsNil) + re.NoError(storage.SaveRegion(region)) } - raftCluster.GetStorage().LoadRegionsOnce(s.ctx, raftCluster.GetBasicCluster().PutRegion) - c.Assert(raftCluster.GetRegionCount(), Equals, n) + raftCluster.GetStorage().LoadRegionsOnce(ctx, raftCluster.GetBasicCluster().PutRegion) + re.Equal(n, raftCluster.GetRegionCount()) } -func (s *clusterTestSuite) TestTiFlashWithPlacementRules(c *C) { - tc, err := tests.NewTestCluster(s.ctx, 1, func(cfg *config.Config, name string) { cfg.Replication.EnablePlacementRules = false }) +func TestTiFlashWithPlacementRules(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tc, err := tests.NewTestCluster(ctx, 1, func(cfg *config.Config, name string) { cfg.Replication.EnablePlacementRules = false }) defer tc.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = tc.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClient(c, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() - bootstrapCluster(c, clusterID, grpcPDClient) + bootstrapCluster(re, clusterID, grpcPDClient) tiflashStore := &metapb.Store{ Id: 11, @@ -846,163 +849,155 @@ func (s *clusterTestSuite) TestTiFlashWithPlacementRules(c *C) { // cannot put TiFlash node without placement rules _, err = putStore(grpcPDClient, clusterID, tiflashStore) - c.Assert(err, NotNil) + re.Error(err) rep := leaderServer.GetConfig().Replication rep.EnablePlacementRules = true svr := leaderServer.GetServer() err = svr.SetReplicationConfig(rep) - c.Assert(err, IsNil) + re.NoError(err) _, err = putStore(grpcPDClient, clusterID, tiflashStore) - c.Assert(err, IsNil) + re.NoError(err) // test TiFlash store limit expect := map[uint64]config.StoreLimitConfig{11: {AddPeer: 30, RemovePeer: 30}} - c.Assert(svr.GetScheduleConfig().StoreLimit, DeepEquals, expect) + re.Equal(expect, svr.GetScheduleConfig().StoreLimit) // cannot disable placement rules with TiFlash nodes rep.EnablePlacementRules = false err = svr.SetReplicationConfig(rep) - c.Assert(err, NotNil) + re.Error(err) err = svr.GetRaftCluster().BuryStore(11, true) - c.Assert(err, IsNil) + re.NoError(err) err = svr.SetReplicationConfig(rep) - c.Assert(err, IsNil) - c.Assert(len(svr.GetScheduleConfig().StoreLimit), Equals, 0) + re.NoError(err) + re.Equal(0, len(svr.GetScheduleConfig().StoreLimit)) } -func (s *clusterTestSuite) TestReplicationModeStatus(c *C) { - tc, err := tests.NewTestCluster(s.ctx, 1, func(conf *config.Config, serverName string) { +func TestReplicationModeStatus(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tc, err := tests.NewTestCluster(ctx, 1, func(conf *config.Config, serverName string) { conf.ReplicationMode.ReplicationMode = "dr-auto-sync" }) defer tc.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = tc.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClient(c, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() req := newBootstrapRequest(clusterID) res, err := grpcPDClient.Bootstrap(context.Background(), req) - c.Assert(err, IsNil) - c.Assert(res.GetReplicationStatus().GetMode(), Equals, replication_modepb.ReplicationMode_DR_AUTO_SYNC) // check status in bootstrap response + re.NoError(err) + re.Equal(replication_modepb.ReplicationMode_DR_AUTO_SYNC, res.GetReplicationStatus().GetMode()) // check status in bootstrap response store := &metapb.Store{Id: 11, Address: "127.0.0.1:1", Version: "v4.1.0"} putRes, err := putStore(grpcPDClient, clusterID, store) - c.Assert(err, IsNil) - c.Assert(putRes.GetReplicationStatus().GetMode(), Equals, replication_modepb.ReplicationMode_DR_AUTO_SYNC) // check status in putStore response + re.NoError(err) + re.Equal(replication_modepb.ReplicationMode_DR_AUTO_SYNC, putRes.GetReplicationStatus().GetMode()) // check status in putStore response hbReq := &pdpb.StoreHeartbeatRequest{ Header: testutil.NewRequestHeader(clusterID), Stats: &pdpb.StoreStats{StoreId: store.GetId()}, } hbRes, err := grpcPDClient.StoreHeartbeat(context.Background(), hbReq) - c.Assert(err, IsNil) - c.Assert(hbRes.GetReplicationStatus().GetMode(), Equals, replication_modepb.ReplicationMode_DR_AUTO_SYNC) // check status in store heartbeat response + re.NoError(err) + re.Equal(replication_modepb.ReplicationMode_DR_AUTO_SYNC, hbRes.GetReplicationStatus().GetMode()) // check status in store heartbeat response } func newIsBootstrapRequest(clusterID uint64) *pdpb.IsBootstrappedRequest { - req := &pdpb.IsBootstrappedRequest{ + return &pdpb.IsBootstrappedRequest{ Header: testutil.NewRequestHeader(clusterID), } - - return req } func newBootstrapRequest(clusterID uint64) *pdpb.BootstrapRequest { - req := &pdpb.BootstrapRequest{ + return &pdpb.BootstrapRequest{ Header: testutil.NewRequestHeader(clusterID), Store: &metapb.Store{Id: 1, Address: testStoreAddr}, Region: &metapb.Region{Id: 2, Peers: []*metapb.Peer{{Id: 3, StoreId: 1, Role: metapb.PeerRole_Voter}}}, } - - return req } // helper function to check and bootstrap. -func bootstrapCluster(c *C, clusterID uint64, grpcPDClient pdpb.PDClient) { +func bootstrapCluster(re *require.Assertions, clusterID uint64, grpcPDClient pdpb.PDClient) { req := newBootstrapRequest(clusterID) _, err := grpcPDClient.Bootstrap(context.Background(), req) - c.Assert(err, IsNil) + re.NoError(err) } func putStore(grpcPDClient pdpb.PDClient, clusterID uint64, store *metapb.Store) (*pdpb.PutStoreResponse, error) { - req := &pdpb.PutStoreRequest{ + return grpcPDClient.PutStore(context.Background(), &pdpb.PutStoreRequest{ Header: testutil.NewRequestHeader(clusterID), Store: store, - } - resp, err := grpcPDClient.PutStore(context.Background(), req) - return resp, err + }) } -func getStore(c *C, clusterID uint64, grpcPDClient pdpb.PDClient, storeID uint64) *metapb.Store { - req := &pdpb.GetStoreRequest{ +func getStore(re *require.Assertions, clusterID uint64, grpcPDClient pdpb.PDClient, storeID uint64) *metapb.Store { + resp, err := grpcPDClient.GetStore(context.Background(), &pdpb.GetStoreRequest{ Header: testutil.NewRequestHeader(clusterID), StoreId: storeID, - } - resp, err := grpcPDClient.GetStore(context.Background(), req) - c.Assert(err, IsNil) - c.Assert(resp.GetStore().GetId(), Equals, storeID) - + }) + re.NoError(err) + re.Equal(storeID, resp.GetStore().GetId()) return resp.GetStore() } -func getRegion(c *C, clusterID uint64, grpcPDClient pdpb.PDClient, regionKey []byte) *metapb.Region { - req := &pdpb.GetRegionRequest{ +func getRegion(re *require.Assertions, clusterID uint64, grpcPDClient pdpb.PDClient, regionKey []byte) *metapb.Region { + resp, err := grpcPDClient.GetRegion(context.Background(), &pdpb.GetRegionRequest{ Header: testutil.NewRequestHeader(clusterID), RegionKey: regionKey, - } - - resp, err := grpcPDClient.GetRegion(context.Background(), req) - c.Assert(err, IsNil) - c.Assert(resp.GetRegion(), NotNil) - + }) + re.NoError(err) + re.NotNil(resp.GetRegion()) return resp.GetRegion() } -func getRegionByID(c *C, clusterID uint64, grpcPDClient pdpb.PDClient, regionID uint64) *metapb.Region { - req := &pdpb.GetRegionByIDRequest{ +func getRegionByID(re *require.Assertions, clusterID uint64, grpcPDClient pdpb.PDClient, regionID uint64) *metapb.Region { + resp, err := grpcPDClient.GetRegionByID(context.Background(), &pdpb.GetRegionByIDRequest{ Header: testutil.NewRequestHeader(clusterID), RegionId: regionID, - } - - resp, err := grpcPDClient.GetRegionByID(context.Background(), req) - c.Assert(err, IsNil) - c.Assert(resp.GetRegion(), NotNil) - + }) + re.NoError(err) + re.NotNil(resp.GetRegion()) return resp.GetRegion() } -func getClusterConfig(c *C, clusterID uint64, grpcPDClient pdpb.PDClient) *metapb.Cluster { - req := &pdpb.GetClusterConfigRequest{Header: testutil.NewRequestHeader(clusterID)} - - resp, err := grpcPDClient.GetClusterConfig(context.Background(), req) - c.Assert(err, IsNil) - c.Assert(resp.GetCluster(), NotNil) - +func getClusterConfig(re *require.Assertions, clusterID uint64, grpcPDClient pdpb.PDClient) *metapb.Cluster { + resp, err := grpcPDClient.GetClusterConfig(context.Background(), &pdpb.GetClusterConfigRequest{ + Header: testutil.NewRequestHeader(clusterID), + }) + re.NoError(err) + re.NotNil(resp.GetCluster()) return resp.GetCluster() } -func (s *clusterTestSuite) TestOfflineStoreLimit(c *C) { - tc, err := tests.NewTestCluster(s.ctx, 1) +func TestOfflineStoreLimit(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + dashboard.SetCheckInterval(30 * time.Minute) + tc, err := tests.NewTestCluster(ctx, 1) defer tc.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = tc.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClient(c, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() - bootstrapCluster(c, clusterID, grpcPDClient) + bootstrapCluster(re, clusterID, grpcPDClient) storeAddrs := []string{"127.0.1.1:0", "127.0.1.1:1"} rc := leaderServer.GetRaftCluster() - c.Assert(rc, NotNil) + re.NotNil(rc) rc.SetStorage(storage.NewStorageWithMemoryBackend()) id := leaderServer.GetAllocator() for _, addr := range storeAddrs { storeID, err := id.Alloc() - c.Assert(err, IsNil) + re.NoError(err) store := newMetaStore(storeID, addr, "4.0.0", metapb.StoreState_Up, getTestDeployPath(storeID)) _, err = putStore(grpcPDClient, clusterID, store) - c.Assert(err, IsNil) + re.NoError(err) } for i := uint64(1); i <= 2; i++ { r := &metapb.Region{ @@ -1018,7 +1013,7 @@ func (s *clusterTestSuite) TestOfflineStoreLimit(c *C) { region := core.NewRegionInfo(r, r.Peers[0], core.SetApproximateSize(10)) err = rc.HandleRegionHeartbeat(region) - c.Assert(err, IsNil) + re.NoError(err) } oc := rc.GetOperatorController() @@ -1027,22 +1022,22 @@ func (s *clusterTestSuite) TestOfflineStoreLimit(c *C) { // only can add 5 remove peer operators on store 1 for i := uint64(1); i <= 5; i++ { op := operator.NewTestOperator(1, &metapb.RegionEpoch{ConfVer: 1, Version: 1}, operator.OpRegion, operator.RemovePeer{FromStore: 1}) - c.Assert(oc.AddOperator(op), IsTrue) - c.Assert(oc.RemoveOperator(op), IsTrue) + re.True(oc.AddOperator(op)) + re.True(oc.RemoveOperator(op)) } op := operator.NewTestOperator(1, &metapb.RegionEpoch{ConfVer: 1, Version: 1}, operator.OpRegion, operator.RemovePeer{FromStore: 1}) - c.Assert(oc.AddOperator(op), IsFalse) - c.Assert(oc.RemoveOperator(op), IsFalse) + re.False(oc.AddOperator(op)) + re.False(oc.RemoveOperator(op)) // only can add 5 remove peer operators on store 2 for i := uint64(1); i <= 5; i++ { op := operator.NewTestOperator(2, &metapb.RegionEpoch{ConfVer: 1, Version: 1}, operator.OpRegion, operator.RemovePeer{FromStore: 2}) - c.Assert(oc.AddOperator(op), IsTrue) - c.Assert(oc.RemoveOperator(op), IsTrue) + re.True(oc.AddOperator(op)) + re.True(oc.RemoveOperator(op)) } op = operator.NewTestOperator(2, &metapb.RegionEpoch{ConfVer: 1, Version: 1}, operator.OpRegion, operator.RemovePeer{FromStore: 2}) - c.Assert(oc.AddOperator(op), IsFalse) - c.Assert(oc.RemoveOperator(op), IsFalse) + re.False(oc.AddOperator(op)) + re.False(oc.RemoveOperator(op)) // reset all store limit opt.SetAllStoresLimit(storelimit.RemovePeer, 2) @@ -1050,12 +1045,12 @@ func (s *clusterTestSuite) TestOfflineStoreLimit(c *C) { // only can add 5 remove peer operators on store 2 for i := uint64(1); i <= 5; i++ { op := operator.NewTestOperator(2, &metapb.RegionEpoch{ConfVer: 1, Version: 1}, operator.OpRegion, operator.RemovePeer{FromStore: 2}) - c.Assert(oc.AddOperator(op), IsTrue) - c.Assert(oc.RemoveOperator(op), IsTrue) + re.True(oc.AddOperator(op)) + re.True(oc.RemoveOperator(op)) } op = operator.NewTestOperator(2, &metapb.RegionEpoch{ConfVer: 1, Version: 1}, operator.OpRegion, operator.RemovePeer{FromStore: 2}) - c.Assert(oc.AddOperator(op), IsFalse) - c.Assert(oc.RemoveOperator(op), IsFalse) + re.False(oc.AddOperator(op)) + re.False(oc.RemoveOperator(op)) // offline store 1 rc.SetStoreLimit(1, storelimit.RemovePeer, storelimit.Unlimited) @@ -1064,28 +1059,32 @@ func (s *clusterTestSuite) TestOfflineStoreLimit(c *C) { // can add unlimited remove peer operators on store 1 for i := uint64(1); i <= 30; i++ { op := operator.NewTestOperator(1, &metapb.RegionEpoch{ConfVer: 1, Version: 1}, operator.OpRegion, operator.RemovePeer{FromStore: 1}) - c.Assert(oc.AddOperator(op), IsTrue) - c.Assert(oc.RemoveOperator(op), IsTrue) + re.True(oc.AddOperator(op)) + re.True(oc.RemoveOperator(op)) } } -func (s *clusterTestSuite) TestUpgradeStoreLimit(c *C) { - tc, err := tests.NewTestCluster(s.ctx, 1) +func TestUpgradeStoreLimit(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + dashboard.SetCheckInterval(30 * time.Minute) + tc, err := tests.NewTestCluster(ctx, 1) defer tc.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = tc.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClient(c, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() - bootstrapCluster(c, clusterID, grpcPDClient) + bootstrapCluster(re, clusterID, grpcPDClient) rc := leaderServer.GetRaftCluster() - c.Assert(rc, NotNil) + re.NotNil(rc) rc.SetStorage(storage.NewStorageWithMemoryBackend()) store := newMetaStore(1, "127.0.1.1:0", "4.0.0", metapb.StoreState_Up, "test/store1") _, err = putStore(grpcPDClient, clusterID, store) - c.Assert(err, IsNil) + re.NoError(err) r := &metapb.Region{ Id: 1, RegionEpoch: &metapb.RegionEpoch{ @@ -1099,58 +1098,60 @@ func (s *clusterTestSuite) TestUpgradeStoreLimit(c *C) { region := core.NewRegionInfo(r, r.Peers[0], core.SetApproximateSize(10)) err = rc.HandleRegionHeartbeat(region) - c.Assert(err, IsNil) + re.NoError(err) // restart PD // Here we use an empty storelimit to simulate the upgrade progress. opt := rc.GetOpts() scheduleCfg := opt.GetScheduleConfig().Clone() scheduleCfg.StoreLimit = map[uint64]config.StoreLimitConfig{} - c.Assert(leaderServer.GetServer().SetScheduleConfig(*scheduleCfg), IsNil) + re.NoError(leaderServer.GetServer().SetScheduleConfig(*scheduleCfg)) err = leaderServer.Stop() - c.Assert(err, IsNil) + re.NoError(err) err = leaderServer.Run() - c.Assert(err, IsNil) + re.NoError(err) oc := rc.GetOperatorController() // only can add 5 remove peer operators on store 1 for i := uint64(1); i <= 5; i++ { op := operator.NewTestOperator(1, &metapb.RegionEpoch{ConfVer: 1, Version: 1}, operator.OpRegion, operator.RemovePeer{FromStore: 1}) - c.Assert(oc.AddOperator(op), IsTrue) - c.Assert(oc.RemoveOperator(op), IsTrue) + re.True(oc.AddOperator(op)) + re.True(oc.RemoveOperator(op)) } op := operator.NewTestOperator(1, &metapb.RegionEpoch{ConfVer: 1, Version: 1}, operator.OpRegion, operator.RemovePeer{FromStore: 1}) - c.Assert(oc.AddOperator(op), IsFalse) - c.Assert(oc.RemoveOperator(op), IsFalse) + re.False(oc.AddOperator(op)) + re.False(oc.RemoveOperator(op)) } -func (s *clusterTestSuite) TestStaleTermHeartbeat(c *C) { - tc, err := tests.NewTestCluster(s.ctx, 1) - c.Assert(err, IsNil) +func TestStaleTermHeartbeat(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + dashboard.SetCheckInterval(30 * time.Minute) + tc, err := tests.NewTestCluster(ctx, 1) + re.NoError(err) defer tc.Destroy() - err = tc.RunInitialServers() - c.Assert(err, IsNil) - + re.NoError(err) tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClient(c, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() - bootstrapCluster(c, clusterID, grpcPDClient) + bootstrapCluster(re, clusterID, grpcPDClient) storeAddrs := []string{"127.0.1.1:0", "127.0.1.1:1", "127.0.1.1:2"} rc := leaderServer.GetRaftCluster() - c.Assert(rc, NotNil) + re.NotNil(rc) rc.SetStorage(storage.NewStorageWithMemoryBackend()) peers := make([]*metapb.Peer, 0, len(storeAddrs)) id := leaderServer.GetAllocator() for _, addr := range storeAddrs { storeID, err := id.Alloc() - c.Assert(err, IsNil) + re.NoError(err) peerID, err := id.Alloc() - c.Assert(err, IsNil) + re.NoError(err) store := newMetaStore(storeID, addr, "3.0.0", metapb.StoreState_Up, getTestDeployPath(storeID)) _, err = putStore(grpcPDClient, clusterID, store) - c.Assert(err, IsNil) + re.NoError(err) peers = append(peers, &metapb.Peer{ Id: peerID, StoreId: storeID, @@ -1176,45 +1177,45 @@ func (s *clusterTestSuite) TestStaleTermHeartbeat(c *C) { region := core.RegionFromHeartbeat(regionReq) err = rc.HandleRegionHeartbeat(region) - c.Assert(err, IsNil) + re.NoError(err) // Transfer leader regionReq.Term = 6 regionReq.Leader = peers[1] region = core.RegionFromHeartbeat(regionReq) err = rc.HandleRegionHeartbeat(region) - c.Assert(err, IsNil) + re.NoError(err) // issue #3379 regionReq.KeysWritten = uint64(18446744073709551615) // -1 regionReq.BytesWritten = uint64(18446744073709550602) // -1024 region = core.RegionFromHeartbeat(regionReq) - c.Assert(region.GetKeysWritten(), Equals, uint64(0)) - c.Assert(region.GetBytesWritten(), Equals, uint64(0)) + re.Equal(uint64(0), region.GetKeysWritten()) + re.Equal(uint64(0), region.GetBytesWritten()) err = rc.HandleRegionHeartbeat(region) - c.Assert(err, IsNil) + re.NoError(err) // Stale heartbeat, update check should fail regionReq.Term = 5 regionReq.Leader = peers[0] region = core.RegionFromHeartbeat(regionReq) err = rc.HandleRegionHeartbeat(region) - c.Assert(err, NotNil) + re.Error(err) // Allow regions that are created by unsafe recover to send a heartbeat, even though they // are considered "stale" because their conf ver and version are both equal to 1. regionReq.Region.RegionEpoch.ConfVer = 1 region = core.RegionFromHeartbeat(regionReq) err = rc.HandleRegionHeartbeat(region) - c.Assert(err, IsNil) + re.NoError(err) } -func (s *clusterTestSuite) putRegionWithLeader(c *C, rc *cluster.RaftCluster, id id.Allocator, storeID uint64) { +func putRegionWithLeader(re *require.Assertions, rc *cluster.RaftCluster, id id.Allocator, storeID uint64) { for i := 0; i < 3; i++ { regionID, err := id.Alloc() - c.Assert(err, IsNil) + re.NoError(err) peerID, err := id.Alloc() - c.Assert(err, IsNil) + re.NoError(err) region := &metapb.Region{ Id: regionID, Peers: []*metapb.Peer{{Id: peerID, StoreId: storeID}}, @@ -1223,43 +1224,46 @@ func (s *clusterTestSuite) putRegionWithLeader(c *C, rc *cluster.RaftCluster, id } rc.HandleRegionHeartbeat(core.NewRegionInfo(region, region.Peers[0])) } - c.Assert(rc.GetStore(storeID).GetLeaderCount(), Equals, 3) + re.Equal(3, rc.GetStore(storeID).GetLeaderCount()) } -func (s *clusterTestSuite) checkMinResolvedTSFromStorage(c *C, rc *cluster.RaftCluster, expect uint64) { +func checkMinResolvedTSFromStorage(re *require.Assertions, rc *cluster.RaftCluster, expect uint64) { time.Sleep(time.Millisecond * 10) ts2, err := rc.GetStorage().LoadMinResolvedTS() - c.Assert(err, IsNil) - c.Assert(ts2, Equals, expect) + re.NoError(err) + re.Equal(expect, ts2) } -func (s *clusterTestSuite) setMinResolvedTSPersistenceInterval(c *C, rc *cluster.RaftCluster, svr *server.Server, interval time.Duration) { +func setMinResolvedTSPersistenceInterval(re *require.Assertions, rc *cluster.RaftCluster, svr *server.Server, interval time.Duration) { cfg := rc.GetOpts().GetPDServerConfig().Clone() cfg.MinResolvedTSPersistenceInterval = typeutil.NewDuration(interval) err := svr.SetPDServerConfig(*cfg) - c.Assert(err, IsNil) + re.NoError(err) time.Sleep(time.Millisecond + interval) } -func (s *clusterTestSuite) TestMinResolvedTS(c *C) { +func TestMinResolvedTS(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() cluster.DefaultMinResolvedTSPersistenceInterval = time.Millisecond - tc, err := tests.NewTestCluster(s.ctx, 1) + tc, err := tests.NewTestCluster(ctx, 1) defer tc.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = tc.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) id := leaderServer.GetAllocator() - grpcPDClient := testutil.MustNewGrpcClient(c, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() - bootstrapCluster(c, clusterID, grpcPDClient) + bootstrapCluster(re, clusterID, grpcPDClient) rc := leaderServer.GetRaftCluster() - c.Assert(rc, NotNil) + re.NotNil(rc) svr := leaderServer.GetServer() - addStoreAndCheckMinResolvedTS := func(c *C, isTiflash bool, minResolvedTS, expect uint64) uint64 { + addStoreAndCheckMinResolvedTS := func(re *require.Assertions, isTiflash bool, minResolvedTS, expect uint64) uint64 { storeID, err := id.Alloc() - c.Assert(err, IsNil) + re.NoError(err) store := &metapb.Store{ Id: storeID, Version: "v6.0.0", @@ -1269,95 +1273,104 @@ func (s *clusterTestSuite) TestMinResolvedTS(c *C) { store.Labels = []*metapb.StoreLabel{{Key: "engine", Value: "tiflash"}} } _, err = putStore(grpcPDClient, clusterID, store) - c.Assert(err, IsNil) + re.NoError(err) req := &pdpb.ReportMinResolvedTsRequest{ Header: testutil.NewRequestHeader(clusterID), StoreId: storeID, MinResolvedTs: minResolvedTS, } _, err = grpcPDClient.ReportMinResolvedTS(context.Background(), req) - c.Assert(err, IsNil) + re.NoError(err) ts := rc.GetMinResolvedTS() - c.Assert(ts, Equals, expect) + re.Equal(expect, ts) return storeID } // case1: cluster is no initialized // min resolved ts should be not available status, err := rc.LoadClusterStatus() - c.Assert(err, IsNil) - c.Assert(status.IsInitialized, IsFalse) + re.NoError(err) + re.False(status.IsInitialized) store1TS := uint64(233) - store1 := addStoreAndCheckMinResolvedTS(c, false /* not tiflash */, store1TS, math.MaxUint64) + store1 := addStoreAndCheckMinResolvedTS(re, false /* not tiflash */, store1TS, math.MaxUint64) // case2: add leader peer to store1 but no run job // min resolved ts should be zero - s.putRegionWithLeader(c, rc, id, store1) + putRegionWithLeader(re, rc, id, store1) + time.Sleep(time.Millisecond) ts := rc.GetMinResolvedTS() - c.Assert(ts, Equals, uint64(0)) + re.Equal(uint64(0), ts) // case3: add leader peer to store1 and run job // min resolved ts should be store1TS - s.setMinResolvedTSPersistenceInterval(c, rc, svr, time.Millisecond) + setMinResolvedTSPersistenceInterval(re, rc, svr, time.Millisecond) + time.Sleep(time.Millisecond) ts = rc.GetMinResolvedTS() - c.Assert(ts, Equals, store1TS) - s.checkMinResolvedTSFromStorage(c, rc, ts) + re.Equal(store1TS, ts) + checkMinResolvedTSFromStorage(re, rc, ts) // case4: add tiflash store // min resolved ts should no change - addStoreAndCheckMinResolvedTS(c, true /* is tiflash */, 0, store1TS) + addStoreAndCheckMinResolvedTS(re, true /* is tiflash */, 0, store1TS) // case5: add new store with lager min resolved ts // min resolved ts should no change store3TS := store1TS + 10 - store3 := addStoreAndCheckMinResolvedTS(c, false /* not tiflash */, store3TS, store1TS) - s.putRegionWithLeader(c, rc, id, store3) + store3 := addStoreAndCheckMinResolvedTS(re, false /* not tiflash */, store3TS, store1TS) + putRegionWithLeader(re, rc, id, store3) // case6: set store1 to tombstone // min resolved ts should change to store 3 - resetStoreState(c, rc, store1, metapb.StoreState_Tombstone) + resetStoreState(re, rc, store1, metapb.StoreState_Tombstone) + time.Sleep(time.Millisecond) ts = rc.GetMinResolvedTS() - c.Assert(ts, Equals, store3TS) + re.Equal(store3TS, ts) // case7: add a store with leader peer but no report min resolved ts // min resolved ts should be no change - s.checkMinResolvedTSFromStorage(c, rc, store3TS) - store4 := addStoreAndCheckMinResolvedTS(c, false /* not tiflash */, 0, store3TS) - s.putRegionWithLeader(c, rc, id, store4) + checkMinResolvedTSFromStorage(re, rc, store3TS) + store4 := addStoreAndCheckMinResolvedTS(re, false /* not tiflash */, 0, store3TS) + putRegionWithLeader(re, rc, id, store4) + time.Sleep(time.Millisecond) ts = rc.GetMinResolvedTS() - c.Assert(ts, Equals, store3TS) - s.checkMinResolvedTSFromStorage(c, rc, store3TS) - resetStoreState(c, rc, store4, metapb.StoreState_Tombstone) + re.Equal(store3TS, ts) + checkMinResolvedTSFromStorage(re, rc, store3TS) + resetStoreState(re, rc, store4, metapb.StoreState_Tombstone) // case8: set min resolved ts persist interval to zero // although min resolved ts increase, it should be not persisted until job running. store5TS := store3TS + 10 - s.setMinResolvedTSPersistenceInterval(c, rc, svr, 0) - store5 := addStoreAndCheckMinResolvedTS(c, false /* not tiflash */, store5TS, store3TS) - resetStoreState(c, rc, store3, metapb.StoreState_Tombstone) - s.putRegionWithLeader(c, rc, id, store5) + setMinResolvedTSPersistenceInterval(re, rc, svr, 0) + store5 := addStoreAndCheckMinResolvedTS(re, false /* not tiflash */, store5TS, store3TS) + resetStoreState(re, rc, store3, metapb.StoreState_Tombstone) + putRegionWithLeader(re, rc, id, store5) + time.Sleep(time.Millisecond) ts = rc.GetMinResolvedTS() - c.Assert(ts, Equals, store3TS) - s.setMinResolvedTSPersistenceInterval(c, rc, svr, time.Millisecond) + re.Equal(store3TS, ts) + setMinResolvedTSPersistenceInterval(re, rc, svr, time.Millisecond) + time.Sleep(time.Millisecond) ts = rc.GetMinResolvedTS() - c.Assert(ts, Equals, store5TS) + re.Equal(store5TS, ts) } // See https://github.com/tikv/pd/issues/4941 -func (s *clusterTestSuite) TestTransferLeaderBack(c *C) { - tc, err := tests.NewTestCluster(s.ctx, 2) +func TestTransferLeaderBack(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tc, err := tests.NewTestCluster(ctx, 2) defer tc.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = tc.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) svr := leaderServer.GetServer() - rc := cluster.NewRaftCluster(s.ctx, svr.ClusterID(), syncer.NewRegionSyncer(svr), svr.GetClient(), svr.GetHTTPClient()) + rc := cluster.NewRaftCluster(ctx, svr.ClusterID(), syncer.NewRegionSyncer(svr), svr.GetClient(), svr.GetHTTPClient()) rc.InitCluster(svr.GetAllocator(), svr.GetPersistOptions(), svr.GetStorage(), svr.GetBasicCluster()) storage := rc.GetStorage() meta := &metapb.Cluster{Id: 123} - c.Assert(storage.SaveMeta(meta), IsNil) + re.NoError(storage.SaveMeta(meta)) n := 4 stores := make([]*metapb.Store, 0, n) for i := 1; i <= n; i++ { @@ -1366,14 +1379,14 @@ func (s *clusterTestSuite) TestTransferLeaderBack(c *C) { } for _, store := range stores { - c.Assert(storage.SaveStore(store), IsNil) + re.NoError(storage.SaveStore(store)) } rc, err = rc.LoadClusterInfo() - c.Assert(err, IsNil) - c.Assert(rc, NotNil) + re.NoError(err) + re.NotNil(rc) // offline a store - c.Assert(rc.RemoveStore(1, false), IsNil) - c.Assert(rc.GetStore(1).GetState(), Equals, metapb.StoreState_Offline) + re.NoError(rc.RemoveStore(1, false)) + re.Equal(metapb.StoreState_Offline, rc.GetStore(1).GetState()) // transfer PD leader to another PD tc.ResignLeader() @@ -1381,11 +1394,12 @@ func (s *clusterTestSuite) TestTransferLeaderBack(c *C) { leaderServer = tc.GetServer(tc.GetLeader()) svr1 := leaderServer.GetServer() rc1 := svr1.GetRaftCluster() - c.Assert(err, IsNil) - c.Assert(rc1, NotNil) + re.NoError(err) + re.NotNil(rc1) + // tombstone a store, and remove its record - c.Assert(rc1.BuryStore(1, false), IsNil) - c.Assert(rc1.RemoveTombStoneRecords(), IsNil) + re.NoError(rc1.BuryStore(1, false)) + re.NoError(rc1.RemoveTombStoneRecords()) // transfer PD leader back to the previous PD tc.ResignLeader() @@ -1393,9 +1407,9 @@ func (s *clusterTestSuite) TestTransferLeaderBack(c *C) { leaderServer = tc.GetServer(tc.GetLeader()) svr = leaderServer.GetServer() rc = svr.GetRaftCluster() - c.Assert(rc, NotNil) + re.NotNil(rc) // check store count - c.Assert(rc.GetMetaCluster(), DeepEquals, meta) - c.Assert(rc.GetStoreCount(), Equals, 3) + re.Equal(meta, rc.GetMetaCluster()) + re.Equal(3, rc.GetStoreCount()) } diff --git a/tests/server/cluster/cluster_work_test.go b/tests/server/cluster/cluster_work_test.go index b3d9fdcf9e0..5dee7da02cd 100644 --- a/tests/server/cluster/cluster_work_test.go +++ b/tests/server/cluster/cluster_work_test.go @@ -17,44 +17,33 @@ package cluster_test import ( "context" "sort" + "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server/core" "github.com/tikv/pd/tests" ) -var _ = Suite(&clusterWorkerTestSuite{}) - -type clusterWorkerTestSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *clusterWorkerTestSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) -} - -func (s *clusterWorkerTestSuite) TearDownSuite(c *C) { - s.cancel() -} - -func (s *clusterWorkerTestSuite) TestValidRequestRegion(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 1) +func TestValidRequestRegion(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 1) defer cluster.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() leaderServer := cluster.GetServer(cluster.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClient(c, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() - bootstrapCluster(c, clusterID, grpcPDClient) + bootstrapCluster(re, clusterID, grpcPDClient) rc := leaderServer.GetRaftCluster() r1 := core.NewRegionInfo(&metapb.Region{ @@ -71,31 +60,34 @@ func (s *clusterWorkerTestSuite) TestValidRequestRegion(c *C) { StoreId: 1, }) err = rc.HandleRegionHeartbeat(r1) - c.Assert(err, IsNil) + re.NoError(err) r2 := &metapb.Region{Id: 2, StartKey: []byte("a"), EndKey: []byte("b")} - c.Assert(rc.ValidRequestRegion(r2), NotNil) + re.Error(rc.ValidRequestRegion(r2)) r3 := &metapb.Region{Id: 1, StartKey: []byte(""), EndKey: []byte("a"), RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 2}} - c.Assert(rc.ValidRequestRegion(r3), NotNil) + re.Error(rc.ValidRequestRegion(r3)) r4 := &metapb.Region{Id: 1, StartKey: []byte(""), EndKey: []byte("a"), RegionEpoch: &metapb.RegionEpoch{ConfVer: 2, Version: 1}} - c.Assert(rc.ValidRequestRegion(r4), NotNil) + re.Error(rc.ValidRequestRegion(r4)) r5 := &metapb.Region{Id: 1, StartKey: []byte(""), EndKey: []byte("a"), RegionEpoch: &metapb.RegionEpoch{ConfVer: 2, Version: 2}} - c.Assert(rc.ValidRequestRegion(r5), IsNil) + re.NoError(rc.ValidRequestRegion(r5)) rc.Stop() } -func (s *clusterWorkerTestSuite) TestAskSplit(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 1) +func TestAskSplit(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 1) defer cluster.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() leaderServer := cluster.GetServer(cluster.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClient(c, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() - bootstrapCluster(c, clusterID, grpcPDClient) + bootstrapCluster(re, clusterID, grpcPDClient) rc := leaderServer.GetRaftCluster() opt := rc.GetOpts() opt.SetSplitMergeInterval(time.Hour) @@ -109,7 +101,7 @@ func (s *clusterWorkerTestSuite) TestAskSplit(c *C) { } _, err = rc.HandleAskSplit(req) - c.Assert(err, IsNil) + re.NoError(err) req1 := &pdpb.AskBatchSplitRequest{ Header: &pdpb.RequestHeader{ @@ -120,27 +112,30 @@ func (s *clusterWorkerTestSuite) TestAskSplit(c *C) { } _, err = rc.HandleAskBatchSplit(req1) - c.Assert(err, IsNil) + re.NoError(err) // test region id whether valid opt.SetSplitMergeInterval(time.Duration(0)) mergeChecker := rc.GetMergeChecker() mergeChecker.Check(regions[0]) - c.Assert(err, IsNil) + re.NoError(err) } -func (s *clusterWorkerTestSuite) TestSuspectRegions(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 1) +func TestSuspectRegions(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 1) defer cluster.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() leaderServer := cluster.GetServer(cluster.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClient(c, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() - bootstrapCluster(c, clusterID, grpcPDClient) + bootstrapCluster(re, clusterID, grpcPDClient) rc := leaderServer.GetRaftCluster() opt := rc.GetOpts() opt.SetSplitMergeInterval(time.Hour) @@ -154,10 +149,10 @@ func (s *clusterWorkerTestSuite) TestSuspectRegions(c *C) { SplitCount: 2, } res, err := rc.HandleAskBatchSplit(req) - c.Assert(err, IsNil) + re.NoError(err) ids := []uint64{regions[0].GetMeta().GetId(), res.Ids[0].NewRegionId, res.Ids[1].NewRegionId} sort.Slice(ids, func(i, j int) bool { return ids[i] < ids[j] }) suspects := rc.GetSuspectRegions() sort.Slice(suspects, func(i, j int) bool { return suspects[i] < suspects[j] }) - c.Assert(suspects, DeepEquals, ids) + re.Equal(ids, suspects) } From 32bccb7ca1ce6ca5ffaf81e419bc341f4200fee9 Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Mon, 20 Jun 2022 16:06:37 +0800 Subject: [PATCH 57/82] checker: migrate test framework to testify (#5174) ref tikv/pd#4813 Signed-off-by: Ryan Leung Co-authored-by: Ti Chi Robot --- pkg/testutil/operator_check.go | 60 ++ .../checker/joint_state_checker_test.go | 61 +- .../schedule/checker/learner_checker_test.go | 44 +- server/schedule/checker/merge_checker_test.go | 445 +++++------ .../checker/priority_inspector_test.go | 49 +- .../schedule/checker/replica_checker_test.go | 251 +++--- server/schedule/checker/rule_checker_test.go | 741 +++++++++--------- server/schedule/checker/split_checker_test.go | 66 +- 8 files changed, 833 insertions(+), 884 deletions(-) diff --git a/pkg/testutil/operator_check.go b/pkg/testutil/operator_check.go index 90779b7059a..1df641e7e0a 100644 --- a/pkg/testutil/operator_check.go +++ b/pkg/testutil/operator_check.go @@ -16,6 +16,7 @@ package testutil import ( "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/server/schedule/operator" ) @@ -141,3 +142,62 @@ func CheckTransferPeerWithLeaderTransferFrom(c *check.C, op *operator.Operator, kind |= operator.OpRegion | operator.OpLeader c.Assert(op.Kind()&kind, check.Equals, kind) } + +// CheckAddPeerWithTestify checks if the operator is to add peer on specified store. +func CheckAddPeerWithTestify(re *require.Assertions, op *operator.Operator, kind operator.OpKind, storeID uint64) { + re.NotNil(op) + re.Equal(2, op.Len()) + re.Equal(storeID, op.Step(0).(operator.AddLearner).ToStore) + re.IsType(operator.PromoteLearner{}, op.Step(1)) + kind |= operator.OpRegion + re.Equal(kind, op.Kind()&kind) +} + +// CheckRemovePeerWithTestify checks if the operator is to remove peer on specified store. +func CheckRemovePeerWithTestify(re *require.Assertions, op *operator.Operator, storeID uint64) { + re.NotNil(op) + if op.Len() == 1 { + re.Equal(storeID, op.Step(0).(operator.RemovePeer).FromStore) + } else { + re.Equal(2, op.Len()) + re.Equal(storeID, op.Step(0).(operator.TransferLeader).FromStore) + re.Equal(storeID, op.Step(1).(operator.RemovePeer).FromStore) + } +} + +// CheckTransferPeerWithTestify checks if the operator is to transfer peer between the specified source and target stores. +func CheckTransferPeerWithTestify(re *require.Assertions, op *operator.Operator, kind operator.OpKind, sourceID, targetID uint64) { + re.NotNil(op) + + steps, _ := trimTransferLeaders(op) + re.Len(steps, 3) + re.Equal(targetID, steps[0].(operator.AddLearner).ToStore) + re.IsType(operator.PromoteLearner{}, steps[1]) + re.Equal(sourceID, steps[2].(operator.RemovePeer).FromStore) + kind |= operator.OpRegion + re.Equal(kind, op.Kind()&kind) +} + +// CheckSteps checks if the operator matches the given steps. +func CheckSteps(re *require.Assertions, op *operator.Operator, steps []operator.OpStep) { + re.NotEqual(0, op.Kind()&operator.OpMerge) + re.NotNil(steps) + re.Len(steps, op.Len()) + for i := range steps { + switch op.Step(i).(type) { + case operator.AddLearner: + re.Equal(steps[i].(operator.AddLearner).ToStore, op.Step(i).(operator.AddLearner).ToStore) + case operator.PromoteLearner: + re.Equal(steps[i].(operator.PromoteLearner).ToStore, op.Step(i).(operator.PromoteLearner).ToStore) + case operator.TransferLeader: + re.Equal(steps[i].(operator.TransferLeader).FromStore, op.Step(i).(operator.TransferLeader).FromStore) + re.Equal(steps[i].(operator.TransferLeader).ToStore, op.Step(i).(operator.TransferLeader).ToStore) + case operator.RemovePeer: + re.Equal(steps[i].(operator.RemovePeer).FromStore, op.Step(i).(operator.RemovePeer).FromStore) + case operator.MergeRegion: + re.Equal(steps[i].(operator.MergeRegion).IsPassive, op.Step(i).(operator.MergeRegion).IsPassive) + default: + re.FailNow("unknown operator step type") + } + } +} diff --git a/server/schedule/checker/joint_state_checker_test.go b/server/schedule/checker/joint_state_checker_test.go index 5d759c51e67..b350de469c4 100644 --- a/server/schedule/checker/joint_state_checker_test.go +++ b/server/schedule/checker/joint_state_checker_test.go @@ -16,42 +16,25 @@ package checker import ( "context" + "testing" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/mock/mockcluster" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/core" "github.com/tikv/pd/server/schedule/operator" ) -var _ = Suite(&testJointStateCheckerSuite{}) - -type testJointStateCheckerSuite struct { - cluster *mockcluster.Cluster - jsc *JointStateChecker - ctx context.Context - cancel context.CancelFunc -} - -func (s *testJointStateCheckerSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) -} - -func (s *testJointStateCheckerSuite) TearDownTest(c *C) { - s.cancel() -} - -func (s *testJointStateCheckerSuite) SetUpTest(c *C) { - s.cluster = mockcluster.NewCluster(s.ctx, config.NewTestOptions()) - s.jsc = NewJointStateChecker(s.cluster) +func TestLeaveJointState(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster := mockcluster.NewCluster(ctx, config.NewTestOptions()) + jsc := NewJointStateChecker(cluster) for id := uint64(1); id <= 10; id++ { - s.cluster.PutStoreWithLabels(id) + cluster.PutStoreWithLabels(id) } -} - -func (s *testJointStateCheckerSuite) TestLeaveJointState(c *C) { - jsc := s.jsc type testCase struct { Peers []*metapb.Peer // first is leader OpSteps []operator.OpStep @@ -131,38 +114,38 @@ func (s *testJointStateCheckerSuite) TestLeaveJointState(c *C) { for _, tc := range cases { region := core.NewRegionInfo(&metapb.Region{Id: 1, Peers: tc.Peers}, tc.Peers[0]) op := jsc.Check(region) - s.checkSteps(c, op, tc.OpSteps) + checkSteps(re, op, tc.OpSteps) } } -func (s *testJointStateCheckerSuite) checkSteps(c *C, op *operator.Operator, steps []operator.OpStep) { +func checkSteps(re *require.Assertions, op *operator.Operator, steps []operator.OpStep) { if len(steps) == 0 { - c.Assert(op, IsNil) + re.Nil(op) return } - c.Assert(op, NotNil) - c.Assert(op.Desc(), Equals, "leave-joint-state") + re.NotNil(op) + re.Equal("leave-joint-state", op.Desc()) - c.Assert(op.Len(), Equals, len(steps)) + re.Len(steps, op.Len()) for i := range steps { switch obtain := op.Step(i).(type) { case operator.ChangePeerV2Leave: expect := steps[i].(operator.ChangePeerV2Leave) - c.Assert(len(obtain.PromoteLearners), Equals, len(expect.PromoteLearners)) - c.Assert(len(obtain.DemoteVoters), Equals, len(expect.DemoteVoters)) + re.Equal(len(expect.PromoteLearners), len(obtain.PromoteLearners)) + re.Equal(len(expect.DemoteVoters), len(obtain.DemoteVoters)) for j, p := range expect.PromoteLearners { - c.Assert(expect.PromoteLearners[j].ToStore, Equals, p.ToStore) + re.Equal(p.ToStore, obtain.PromoteLearners[j].ToStore) } for j, d := range expect.DemoteVoters { - c.Assert(obtain.DemoteVoters[j].ToStore, Equals, d.ToStore) + re.Equal(d.ToStore, obtain.DemoteVoters[j].ToStore) } case operator.TransferLeader: expect := steps[i].(operator.TransferLeader) - c.Assert(obtain.FromStore, Equals, expect.FromStore) - c.Assert(obtain.ToStore, Equals, expect.ToStore) + re.Equal(expect.FromStore, obtain.FromStore) + re.Equal(expect.ToStore, obtain.ToStore) default: - c.Fatal("unknown operator step type") + re.FailNow("unknown operator step type") } } } diff --git a/server/schedule/checker/learner_checker_test.go b/server/schedule/checker/learner_checker_test.go index 1a403e79043..afe4b920313 100644 --- a/server/schedule/checker/learner_checker_test.go +++ b/server/schedule/checker/learner_checker_test.go @@ -16,9 +16,10 @@ package checker import ( "context" + "testing" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/mock/mockcluster" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/core" @@ -26,31 +27,16 @@ import ( "github.com/tikv/pd/server/versioninfo" ) -var _ = Suite(&testLearnerCheckerSuite{}) - -type testLearnerCheckerSuite struct { - cluster *mockcluster.Cluster - lc *LearnerChecker - ctx context.Context - cancel context.CancelFunc -} - -func (s *testLearnerCheckerSuite) SetUpTest(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) - s.cluster = mockcluster.NewCluster(s.ctx, config.NewTestOptions()) - s.cluster.SetClusterVersion(versioninfo.MinSupportedVersion(versioninfo.Version4_0)) - s.lc = NewLearnerChecker(s.cluster) +func TestPromoteLearner(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster := mockcluster.NewCluster(ctx, config.NewTestOptions()) + cluster.SetClusterVersion(versioninfo.MinSupportedVersion(versioninfo.Version4_0)) + lc := NewLearnerChecker(cluster) for id := uint64(1); id <= 10; id++ { - s.cluster.PutStoreWithLabels(id) + cluster.PutStoreWithLabels(id) } -} - -func (s *testLearnerCheckerSuite) TearDownTest(c *C) { - s.cancel() -} - -func (s *testLearnerCheckerSuite) TestPromoteLearner(c *C) { - lc := s.lc region := core.NewRegionInfo( &metapb.Region{ @@ -62,12 +48,12 @@ func (s *testLearnerCheckerSuite) TestPromoteLearner(c *C) { }, }, &metapb.Peer{Id: 101, StoreId: 1}) op := lc.Check(region) - c.Assert(op, NotNil) - c.Assert(op.Desc(), Equals, "promote-learner") - c.Assert(op.Step(0), FitsTypeOf, operator.PromoteLearner{}) - c.Assert(op.Step(0).(operator.PromoteLearner).ToStore, Equals, uint64(3)) + re.NotNil(op) + re.Equal("promote-learner", op.Desc()) + re.IsType(operator.PromoteLearner{}, op.Step(0)) + re.Equal(uint64(3), op.Step(0).(operator.PromoteLearner).ToStore) region = region.Clone(core.WithPendingPeers([]*metapb.Peer{region.GetPeer(103)})) op = lc.Check(region) - c.Assert(op, IsNil) + re.Nil(op) } diff --git a/server/schedule/checker/merge_checker_test.go b/server/schedule/checker/merge_checker_test.go index 21c6eeec410..b0f5c8ae270 100644 --- a/server/schedule/checker/merge_checker_test.go +++ b/server/schedule/checker/merge_checker_test.go @@ -20,8 +20,8 @@ import ( "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/mock/mockcluster" "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server/config" @@ -36,17 +36,12 @@ import ( "go.uber.org/goleak" ) -func TestMergeChecker(t *testing.T) { - TestingT(t) -} - func TestMain(m *testing.M) { goleak.VerifyTestMain(m, testutil.LeakOptions...) } -var _ = Suite(&testMergeCheckerSuite{}) - -type testMergeCheckerSuite struct { +type mergeCheckerTestSuite struct { + suite.Suite ctx context.Context cancel context.CancelFunc cluster *mockcluster.Cluster @@ -54,145 +49,146 @@ type testMergeCheckerSuite struct { regions []*core.RegionInfo } -func (s *testMergeCheckerSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) +func TestMergeCheckerTestSuite(t *testing.T) { + suite.Run(t, new(mergeCheckerTestSuite)) } -func (s *testMergeCheckerSuite) TearDownTest(c *C) { - s.cancel() -} - -func (s *testMergeCheckerSuite) SetUpTest(c *C) { +func (suite *mergeCheckerTestSuite) SetupTest() { cfg := config.NewTestOptions() - s.cluster = mockcluster.NewCluster(s.ctx, cfg) - s.cluster.SetMaxMergeRegionSize(2) - s.cluster.SetMaxMergeRegionKeys(2) - s.cluster.SetLabelPropertyConfig(config.LabelPropertyConfig{ + suite.ctx, suite.cancel = context.WithCancel(context.Background()) + suite.cluster = mockcluster.NewCluster(suite.ctx, cfg) + suite.cluster.SetMaxMergeRegionSize(2) + suite.cluster.SetMaxMergeRegionKeys(2) + suite.cluster.SetLabelPropertyConfig(config.LabelPropertyConfig{ config.RejectLeader: {{Key: "reject", Value: "leader"}}, }) - s.cluster.SetClusterVersion(versioninfo.MinSupportedVersion(versioninfo.Version4_0)) + suite.cluster.SetClusterVersion(versioninfo.MinSupportedVersion(versioninfo.Version4_0)) stores := map[uint64][]string{ 1: {}, 2: {}, 3: {}, 4: {}, 5: {}, 6: {}, 7: {"reject", "leader"}, 8: {"reject", "leader"}, } for storeID, labels := range stores { - s.cluster.PutStoreWithLabels(storeID, labels...) + suite.cluster.PutStoreWithLabels(storeID, labels...) } - s.regions = []*core.RegionInfo{ + suite.regions = []*core.RegionInfo{ newRegionInfo(1, "", "a", 1, 1, []uint64{101, 1}, []uint64{101, 1}, []uint64{102, 2}), newRegionInfo(2, "a", "t", 200, 200, []uint64{104, 4}, []uint64{103, 1}, []uint64{104, 4}, []uint64{105, 5}), newRegionInfo(3, "t", "x", 1, 1, []uint64{108, 6}, []uint64{106, 2}, []uint64{107, 5}, []uint64{108, 6}), newRegionInfo(4, "x", "", 1, 1, []uint64{109, 4}, []uint64{109, 4}), } - for _, region := range s.regions { - s.cluster.PutRegion(region) + for _, region := range suite.regions { + suite.cluster.PutRegion(region) } - s.mc = NewMergeChecker(s.ctx, s.cluster) + suite.mc = NewMergeChecker(suite.ctx, suite.cluster) } -func (s *testMergeCheckerSuite) TestBasic(c *C) { - s.cluster.SetSplitMergeInterval(0) +func (suite *mergeCheckerTestSuite) TearDownTest() { + suite.cancel() +} + +func (suite *mergeCheckerTestSuite) TestBasic() { + suite.cluster.SetSplitMergeInterval(0) // should with same peer count - ops := s.mc.Check(s.regions[0]) - c.Assert(ops, IsNil) + ops := suite.mc.Check(suite.regions[0]) + suite.Nil(ops) // The size should be small enough. - ops = s.mc.Check(s.regions[1]) - c.Assert(ops, IsNil) + ops = suite.mc.Check(suite.regions[1]) + suite.Nil(ops) // target region size is too large - s.cluster.PutRegion(s.regions[1].Clone(core.SetApproximateSize(600))) - ops = s.mc.Check(s.regions[2]) - c.Assert(ops, IsNil) + suite.cluster.PutRegion(suite.regions[1].Clone(core.SetApproximateSize(600))) + ops = suite.mc.Check(suite.regions[2]) + suite.Nil(ops) // it can merge if the max region size of the store is greater than the target region size. - config := s.cluster.GetStoreConfig() + config := suite.cluster.GetStoreConfig() config.RegionMaxSize = "10Gib" - ops = s.mc.Check(s.regions[2]) - c.Assert(ops, NotNil) + ops = suite.mc.Check(suite.regions[2]) + suite.NotNil(ops) config.RegionMaxSize = "144Mib" - ops = s.mc.Check(s.regions[2]) - c.Assert(ops, IsNil) + ops = suite.mc.Check(suite.regions[2]) + suite.Nil(ops) // change the size back - s.cluster.PutRegion(s.regions[1].Clone(core.SetApproximateSize(200))) - ops = s.mc.Check(s.regions[2]) - c.Assert(ops, NotNil) + suite.cluster.PutRegion(suite.regions[1].Clone(core.SetApproximateSize(200))) + ops = suite.mc.Check(suite.regions[2]) + suite.NotNil(ops) // Check merge with previous region. - c.Assert(ops[0].RegionID(), Equals, s.regions[2].GetID()) - c.Assert(ops[1].RegionID(), Equals, s.regions[1].GetID()) + suite.Equal(suite.regions[2].GetID(), ops[0].RegionID()) + suite.Equal(suite.regions[1].GetID(), ops[1].RegionID()) // Test the peer store check. - store := s.cluster.GetStore(1) - c.Assert(store, NotNil) + store := suite.cluster.GetStore(1) + suite.NotNil(store) // Test the peer store is deleted. - s.cluster.DeleteStore(store) - ops = s.mc.Check(s.regions[2]) - c.Assert(ops, IsNil) + suite.cluster.DeleteStore(store) + ops = suite.mc.Check(suite.regions[2]) + suite.Nil(ops) // Test the store is normal. - s.cluster.PutStore(store) - ops = s.mc.Check(s.regions[2]) - c.Assert(ops, NotNil) - c.Assert(ops[0].RegionID(), Equals, s.regions[2].GetID()) - c.Assert(ops[1].RegionID(), Equals, s.regions[1].GetID()) + suite.cluster.PutStore(store) + ops = suite.mc.Check(suite.regions[2]) + suite.NotNil(ops) + suite.Equal(suite.regions[2].GetID(), ops[0].RegionID()) + suite.Equal(suite.regions[1].GetID(), ops[1].RegionID()) // Test the store is offline. - s.cluster.SetStoreOffline(store.GetID()) - ops = s.mc.Check(s.regions[2]) + suite.cluster.SetStoreOffline(store.GetID()) + ops = suite.mc.Check(suite.regions[2]) // Only target region have a peer on the offline store, // so it's not ok to merge. - c.Assert(ops, IsNil) + suite.Nil(ops) // Test the store is up. - s.cluster.SetStoreUp(store.GetID()) - ops = s.mc.Check(s.regions[2]) - c.Assert(ops, NotNil) - c.Assert(ops[0].RegionID(), Equals, s.regions[2].GetID()) - c.Assert(ops[1].RegionID(), Equals, s.regions[1].GetID()) - store = s.cluster.GetStore(5) - c.Assert(store, NotNil) + suite.cluster.SetStoreUp(store.GetID()) + ops = suite.mc.Check(suite.regions[2]) + suite.NotNil(ops) + suite.Equal(suite.regions[2].GetID(), ops[0].RegionID()) + suite.Equal(suite.regions[1].GetID(), ops[1].RegionID()) + store = suite.cluster.GetStore(5) + suite.NotNil(store) // Test the peer store is deleted. - s.cluster.DeleteStore(store) - ops = s.mc.Check(s.regions[2]) - c.Assert(ops, IsNil) + suite.cluster.DeleteStore(store) + ops = suite.mc.Check(suite.regions[2]) + suite.Nil(ops) // Test the store is normal. - s.cluster.PutStore(store) - ops = s.mc.Check(s.regions[2]) - c.Assert(ops, NotNil) - c.Assert(ops[0].RegionID(), Equals, s.regions[2].GetID()) - c.Assert(ops[1].RegionID(), Equals, s.regions[1].GetID()) + suite.cluster.PutStore(store) + ops = suite.mc.Check(suite.regions[2]) + suite.NotNil(ops) + suite.Equal(suite.regions[2].GetID(), ops[0].RegionID()) + suite.Equal(suite.regions[1].GetID(), ops[1].RegionID()) // Test the store is offline. - s.cluster.SetStoreOffline(store.GetID()) - ops = s.mc.Check(s.regions[2]) + suite.cluster.SetStoreOffline(store.GetID()) + ops = suite.mc.Check(suite.regions[2]) // Both regions have peers on the offline store, // so it's ok to merge. - c.Assert(ops, NotNil) - c.Assert(ops[0].RegionID(), Equals, s.regions[2].GetID()) - c.Assert(ops[1].RegionID(), Equals, s.regions[1].GetID()) + suite.NotNil(ops) + suite.Equal(suite.regions[2].GetID(), ops[0].RegionID()) + suite.Equal(suite.regions[1].GetID(), ops[1].RegionID()) // Test the store is up. - s.cluster.SetStoreUp(store.GetID()) - ops = s.mc.Check(s.regions[2]) - c.Assert(ops, NotNil) - c.Assert(ops[0].RegionID(), Equals, s.regions[2].GetID()) - c.Assert(ops[1].RegionID(), Equals, s.regions[1].GetID()) + suite.cluster.SetStoreUp(store.GetID()) + ops = suite.mc.Check(suite.regions[2]) + suite.NotNil(ops) + suite.Equal(suite.regions[2].GetID(), ops[0].RegionID()) + suite.Equal(suite.regions[1].GetID(), ops[1].RegionID()) // Enable one way merge - s.cluster.SetEnableOneWayMerge(true) - ops = s.mc.Check(s.regions[2]) - c.Assert(ops, IsNil) - s.cluster.SetEnableOneWayMerge(false) + suite.cluster.SetEnableOneWayMerge(true) + ops = suite.mc.Check(suite.regions[2]) + suite.Nil(ops) + suite.cluster.SetEnableOneWayMerge(false) // Make up peers for next region. - s.regions[3] = s.regions[3].Clone(core.WithAddPeer(&metapb.Peer{Id: 110, StoreId: 1}), core.WithAddPeer(&metapb.Peer{Id: 111, StoreId: 2})) - s.cluster.PutRegion(s.regions[3]) - ops = s.mc.Check(s.regions[2]) - c.Assert(ops, NotNil) + suite.regions[3] = suite.regions[3].Clone(core.WithAddPeer(&metapb.Peer{Id: 110, StoreId: 1}), core.WithAddPeer(&metapb.Peer{Id: 111, StoreId: 2})) + suite.cluster.PutRegion(suite.regions[3]) + ops = suite.mc.Check(suite.regions[2]) + suite.NotNil(ops) // Now it merges to next region. - c.Assert(ops[0].RegionID(), Equals, s.regions[2].GetID()) - c.Assert(ops[1].RegionID(), Equals, s.regions[3].GetID()) + suite.Equal(suite.regions[2].GetID(), ops[0].RegionID()) + suite.Equal(suite.regions[3].GetID(), ops[1].RegionID()) // merge cannot across rule key. - s.cluster.SetEnablePlacementRules(true) - s.cluster.RuleManager.SetRule(&placement.Rule{ + suite.cluster.SetEnablePlacementRules(true) + suite.cluster.RuleManager.SetRule(&placement.Rule{ GroupID: "pd", ID: "test", Index: 1, @@ -203,83 +199,60 @@ func (s *testMergeCheckerSuite) TestBasic(c *C) { Count: 3, }) // region 2 can only merge with previous region now. - ops = s.mc.Check(s.regions[2]) - c.Assert(ops, NotNil) - c.Assert(ops[0].RegionID(), Equals, s.regions[2].GetID()) - c.Assert(ops[1].RegionID(), Equals, s.regions[1].GetID()) - s.cluster.RuleManager.DeleteRule("pd", "test") + ops = suite.mc.Check(suite.regions[2]) + suite.NotNil(ops) + suite.Equal(suite.regions[2].GetID(), ops[0].RegionID()) + suite.Equal(suite.regions[1].GetID(), ops[1].RegionID()) + suite.cluster.RuleManager.DeleteRule("pd", "test") // check 'merge_option' label - s.cluster.GetRegionLabeler().SetLabelRule(&labeler.LabelRule{ + suite.cluster.GetRegionLabeler().SetLabelRule(&labeler.LabelRule{ ID: "test", Labels: []labeler.RegionLabel{{Key: mergeOptionLabel, Value: mergeOptionValueDeny}}, RuleType: labeler.KeyRange, Data: makeKeyRanges("", "74"), }) - ops = s.mc.Check(s.regions[0]) - c.Assert(ops, HasLen, 0) - ops = s.mc.Check(s.regions[1]) - c.Assert(ops, HasLen, 0) + ops = suite.mc.Check(suite.regions[0]) + suite.Len(ops, 0) + ops = suite.mc.Check(suite.regions[1]) + suite.Len(ops, 0) // Skip recently split regions. - s.cluster.SetSplitMergeInterval(time.Hour) - ops = s.mc.Check(s.regions[2]) - c.Assert(ops, IsNil) - - s.mc.startTime = time.Now().Add(-2 * time.Hour) - ops = s.mc.Check(s.regions[2]) - c.Assert(ops, NotNil) - ops = s.mc.Check(s.regions[3]) - c.Assert(ops, NotNil) - - s.mc.RecordRegionSplit([]uint64{s.regions[2].GetID()}) - ops = s.mc.Check(s.regions[2]) - c.Assert(ops, IsNil) - ops = s.mc.Check(s.regions[3]) - c.Assert(ops, IsNil) - - s.cluster.SetSplitMergeInterval(500 * time.Millisecond) - ops = s.mc.Check(s.regions[2]) - c.Assert(ops, IsNil) - ops = s.mc.Check(s.regions[3]) - c.Assert(ops, IsNil) + suite.cluster.SetSplitMergeInterval(time.Hour) + ops = suite.mc.Check(suite.regions[2]) + suite.Nil(ops) + + suite.mc.startTime = time.Now().Add(-2 * time.Hour) + ops = suite.mc.Check(suite.regions[2]) + suite.NotNil(ops) + ops = suite.mc.Check(suite.regions[3]) + suite.NotNil(ops) + + suite.mc.RecordRegionSplit([]uint64{suite.regions[2].GetID()}) + ops = suite.mc.Check(suite.regions[2]) + suite.Nil(ops) + ops = suite.mc.Check(suite.regions[3]) + suite.Nil(ops) + + suite.cluster.SetSplitMergeInterval(500 * time.Millisecond) + ops = suite.mc.Check(suite.regions[2]) + suite.Nil(ops) + ops = suite.mc.Check(suite.regions[3]) + suite.Nil(ops) time.Sleep(500 * time.Millisecond) - ops = s.mc.Check(s.regions[2]) - c.Assert(ops, NotNil) - ops = s.mc.Check(s.regions[3]) - c.Assert(ops, NotNil) -} - -func (s *testMergeCheckerSuite) checkSteps(c *C, op *operator.Operator, steps []operator.OpStep) { - c.Assert(op.Kind()&operator.OpMerge, Not(Equals), 0) - c.Assert(steps, NotNil) - c.Assert(op.Len(), Equals, len(steps)) - for i := range steps { - switch op.Step(i).(type) { - case operator.AddLearner: - c.Assert(op.Step(i).(operator.AddLearner).ToStore, Equals, steps[i].(operator.AddLearner).ToStore) - case operator.PromoteLearner: - c.Assert(op.Step(i).(operator.PromoteLearner).ToStore, Equals, steps[i].(operator.PromoteLearner).ToStore) - case operator.TransferLeader: - c.Assert(op.Step(i).(operator.TransferLeader).FromStore, Equals, steps[i].(operator.TransferLeader).FromStore) - c.Assert(op.Step(i).(operator.TransferLeader).ToStore, Equals, steps[i].(operator.TransferLeader).ToStore) - case operator.RemovePeer: - c.Assert(op.Step(i).(operator.RemovePeer).FromStore, Equals, steps[i].(operator.RemovePeer).FromStore) - case operator.MergeRegion: - c.Assert(op.Step(i).(operator.MergeRegion).IsPassive, Equals, steps[i].(operator.MergeRegion).IsPassive) - default: - c.Fatal("unknown operator step type") - } - } + ops = suite.mc.Check(suite.regions[2]) + suite.NotNil(ops) + ops = suite.mc.Check(suite.regions[3]) + suite.NotNil(ops) } -func (s *testMergeCheckerSuite) TestMatchPeers(c *C) { - s.cluster.SetSplitMergeInterval(0) +func (suite *mergeCheckerTestSuite) TestMatchPeers() { + suite.cluster.SetSplitMergeInterval(0) // partial store overlap not including leader - ops := s.mc.Check(s.regions[2]) - c.Assert(ops, NotNil) - s.checkSteps(c, ops[0], []operator.OpStep{ + ops := suite.mc.Check(suite.regions[2]) + suite.NotNil(ops) + testutil.CheckSteps(suite.Require(), ops[0], []operator.OpStep{ operator.AddLearner{ToStore: 1}, operator.PromoteLearner{ToStore: 1}, operator.RemovePeer{FromStore: 2}, @@ -288,21 +261,21 @@ func (s *testMergeCheckerSuite) TestMatchPeers(c *C) { operator.TransferLeader{FromStore: 6, ToStore: 5}, operator.RemovePeer{FromStore: 6}, operator.MergeRegion{ - FromRegion: s.regions[2].GetMeta(), - ToRegion: s.regions[1].GetMeta(), + FromRegion: suite.regions[2].GetMeta(), + ToRegion: suite.regions[1].GetMeta(), IsPassive: false, }, }) - s.checkSteps(c, ops[1], []operator.OpStep{ + testutil.CheckSteps(suite.Require(), ops[1], []operator.OpStep{ operator.MergeRegion{ - FromRegion: s.regions[2].GetMeta(), - ToRegion: s.regions[1].GetMeta(), + FromRegion: suite.regions[2].GetMeta(), + ToRegion: suite.regions[1].GetMeta(), IsPassive: true, }, }) // partial store overlap including leader - newRegion := s.regions[2].Clone( + newRegion := suite.regions[2].Clone( core.SetPeers([]*metapb.Peer{ {Id: 106, StoreId: 1}, {Id: 107, StoreId: 5}, @@ -310,59 +283,59 @@ func (s *testMergeCheckerSuite) TestMatchPeers(c *C) { }), core.WithLeader(&metapb.Peer{Id: 106, StoreId: 1}), ) - s.regions[2] = newRegion - s.cluster.PutRegion(s.regions[2]) - ops = s.mc.Check(s.regions[2]) - s.checkSteps(c, ops[0], []operator.OpStep{ + suite.regions[2] = newRegion + suite.cluster.PutRegion(suite.regions[2]) + ops = suite.mc.Check(suite.regions[2]) + testutil.CheckSteps(suite.Require(), ops[0], []operator.OpStep{ operator.AddLearner{ToStore: 4}, operator.PromoteLearner{ToStore: 4}, operator.RemovePeer{FromStore: 6}, operator.MergeRegion{ - FromRegion: s.regions[2].GetMeta(), - ToRegion: s.regions[1].GetMeta(), + FromRegion: suite.regions[2].GetMeta(), + ToRegion: suite.regions[1].GetMeta(), IsPassive: false, }, }) - s.checkSteps(c, ops[1], []operator.OpStep{ + testutil.CheckSteps(suite.Require(), ops[1], []operator.OpStep{ operator.MergeRegion{ - FromRegion: s.regions[2].GetMeta(), - ToRegion: s.regions[1].GetMeta(), + FromRegion: suite.regions[2].GetMeta(), + ToRegion: suite.regions[1].GetMeta(), IsPassive: true, }, }) // all stores overlap - s.regions[2] = s.regions[2].Clone(core.SetPeers([]*metapb.Peer{ + suite.regions[2] = suite.regions[2].Clone(core.SetPeers([]*metapb.Peer{ {Id: 106, StoreId: 1}, {Id: 107, StoreId: 5}, {Id: 108, StoreId: 4}, })) - s.cluster.PutRegion(s.regions[2]) - ops = s.mc.Check(s.regions[2]) - s.checkSteps(c, ops[0], []operator.OpStep{ + suite.cluster.PutRegion(suite.regions[2]) + ops = suite.mc.Check(suite.regions[2]) + testutil.CheckSteps(suite.Require(), ops[0], []operator.OpStep{ operator.MergeRegion{ - FromRegion: s.regions[2].GetMeta(), - ToRegion: s.regions[1].GetMeta(), + FromRegion: suite.regions[2].GetMeta(), + ToRegion: suite.regions[1].GetMeta(), IsPassive: false, }, }) - s.checkSteps(c, ops[1], []operator.OpStep{ + testutil.CheckSteps(suite.Require(), ops[1], []operator.OpStep{ operator.MergeRegion{ - FromRegion: s.regions[2].GetMeta(), - ToRegion: s.regions[1].GetMeta(), + FromRegion: suite.regions[2].GetMeta(), + ToRegion: suite.regions[1].GetMeta(), IsPassive: true, }, }) // all stores not overlap - s.regions[2] = s.regions[2].Clone(core.SetPeers([]*metapb.Peer{ + suite.regions[2] = suite.regions[2].Clone(core.SetPeers([]*metapb.Peer{ {Id: 109, StoreId: 2}, {Id: 110, StoreId: 3}, {Id: 111, StoreId: 6}, }), core.WithLeader(&metapb.Peer{Id: 109, StoreId: 2})) - s.cluster.PutRegion(s.regions[2]) - ops = s.mc.Check(s.regions[2]) - s.checkSteps(c, ops[0], []operator.OpStep{ + suite.cluster.PutRegion(suite.regions[2]) + ops = suite.mc.Check(suite.regions[2]) + testutil.CheckSteps(suite.Require(), ops[0], []operator.OpStep{ operator.AddLearner{ToStore: 1}, operator.PromoteLearner{ToStore: 1}, operator.RemovePeer{FromStore: 3}, @@ -374,21 +347,21 @@ func (s *testMergeCheckerSuite) TestMatchPeers(c *C) { operator.TransferLeader{FromStore: 2, ToStore: 1}, operator.RemovePeer{FromStore: 2}, operator.MergeRegion{ - FromRegion: s.regions[2].GetMeta(), - ToRegion: s.regions[1].GetMeta(), + FromRegion: suite.regions[2].GetMeta(), + ToRegion: suite.regions[1].GetMeta(), IsPassive: false, }, }) - s.checkSteps(c, ops[1], []operator.OpStep{ + testutil.CheckSteps(suite.Require(), ops[1], []operator.OpStep{ operator.MergeRegion{ - FromRegion: s.regions[2].GetMeta(), - ToRegion: s.regions[1].GetMeta(), + FromRegion: suite.regions[2].GetMeta(), + ToRegion: suite.regions[1].GetMeta(), IsPassive: true, }, }) // no overlap with reject leader label - s.regions[1] = s.regions[1].Clone( + suite.regions[1] = suite.regions[1].Clone( core.SetPeers([]*metapb.Peer{ {Id: 112, StoreId: 7}, {Id: 113, StoreId: 8}, @@ -396,9 +369,9 @@ func (s *testMergeCheckerSuite) TestMatchPeers(c *C) { }), core.WithLeader(&metapb.Peer{Id: 114, StoreId: 1}), ) - s.cluster.PutRegion(s.regions[1]) - ops = s.mc.Check(s.regions[2]) - s.checkSteps(c, ops[0], []operator.OpStep{ + suite.cluster.PutRegion(suite.regions[1]) + ops = suite.mc.Check(suite.regions[2]) + testutil.CheckSteps(suite.Require(), ops[0], []operator.OpStep{ operator.AddLearner{ToStore: 1}, operator.PromoteLearner{ToStore: 1}, operator.RemovePeer{FromStore: 3}, @@ -413,21 +386,21 @@ func (s *testMergeCheckerSuite) TestMatchPeers(c *C) { operator.RemovePeer{FromStore: 2}, operator.MergeRegion{ - FromRegion: s.regions[2].GetMeta(), - ToRegion: s.regions[1].GetMeta(), + FromRegion: suite.regions[2].GetMeta(), + ToRegion: suite.regions[1].GetMeta(), IsPassive: false, }, }) - s.checkSteps(c, ops[1], []operator.OpStep{ + testutil.CheckSteps(suite.Require(), ops[1], []operator.OpStep{ operator.MergeRegion{ - FromRegion: s.regions[2].GetMeta(), - ToRegion: s.regions[1].GetMeta(), + FromRegion: suite.regions[2].GetMeta(), + ToRegion: suite.regions[1].GetMeta(), IsPassive: true, }, }) // overlap with reject leader label - s.regions[1] = s.regions[1].Clone( + suite.regions[1] = suite.regions[1].Clone( core.SetPeers([]*metapb.Peer{ {Id: 115, StoreId: 7}, {Id: 116, StoreId: 8}, @@ -435,7 +408,7 @@ func (s *testMergeCheckerSuite) TestMatchPeers(c *C) { }), core.WithLeader(&metapb.Peer{Id: 117, StoreId: 1}), ) - s.regions[2] = s.regions[2].Clone( + suite.regions[2] = suite.regions[2].Clone( core.SetPeers([]*metapb.Peer{ {Id: 118, StoreId: 7}, {Id: 119, StoreId: 3}, @@ -443,9 +416,9 @@ func (s *testMergeCheckerSuite) TestMatchPeers(c *C) { }), core.WithLeader(&metapb.Peer{Id: 120, StoreId: 2}), ) - s.cluster.PutRegion(s.regions[1]) - ops = s.mc.Check(s.regions[2]) - s.checkSteps(c, ops[0], []operator.OpStep{ + suite.cluster.PutRegion(suite.regions[1]) + ops = suite.mc.Check(suite.regions[2]) + testutil.CheckSteps(suite.Require(), ops[0], []operator.OpStep{ operator.AddLearner{ToStore: 1}, operator.PromoteLearner{ToStore: 1}, operator.RemovePeer{FromStore: 3}, @@ -454,23 +427,23 @@ func (s *testMergeCheckerSuite) TestMatchPeers(c *C) { operator.TransferLeader{FromStore: 2, ToStore: 1}, operator.RemovePeer{FromStore: 2}, operator.MergeRegion{ - FromRegion: s.regions[2].GetMeta(), - ToRegion: s.regions[1].GetMeta(), + FromRegion: suite.regions[2].GetMeta(), + ToRegion: suite.regions[1].GetMeta(), IsPassive: false, }, }) - s.checkSteps(c, ops[1], []operator.OpStep{ + testutil.CheckSteps(suite.Require(), ops[1], []operator.OpStep{ operator.MergeRegion{ - FromRegion: s.regions[2].GetMeta(), - ToRegion: s.regions[1].GetMeta(), + FromRegion: suite.regions[2].GetMeta(), + ToRegion: suite.regions[1].GetMeta(), IsPassive: true, }, }) } -func (s *testMergeCheckerSuite) TestStoreLimitWithMerge(c *C) { +func (suite *mergeCheckerTestSuite) TestStoreLimitWithMerge() { cfg := config.NewTestOptions() - tc := mockcluster.NewCluster(s.ctx, cfg) + tc := mockcluster.NewCluster(suite.ctx, cfg) tc.SetMaxMergeRegionSize(2) tc.SetMaxMergeRegionKeys(2) tc.SetSplitMergeInterval(0) @@ -489,9 +462,9 @@ func (s *testMergeCheckerSuite) TestStoreLimitWithMerge(c *C) { tc.PutRegion(region) } - mc := NewMergeChecker(s.ctx, tc) - stream := hbstream.NewTestHeartbeatStreams(s.ctx, tc.ID, tc, false /* no need to run */) - oc := schedule.NewOperatorController(s.ctx, tc, stream) + mc := NewMergeChecker(suite.ctx, tc) + stream := hbstream.NewTestHeartbeatStreams(suite.ctx, tc.ID, tc, false /* no need to run */) + oc := schedule.NewOperatorController(suite.ctx, tc, stream) regions[2] = regions[2].Clone( core.SetPeers([]*metapb.Peer{ @@ -509,8 +482,8 @@ func (s *testMergeCheckerSuite) TestStoreLimitWithMerge(c *C) { // The size of Region is less or equal than 1MB. for i := 0; i < 50; i++ { ops := mc.Check(regions[2]) - c.Assert(ops, NotNil) - c.Assert(oc.AddOperator(ops...), IsTrue) + suite.NotNil(ops) + suite.True(oc.AddOperator(ops...)) for _, op := range ops { oc.RemoveOperator(op) } @@ -523,49 +496,49 @@ func (s *testMergeCheckerSuite) TestStoreLimitWithMerge(c *C) { // The size of Region is more than 1MB but no more than 20MB. for i := 0; i < 5; i++ { ops := mc.Check(regions[2]) - c.Assert(ops, NotNil) - c.Assert(oc.AddOperator(ops...), IsTrue) + suite.NotNil(ops) + suite.True(oc.AddOperator(ops...)) for _, op := range ops { oc.RemoveOperator(op) } } { ops := mc.Check(regions[2]) - c.Assert(ops, NotNil) - c.Assert(oc.AddOperator(ops...), IsFalse) + suite.NotNil(ops) + suite.False(oc.AddOperator(ops...)) } } -func (s *testMergeCheckerSuite) TestCache(c *C) { +func (suite *mergeCheckerTestSuite) TestCache() { cfg := config.NewTestOptions() - s.cluster = mockcluster.NewCluster(s.ctx, cfg) - s.cluster.SetMaxMergeRegionSize(2) - s.cluster.SetMaxMergeRegionKeys(2) - s.cluster.SetSplitMergeInterval(time.Hour) - s.cluster.SetClusterVersion(versioninfo.MinSupportedVersion(versioninfo.Version4_0)) + suite.cluster = mockcluster.NewCluster(suite.ctx, cfg) + suite.cluster.SetMaxMergeRegionSize(2) + suite.cluster.SetMaxMergeRegionKeys(2) + suite.cluster.SetSplitMergeInterval(time.Hour) + suite.cluster.SetClusterVersion(versioninfo.MinSupportedVersion(versioninfo.Version4_0)) stores := map[uint64][]string{ 1: {}, 2: {}, 3: {}, 4: {}, 5: {}, 6: {}, } for storeID, labels := range stores { - s.cluster.PutStoreWithLabels(storeID, labels...) + suite.cluster.PutStoreWithLabels(storeID, labels...) } - s.regions = []*core.RegionInfo{ + suite.regions = []*core.RegionInfo{ newRegionInfo(2, "a", "t", 200, 200, []uint64{104, 4}, []uint64{103, 1}, []uint64{104, 4}, []uint64{105, 5}), newRegionInfo(3, "t", "x", 1, 1, []uint64{108, 6}, []uint64{106, 2}, []uint64{107, 5}, []uint64{108, 6}), } - for _, region := range s.regions { - s.cluster.PutRegion(region) + for _, region := range suite.regions { + suite.cluster.PutRegion(region) } - s.mc = NewMergeChecker(s.ctx, s.cluster) + suite.mc = NewMergeChecker(suite.ctx, suite.cluster) - ops := s.mc.Check(s.regions[1]) - c.Assert(ops, IsNil) - s.cluster.SetSplitMergeInterval(0) + ops := suite.mc.Check(suite.regions[1]) + suite.Nil(ops) + suite.cluster.SetSplitMergeInterval(0) time.Sleep(time.Second) - ops = s.mc.Check(s.regions[1]) - c.Assert(ops, NotNil) + ops = suite.mc.Check(suite.regions[1]) + suite.NotNil(ops) } func makeKeyRanges(keys ...string) []interface{} { diff --git a/server/schedule/checker/priority_inspector_test.go b/server/schedule/checker/priority_inspector_test.go index 319c330d359..35662846c4a 100644 --- a/server/schedule/checker/priority_inspector_test.go +++ b/server/schedule/checker/priority_inspector_test.go @@ -16,31 +16,20 @@ package checker import ( "context" + "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/mock/mockcluster" "github.com/tikv/pd/server/config" ) -var _ = Suite(&testPriorityInspectorSuite{}) - -type testPriorityInspectorSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *testPriorityInspectorSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) -} - -func (s *testPriorityInspectorSuite) TearDownTest(c *C) { - s.cancel() -} - -func (s *testPriorityInspectorSuite) TestCheckPriorityRegions(c *C) { +func TestCheckPriorityRegions(t *testing.T) { + re := require.New(t) opt := config.NewTestOptions() - tc := mockcluster.NewCluster(s.ctx, opt) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tc := mockcluster.NewCluster(ctx, opt) tc.AddRegionStore(1, 0) tc.AddRegionStore(2, 0) tc.AddRegionStore(3, 0) @@ -49,42 +38,42 @@ func (s *testPriorityInspectorSuite) TestCheckPriorityRegions(c *C) { tc.AddLeaderRegion(3, 2) pc := NewPriorityInspector(tc) - checkPriorityRegionTest(pc, tc, c) + checkPriorityRegionTest(re, pc, tc) opt.SetPlacementRuleEnabled(true) - c.Assert(opt.IsPlacementRulesEnabled(), IsTrue) - checkPriorityRegionTest(pc, tc, c) + re.True(opt.IsPlacementRulesEnabled()) + checkPriorityRegionTest(re, pc, tc) } -func checkPriorityRegionTest(pc *PriorityInspector, tc *mockcluster.Cluster, c *C) { +func checkPriorityRegionTest(re *require.Assertions, pc *PriorityInspector, tc *mockcluster.Cluster) { // case1: inspect region 1, it doesn't lack replica region := tc.GetRegion(1) opt := tc.GetOpts() pc.Inspect(region) - c.Assert(0, Equals, pc.queue.Len()) + re.Equal(0, pc.queue.Len()) // case2: inspect region 2, it lacks one replica region = tc.GetRegion(2) pc.Inspect(region) - c.Assert(1, Equals, pc.queue.Len()) + re.Equal(1, pc.queue.Len()) // the region will not rerun after it checks - c.Assert(0, Equals, len(pc.GetPriorityRegions())) + re.Equal(0, len(pc.GetPriorityRegions())) // case3: inspect region 3, it will has high priority region = tc.GetRegion(3) pc.Inspect(region) - c.Assert(2, Equals, pc.queue.Len()) + re.Equal(2, pc.queue.Len()) time.Sleep(opt.GetPatrolRegionInterval() * 10) // region 3 has higher priority ids := pc.GetPriorityRegions() - c.Assert(2, Equals, len(ids)) - c.Assert(uint64(3), Equals, ids[0]) - c.Assert(uint64(2), Equals, ids[1]) + re.Equal(2, len(ids)) + re.Equal(uint64(3), ids[0]) + re.Equal(uint64(2), ids[1]) // case4: inspect region 2 again after it fixup replicas tc.AddLeaderRegion(2, 2, 3, 1) region = tc.GetRegion(2) pc.Inspect(region) - c.Assert(1, Equals, pc.queue.Len()) + re.Equal(1, pc.queue.Len()) // recover tc.AddLeaderRegion(2, 2, 3) diff --git a/server/schedule/checker/replica_checker_test.go b/server/schedule/checker/replica_checker_test.go index 87c813d0111..8a0327c09c7 100644 --- a/server/schedule/checker/replica_checker_test.go +++ b/server/schedule/checker/replica_checker_test.go @@ -16,11 +16,12 @@ package checker import ( "context" + "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/cache" "github.com/tikv/pd/pkg/mock/mockcluster" "github.com/tikv/pd/pkg/testutil" @@ -35,28 +36,24 @@ const ( MB = 1024 * KB ) -var _ = Suite(&testReplicaCheckerSuite{}) - -type testReplicaCheckerSuite struct { +type replicaCheckerTestSuite struct { + suite.Suite cluster *mockcluster.Cluster rc *ReplicaChecker ctx context.Context cancel context.CancelFunc } -func (s *testReplicaCheckerSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) -} - -func (s *testReplicaCheckerSuite) TearDownTest(c *C) { - s.cancel() +func TestReplicaCheckerTestSuite(t *testing.T) { + suite.Run(t, new(replicaCheckerTestSuite)) } -func (s *testReplicaCheckerSuite) SetUpTest(c *C) { +func (suite *replicaCheckerTestSuite) SetupTest() { cfg := config.NewTestOptions() - s.cluster = mockcluster.NewCluster(s.ctx, cfg) - s.cluster.SetClusterVersion(versioninfo.MinSupportedVersion(versioninfo.Version4_0)) - s.rc = NewReplicaChecker(s.cluster, cache.NewDefaultCache(10)) + suite.ctx, suite.cancel = context.WithCancel(context.Background()) + suite.cluster = mockcluster.NewCluster(suite.ctx, cfg) + suite.cluster.SetClusterVersion(versioninfo.MinSupportedVersion(versioninfo.Version4_0)) + suite.rc = NewReplicaChecker(suite.cluster, cache.NewDefaultCache(10)) stats := &pdpb.StoreStats{ Capacity: 100, Available: 100, @@ -88,12 +85,16 @@ func (s *testReplicaCheckerSuite) SetUpTest(c *C) { ), } for _, store := range stores { - s.cluster.PutStore(store) + suite.cluster.PutStore(store) } - s.cluster.AddLabelsStore(2, 1, map[string]string{"noleader": "true"}) + suite.cluster.AddLabelsStore(2, 1, map[string]string{"noleader": "true"}) +} + +func (suite *replicaCheckerTestSuite) TearDownTest() { + suite.cancel() } -func (s *testReplicaCheckerSuite) TestReplacePendingPeer(c *C) { +func (suite *replicaCheckerTestSuite) TestReplacePendingPeer() { peers := []*metapb.Peer{ { Id: 2, @@ -109,16 +110,16 @@ func (s *testReplicaCheckerSuite) TestReplacePendingPeer(c *C) { }, } r := core.NewRegionInfo(&metapb.Region{Id: 1, Peers: peers}, peers[1], core.WithPendingPeers(peers[0:1])) - s.cluster.PutRegion(r) - op := s.rc.Check(r) - c.Assert(op, NotNil) - c.Assert(op.Step(0).(operator.AddLearner).ToStore, Equals, uint64(4)) - c.Assert(op.Step(1).(operator.PromoteLearner).ToStore, Equals, uint64(4)) - c.Assert(op.Step(2).(operator.RemovePeer).FromStore, Equals, uint64(1)) + suite.cluster.PutRegion(r) + op := suite.rc.Check(r) + suite.NotNil(op) + suite.Equal(uint64(4), op.Step(0).(operator.AddLearner).ToStore) + suite.Equal(uint64(4), op.Step(1).(operator.PromoteLearner).ToStore) + suite.Equal(uint64(1), op.Step(2).(operator.RemovePeer).FromStore) } -func (s *testReplicaCheckerSuite) TestReplaceOfflinePeer(c *C) { - s.cluster.SetLabelPropertyConfig(config.LabelPropertyConfig{ +func (suite *replicaCheckerTestSuite) TestReplaceOfflinePeer() { + suite.cluster.SetLabelPropertyConfig(config.LabelPropertyConfig{ config.RejectLeader: {{Key: "noleader", Value: "true"}}, }) peers := []*metapb.Peer{ @@ -136,17 +137,17 @@ func (s *testReplicaCheckerSuite) TestReplaceOfflinePeer(c *C) { }, } r := core.NewRegionInfo(&metapb.Region{Id: 2, Peers: peers}, peers[0]) - s.cluster.PutRegion(r) - op := s.rc.Check(r) - c.Assert(op, NotNil) - c.Assert(op.Step(0).(operator.TransferLeader).ToStore, Equals, uint64(3)) - c.Assert(op.Step(1).(operator.AddLearner).ToStore, Equals, uint64(4)) - c.Assert(op.Step(2).(operator.PromoteLearner).ToStore, Equals, uint64(4)) - c.Assert(op.Step(3).(operator.RemovePeer).FromStore, Equals, uint64(1)) + suite.cluster.PutRegion(r) + op := suite.rc.Check(r) + suite.NotNil(op) + suite.Equal(uint64(3), op.Step(0).(operator.TransferLeader).ToStore) + suite.Equal(uint64(4), op.Step(1).(operator.AddLearner).ToStore) + suite.Equal(uint64(4), op.Step(2).(operator.PromoteLearner).ToStore) + suite.Equal(uint64(1), op.Step(3).(operator.RemovePeer).FromStore) } -func (s *testReplicaCheckerSuite) TestOfflineWithOneReplica(c *C) { - s.cluster.SetMaxReplicas(1) +func (suite *replicaCheckerTestSuite) TestOfflineWithOneReplica() { + suite.cluster.SetMaxReplicas(1) peers := []*metapb.Peer{ { Id: 4, @@ -154,27 +155,27 @@ func (s *testReplicaCheckerSuite) TestOfflineWithOneReplica(c *C) { }, } r := core.NewRegionInfo(&metapb.Region{Id: 2, Peers: peers}, peers[0]) - s.cluster.PutRegion(r) - op := s.rc.Check(r) - c.Assert(op, NotNil) - c.Assert(op.Desc(), Equals, "replace-offline-replica") + suite.cluster.PutRegion(r) + op := suite.rc.Check(r) + suite.NotNil(op) + suite.Equal("replace-offline-replica", op.Desc()) } -func (s *testReplicaCheckerSuite) TestDownPeer(c *C) { +func (suite *replicaCheckerTestSuite) TestDownPeer() { // down a peer, the number of normal peers(except learner) is enough. - op := s.downPeerAndCheck(c, metapb.PeerRole_Voter) - c.Assert(op, NotNil) - c.Assert(op.Desc(), Equals, "remove-extra-down-replica") + op := suite.downPeerAndCheck(metapb.PeerRole_Voter) + suite.NotNil(op) + suite.Equal("remove-extra-down-replica", op.Desc()) // down a peer,the number of peers(except learner) is not enough. - op = s.downPeerAndCheck(c, metapb.PeerRole_Learner) - c.Assert(op, NotNil) - c.Assert(op.Desc(), Equals, "replace-down-replica") + op = suite.downPeerAndCheck(metapb.PeerRole_Learner) + suite.NotNil(op) + suite.Equal("replace-down-replica", op.Desc()) } -func (s *testReplicaCheckerSuite) downPeerAndCheck(c *C, aliveRole metapb.PeerRole) *operator.Operator { - s.cluster.SetMaxReplicas(2) - s.cluster.SetStoreUp(1) +func (suite *replicaCheckerTestSuite) downPeerAndCheck(aliveRole metapb.PeerRole) *operator.Operator { + suite.cluster.SetMaxReplicas(2) + suite.cluster.SetStoreUp(1) downStoreID := uint64(3) peers := []*metapb.Peer{ { @@ -192,8 +193,8 @@ func (s *testReplicaCheckerSuite) downPeerAndCheck(c *C, aliveRole metapb.PeerRo }, } r := core.NewRegionInfo(&metapb.Region{Id: 2, Peers: peers}, peers[0]) - s.cluster.PutRegion(r) - s.cluster.SetStoreDown(downStoreID) + suite.cluster.PutRegion(r) + suite.cluster.SetStoreDown(downStoreID) downPeer := &pdpb.PeerStats{ Peer: &metapb.Peer{ Id: 14, @@ -202,13 +203,13 @@ func (s *testReplicaCheckerSuite) downPeerAndCheck(c *C, aliveRole metapb.PeerRo DownSeconds: 24 * 60 * 60, } r = r.Clone(core.WithDownPeers(append(r.GetDownPeers(), downPeer))) - c.Assert(r.GetDownPeers(), HasLen, 1) - return s.rc.Check(r) + suite.Len(r.GetDownPeers(), 1) + return suite.rc.Check(r) } -func (s *testReplicaCheckerSuite) TestBasic(c *C) { +func (suite *replicaCheckerTestSuite) TestBasic() { opt := config.NewTestOptions() - tc := mockcluster.NewCluster(s.ctx, opt) + tc := mockcluster.NewCluster(suite.ctx, opt) tc.SetMaxSnapshotCount(2) tc.SetClusterVersion(versioninfo.MinSupportedVersion(versioninfo.Version4_0)) rc := NewReplicaChecker(tc, cache.NewDefaultCache(10)) @@ -223,41 +224,41 @@ func (s *testReplicaCheckerSuite) TestBasic(c *C) { // Region has 2 peers, we need to add a new peer. region := tc.GetRegion(1) - testutil.CheckAddPeer(c, rc.Check(region), operator.OpReplica, 4) + testutil.CheckAddPeerWithTestify(suite.Require(), rc.Check(region), operator.OpReplica, 4) // Disable make up replica feature. tc.SetEnableMakeUpReplica(false) - c.Assert(rc.Check(region), IsNil) + suite.Nil(rc.Check(region)) tc.SetEnableMakeUpReplica(true) // Test healthFilter. // If store 4 is down, we add to store 3. tc.SetStoreDown(4) - testutil.CheckAddPeer(c, rc.Check(region), operator.OpReplica, 3) + testutil.CheckAddPeerWithTestify(suite.Require(), rc.Check(region), operator.OpReplica, 3) tc.SetStoreUp(4) - testutil.CheckAddPeer(c, rc.Check(region), operator.OpReplica, 4) + testutil.CheckAddPeerWithTestify(suite.Require(), rc.Check(region), operator.OpReplica, 4) // Test snapshotCountFilter. // If snapshotCount > MaxSnapshotCount, we add to store 3. tc.UpdateSnapshotCount(4, 3) - testutil.CheckAddPeer(c, rc.Check(region), operator.OpReplica, 3) + testutil.CheckAddPeerWithTestify(suite.Require(), rc.Check(region), operator.OpReplica, 3) // If snapshotCount < MaxSnapshotCount, we can add peer again. tc.UpdateSnapshotCount(4, 1) - testutil.CheckAddPeer(c, rc.Check(region), operator.OpReplica, 4) + testutil.CheckAddPeerWithTestify(suite.Require(), rc.Check(region), operator.OpReplica, 4) // Add peer in store 4, and we have enough replicas. peer4, _ := tc.AllocPeer(4) region = region.Clone(core.WithAddPeer(peer4)) - c.Assert(rc.Check(region), IsNil) + suite.Nil(rc.Check(region)) // Add peer in store 3, and we have redundant replicas. peer3, _ := tc.AllocPeer(3) region = region.Clone(core.WithAddPeer(peer3)) - testutil.CheckRemovePeer(c, rc.Check(region), 1) + testutil.CheckRemovePeerWithTestify(suite.Require(), rc.Check(region), 1) // Disable remove extra replica feature. tc.SetEnableRemoveExtraReplica(false) - c.Assert(rc.Check(region), IsNil) + suite.Nil(rc.Check(region)) tc.SetEnableRemoveExtraReplica(true) region = region.Clone(core.WithRemoveStorePeer(1), core.WithLeader(region.GetStorePeer(3))) @@ -270,18 +271,18 @@ func (s *testReplicaCheckerSuite) TestBasic(c *C) { } region = region.Clone(core.WithDownPeers(append(region.GetDownPeers(), downPeer))) - testutil.CheckTransferPeer(c, rc.Check(region), operator.OpReplica, 2, 1) + testutil.CheckTransferPeerWithTestify(suite.Require(), rc.Check(region), operator.OpReplica, 2, 1) region = region.Clone(core.WithDownPeers(nil)) - c.Assert(rc.Check(region), IsNil) + suite.Nil(rc.Check(region)) // Peer in store 3 is offline, transfer peer to store 1. tc.SetStoreOffline(3) - testutil.CheckTransferPeer(c, rc.Check(region), operator.OpReplica, 3, 1) + testutil.CheckTransferPeerWithTestify(suite.Require(), rc.Check(region), operator.OpReplica, 3, 1) } -func (s *testReplicaCheckerSuite) TestLostStore(c *C) { +func (suite *replicaCheckerTestSuite) TestLostStore() { opt := config.NewTestOptions() - tc := mockcluster.NewCluster(s.ctx, opt) + tc := mockcluster.NewCluster(suite.ctx, opt) tc.SetClusterVersion(versioninfo.MinSupportedVersion(versioninfo.Version4_0)) tc.AddRegionStore(1, 1) @@ -295,12 +296,12 @@ func (s *testReplicaCheckerSuite) TestLostStore(c *C) { tc.AddLeaderRegion(1, 1, 2, 3) region := tc.GetRegion(1) op := rc.Check(region) - c.Assert(op, IsNil) + suite.Nil(op) } -func (s *testReplicaCheckerSuite) TestOffline(c *C) { +func (suite *replicaCheckerTestSuite) TestOffline() { opt := config.NewTestOptions() - tc := mockcluster.NewCluster(s.ctx, opt) + tc := mockcluster.NewCluster(suite.ctx, opt) tc.SetClusterVersion(versioninfo.MinSupportedVersion(versioninfo.Version4_0)) tc.SetMaxReplicas(3) tc.SetLocationLabels([]string{"zone", "rack", "host"}) @@ -316,43 +317,43 @@ func (s *testReplicaCheckerSuite) TestOffline(c *C) { region := tc.GetRegion(1) // Store 2 has different zone and smallest region score. - testutil.CheckAddPeer(c, rc.Check(region), operator.OpReplica, 2) + testutil.CheckAddPeerWithTestify(suite.Require(), rc.Check(region), operator.OpReplica, 2) peer2, _ := tc.AllocPeer(2) region = region.Clone(core.WithAddPeer(peer2)) // Store 3 has different zone and smallest region score. - testutil.CheckAddPeer(c, rc.Check(region), operator.OpReplica, 3) + testutil.CheckAddPeerWithTestify(suite.Require(), rc.Check(region), operator.OpReplica, 3) peer3, _ := tc.AllocPeer(3) region = region.Clone(core.WithAddPeer(peer3)) // Store 4 has the same zone with store 3 and larger region score. peer4, _ := tc.AllocPeer(4) region = region.Clone(core.WithAddPeer(peer4)) - testutil.CheckRemovePeer(c, rc.Check(region), 4) + testutil.CheckRemovePeerWithTestify(suite.Require(), rc.Check(region), 4) // Test offline // the number of region peers more than the maxReplicas // remove the peer tc.SetStoreOffline(3) - testutil.CheckRemovePeer(c, rc.Check(region), 3) + testutil.CheckRemovePeerWithTestify(suite.Require(), rc.Check(region), 3) region = region.Clone(core.WithRemoveStorePeer(4)) // the number of region peers equals the maxReplicas // Transfer peer to store 4. - testutil.CheckTransferPeer(c, rc.Check(region), operator.OpReplica, 3, 4) + testutil.CheckTransferPeerWithTestify(suite.Require(), rc.Check(region), operator.OpReplica, 3, 4) // Store 5 has a same label score with store 4, but the region score smaller than store 4, we will choose store 5. tc.AddLabelsStore(5, 3, map[string]string{"zone": "z4", "rack": "r1", "host": "h1"}) - testutil.CheckTransferPeer(c, rc.Check(region), operator.OpReplica, 3, 5) + testutil.CheckTransferPeerWithTestify(suite.Require(), rc.Check(region), operator.OpReplica, 3, 5) // Store 5 has too many snapshots, choose store 4 tc.UpdateSnapshotCount(5, 100) - testutil.CheckTransferPeer(c, rc.Check(region), operator.OpReplica, 3, 4) + testutil.CheckTransferPeerWithTestify(suite.Require(), rc.Check(region), operator.OpReplica, 3, 4) tc.UpdatePendingPeerCount(4, 100) - c.Assert(rc.Check(region), IsNil) + suite.Nil(rc.Check(region)) } -func (s *testReplicaCheckerSuite) TestDistinctScore(c *C) { +func (suite *replicaCheckerTestSuite) TestDistinctScore() { opt := config.NewTestOptions() - tc := mockcluster.NewCluster(s.ctx, opt) + tc := mockcluster.NewCluster(suite.ctx, opt) tc.SetClusterVersion(versioninfo.MinSupportedVersion(versioninfo.Version4_0)) tc.SetMaxReplicas(3) tc.SetLocationLabels([]string{"zone", "rack", "host"}) @@ -365,73 +366,73 @@ func (s *testReplicaCheckerSuite) TestDistinctScore(c *C) { // We need 3 replicas. tc.AddLeaderRegion(1, 1) region := tc.GetRegion(1) - testutil.CheckAddPeer(c, rc.Check(region), operator.OpReplica, 2) + testutil.CheckAddPeerWithTestify(suite.Require(), rc.Check(region), operator.OpReplica, 2) peer2, _ := tc.AllocPeer(2) region = region.Clone(core.WithAddPeer(peer2)) // Store 1,2,3 have the same zone, rack, and host. tc.AddLabelsStore(3, 5, map[string]string{"zone": "z1", "rack": "r1", "host": "h1"}) - testutil.CheckAddPeer(c, rc.Check(region), operator.OpReplica, 3) + testutil.CheckAddPeerWithTestify(suite.Require(), rc.Check(region), operator.OpReplica, 3) // Store 4 has smaller region score. tc.AddLabelsStore(4, 4, map[string]string{"zone": "z1", "rack": "r1", "host": "h1"}) - testutil.CheckAddPeer(c, rc.Check(region), operator.OpReplica, 4) + testutil.CheckAddPeerWithTestify(suite.Require(), rc.Check(region), operator.OpReplica, 4) // Store 5 has a different host. tc.AddLabelsStore(5, 5, map[string]string{"zone": "z1", "rack": "r1", "host": "h2"}) - testutil.CheckAddPeer(c, rc.Check(region), operator.OpReplica, 5) + testutil.CheckAddPeerWithTestify(suite.Require(), rc.Check(region), operator.OpReplica, 5) // Store 6 has a different rack. tc.AddLabelsStore(6, 6, map[string]string{"zone": "z1", "rack": "r2", "host": "h1"}) - testutil.CheckAddPeer(c, rc.Check(region), operator.OpReplica, 6) + testutil.CheckAddPeerWithTestify(suite.Require(), rc.Check(region), operator.OpReplica, 6) // Store 7 has a different zone. tc.AddLabelsStore(7, 7, map[string]string{"zone": "z2", "rack": "r1", "host": "h1"}) - testutil.CheckAddPeer(c, rc.Check(region), operator.OpReplica, 7) + testutil.CheckAddPeerWithTestify(suite.Require(), rc.Check(region), operator.OpReplica, 7) // Test stateFilter. tc.SetStoreOffline(7) - testutil.CheckAddPeer(c, rc.Check(region), operator.OpReplica, 6) + testutil.CheckAddPeerWithTestify(suite.Require(), rc.Check(region), operator.OpReplica, 6) tc.SetStoreUp(7) - testutil.CheckAddPeer(c, rc.Check(region), operator.OpReplica, 7) + testutil.CheckAddPeerWithTestify(suite.Require(), rc.Check(region), operator.OpReplica, 7) // Add peer to store 7. peer7, _ := tc.AllocPeer(7) region = region.Clone(core.WithAddPeer(peer7)) // Replace peer in store 1 with store 6 because it has a different rack. - testutil.CheckTransferPeer(c, rc.Check(region), operator.OpReplica, 1, 6) + testutil.CheckTransferPeerWithTestify(suite.Require(), rc.Check(region), operator.OpReplica, 1, 6) // Disable locationReplacement feature. tc.SetEnableLocationReplacement(false) - c.Assert(rc.Check(region), IsNil) + suite.Nil(rc.Check(region)) tc.SetEnableLocationReplacement(true) peer6, _ := tc.AllocPeer(6) region = region.Clone(core.WithAddPeer(peer6)) - testutil.CheckRemovePeer(c, rc.Check(region), 1) + testutil.CheckRemovePeerWithTestify(suite.Require(), rc.Check(region), 1) region = region.Clone(core.WithRemoveStorePeer(1), core.WithLeader(region.GetStorePeer(2))) - c.Assert(rc.Check(region), IsNil) + suite.Nil(rc.Check(region)) // Store 8 has the same zone and different rack with store 7. // Store 1 has the same zone and different rack with store 6. // So store 8 and store 1 are equivalent. tc.AddLabelsStore(8, 1, map[string]string{"zone": "z2", "rack": "r2", "host": "h1"}) - c.Assert(rc.Check(region), IsNil) + suite.Nil(rc.Check(region)) // Store 10 has a different zone. // Store 2 and 6 have the same distinct score, but store 2 has larger region score. // So replace peer in store 2 with store 10. tc.AddLabelsStore(10, 1, map[string]string{"zone": "z3", "rack": "r1", "host": "h1"}) - testutil.CheckTransferPeer(c, rc.Check(region), operator.OpReplica, 2, 10) + testutil.CheckTransferPeerWithTestify(suite.Require(), rc.Check(region), operator.OpReplica, 2, 10) peer10, _ := tc.AllocPeer(10) region = region.Clone(core.WithAddPeer(peer10)) - testutil.CheckRemovePeer(c, rc.Check(region), 2) + testutil.CheckRemovePeerWithTestify(suite.Require(), rc.Check(region), 2) region = region.Clone(core.WithRemoveStorePeer(2)) - c.Assert(rc.Check(region), IsNil) + suite.Nil(rc.Check(region)) } -func (s *testReplicaCheckerSuite) TestDistinctScore2(c *C) { +func (suite *replicaCheckerTestSuite) TestDistinctScore2() { opt := config.NewTestOptions() - tc := mockcluster.NewCluster(s.ctx, opt) + tc := mockcluster.NewCluster(suite.ctx, opt) tc.SetClusterVersion(versioninfo.MinSupportedVersion(versioninfo.Version4_0)) tc.SetMaxReplicas(5) tc.SetLocationLabels([]string{"zone", "host"}) @@ -448,20 +449,20 @@ func (s *testReplicaCheckerSuite) TestDistinctScore2(c *C) { tc.AddLeaderRegion(1, 1, 2, 4) region := tc.GetRegion(1) - testutil.CheckAddPeer(c, rc.Check(region), operator.OpReplica, 6) + testutil.CheckAddPeerWithTestify(suite.Require(), rc.Check(region), operator.OpReplica, 6) peer6, _ := tc.AllocPeer(6) region = region.Clone(core.WithAddPeer(peer6)) - testutil.CheckAddPeer(c, rc.Check(region), operator.OpReplica, 5) + testutil.CheckAddPeerWithTestify(suite.Require(), rc.Check(region), operator.OpReplica, 5) peer5, _ := tc.AllocPeer(5) region = region.Clone(core.WithAddPeer(peer5)) - c.Assert(rc.Check(region), IsNil) + suite.Nil(rc.Check(region)) } -func (s *testReplicaCheckerSuite) TestStorageThreshold(c *C) { +func (suite *replicaCheckerTestSuite) TestStorageThreshold() { opt := config.NewTestOptions() - tc := mockcluster.NewCluster(s.ctx, opt) + tc := mockcluster.NewCluster(suite.ctx, opt) tc.SetLocationLabels([]string{"zone"}) tc.SetClusterVersion(versioninfo.MinSupportedVersion(versioninfo.Version4_0)) rc := NewReplicaChecker(tc, cache.NewDefaultCache(10)) @@ -480,24 +481,24 @@ func (s *testReplicaCheckerSuite) TestStorageThreshold(c *C) { // Move peer to better location. tc.UpdateStorageRatio(4, 0, 1) - testutil.CheckTransferPeer(c, rc.Check(region), operator.OpReplica, 1, 4) + testutil.CheckTransferPeerWithTestify(suite.Require(), rc.Check(region), operator.OpReplica, 1, 4) // If store4 is almost full, do not add peer on it. tc.UpdateStorageRatio(4, 0.9, 0.1) - c.Assert(rc.Check(region), IsNil) + suite.Nil(rc.Check(region)) tc.AddLeaderRegion(2, 1, 3) region = tc.GetRegion(2) // Add peer on store4. tc.UpdateStorageRatio(4, 0, 1) - testutil.CheckAddPeer(c, rc.Check(region), operator.OpReplica, 4) + testutil.CheckAddPeerWithTestify(suite.Require(), rc.Check(region), operator.OpReplica, 4) // If store4 is almost full, do not add peer on it. tc.UpdateStorageRatio(4, 0.8, 0) - testutil.CheckAddPeer(c, rc.Check(region), operator.OpReplica, 2) + testutil.CheckAddPeerWithTestify(suite.Require(), rc.Check(region), operator.OpReplica, 2) } -func (s *testReplicaCheckerSuite) TestOpts(c *C) { +func (suite *replicaCheckerTestSuite) TestOpts() { opt := config.NewTestOptions() - tc := mockcluster.NewCluster(s.ctx, opt) + tc := mockcluster.NewCluster(suite.ctx, opt) tc.SetClusterVersion(versioninfo.MinSupportedVersion(versioninfo.Version4_0)) rc := NewReplicaChecker(tc, cache.NewDefaultCache(10)) @@ -518,17 +519,17 @@ func (s *testReplicaCheckerSuite) TestOpts(c *C) { })) tc.SetStoreOffline(2) // RemoveDownReplica has higher priority than replaceOfflineReplica. - testutil.CheckTransferPeer(c, rc.Check(region), operator.OpReplica, 1, 4) + testutil.CheckTransferPeerWithTestify(suite.Require(), rc.Check(region), operator.OpReplica, 1, 4) tc.SetEnableRemoveDownReplica(false) - testutil.CheckTransferPeer(c, rc.Check(region), operator.OpReplica, 2, 4) + testutil.CheckTransferPeerWithTestify(suite.Require(), rc.Check(region), operator.OpReplica, 2, 4) tc.SetEnableReplaceOfflineReplica(false) - c.Assert(rc.Check(region), IsNil) + suite.Nil(rc.Check(region)) } // See issue: https://github.com/tikv/pd/issues/3705 -func (s *testReplicaCheckerSuite) TestFixDownPeer(c *C) { +func (suite *replicaCheckerTestSuite) TestFixDownPeer() { opt := config.NewTestOptions() - tc := mockcluster.NewCluster(s.ctx, opt) + tc := mockcluster.NewCluster(suite.ctx, opt) tc.SetClusterVersion(versioninfo.MinSupportedVersion(versioninfo.Version4_0)) tc.SetLocationLabels([]string{"zone"}) rc := NewReplicaChecker(tc, cache.NewDefaultCache(10)) @@ -541,25 +542,25 @@ func (s *testReplicaCheckerSuite) TestFixDownPeer(c *C) { tc.AddLeaderRegion(1, 1, 3, 4) region := tc.GetRegion(1) - c.Assert(rc.Check(region), IsNil) + suite.Nil(rc.Check(region)) tc.SetStoreDown(4) region = region.Clone(core.WithDownPeers([]*pdpb.PeerStats{ {Peer: region.GetStorePeer(4), DownSeconds: 6000}, })) - testutil.CheckTransferPeer(c, rc.Check(region), operator.OpRegion, 4, 5) + testutil.CheckTransferPeerWithTestify(suite.Require(), rc.Check(region), operator.OpRegion, 4, 5) tc.SetStoreDown(5) - testutil.CheckTransferPeer(c, rc.Check(region), operator.OpRegion, 4, 2) + testutil.CheckTransferPeerWithTestify(suite.Require(), rc.Check(region), operator.OpRegion, 4, 2) tc.SetIsolationLevel("zone") - c.Assert(rc.Check(region), IsNil) + suite.Nil(rc.Check(region)) } // See issue: https://github.com/tikv/pd/issues/3705 -func (s *testReplicaCheckerSuite) TestFixOfflinePeer(c *C) { +func (suite *replicaCheckerTestSuite) TestFixOfflinePeer() { opt := config.NewTestOptions() - tc := mockcluster.NewCluster(s.ctx, opt) + tc := mockcluster.NewCluster(suite.ctx, opt) tc.SetClusterVersion(versioninfo.MinSupportedVersion(versioninfo.Version4_0)) tc.SetLocationLabels([]string{"zone"}) rc := NewReplicaChecker(tc, cache.NewDefaultCache(10)) @@ -572,14 +573,14 @@ func (s *testReplicaCheckerSuite) TestFixOfflinePeer(c *C) { tc.AddLeaderRegion(1, 1, 3, 4) region := tc.GetRegion(1) - c.Assert(rc.Check(region), IsNil) + suite.Nil(rc.Check(region)) tc.SetStoreOffline(4) - testutil.CheckTransferPeer(c, rc.Check(region), operator.OpRegion, 4, 5) + testutil.CheckTransferPeerWithTestify(suite.Require(), rc.Check(region), operator.OpRegion, 4, 5) tc.SetStoreOffline(5) - testutil.CheckTransferPeer(c, rc.Check(region), operator.OpRegion, 4, 2) + testutil.CheckTransferPeerWithTestify(suite.Require(), rc.Check(region), operator.OpRegion, 4, 2) tc.SetIsolationLevel("zone") - c.Assert(rc.Check(region), IsNil) + suite.Nil(rc.Check(region)) } diff --git a/server/schedule/checker/rule_checker_test.go b/server/schedule/checker/rule_checker_test.go index f3a908939bf..ea9a369348e 100644 --- a/server/schedule/checker/rule_checker_test.go +++ b/server/schedule/checker/rule_checker_test.go @@ -16,11 +16,12 @@ package checker import ( "context" + "testing" - . "github.com/pingcap/check" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/cache" "github.com/tikv/pd/pkg/mock/mockcluster" "github.com/tikv/pd/pkg/testutil" @@ -31,36 +32,12 @@ import ( "github.com/tikv/pd/server/versioninfo" ) -var _ = Suite(&testRuleCheckerSuite{}) -var _ = SerialSuites(&testRuleCheckerSerialSuite{}) - -type testRuleCheckerSerialSuite struct { - cluster *mockcluster.Cluster - ruleManager *placement.RuleManager - rc *RuleChecker - ctx context.Context - cancel context.CancelFunc -} - -func (s *testRuleCheckerSerialSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) +func TestRuleCheckerTestSuite(t *testing.T) { + suite.Run(t, new(ruleCheckerTestSuite)) } -func (s *testRuleCheckerSerialSuite) TearDownTest(c *C) { - s.cancel() -} - -func (s *testRuleCheckerSerialSuite) SetUpTest(c *C) { - cfg := config.NewTestOptions() - cfg.SetPlacementRulesCacheEnabled(true) - s.cluster = mockcluster.NewCluster(s.ctx, cfg) - s.cluster.SetClusterVersion(versioninfo.MinSupportedVersion(versioninfo.Version4_0)) - s.cluster.SetEnablePlacementRules(true) - s.ruleManager = s.cluster.RuleManager - s.rc = NewRuleChecker(s.cluster, s.ruleManager, cache.NewDefaultCache(10)) -} - -type testRuleCheckerSuite struct { +type ruleCheckerTestSuite struct { + suite.Suite cluster *mockcluster.Cluster ruleManager *placement.RuleManager rc *RuleChecker @@ -68,42 +45,39 @@ type testRuleCheckerSuite struct { cancel context.CancelFunc } -func (s *testRuleCheckerSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) -} - -func (s *testRuleCheckerSuite) TearDownTest(c *C) { - s.cancel() +func (suite *ruleCheckerTestSuite) SetupTest() { + cfg := config.NewTestOptions() + suite.ctx, suite.cancel = context.WithCancel(context.Background()) + suite.cluster = mockcluster.NewCluster(suite.ctx, cfg) + suite.cluster.SetClusterVersion(versioninfo.MinSupportedVersion(versioninfo.Version4_0)) + suite.cluster.SetEnablePlacementRules(true) + suite.ruleManager = suite.cluster.RuleManager + suite.rc = NewRuleChecker(suite.cluster, suite.ruleManager, cache.NewDefaultCache(10)) } -func (s *testRuleCheckerSuite) SetUpTest(c *C) { - cfg := config.NewTestOptions() - s.cluster = mockcluster.NewCluster(s.ctx, cfg) - s.cluster.SetClusterVersion(versioninfo.MinSupportedVersion(versioninfo.Version4_0)) - s.cluster.SetEnablePlacementRules(true) - s.ruleManager = s.cluster.RuleManager - s.rc = NewRuleChecker(s.cluster, s.ruleManager, cache.NewDefaultCache(10)) +func (suite *ruleCheckerTestSuite) TearDownTest() { + suite.cancel() } -func (s *testRuleCheckerSuite) TestAddRulePeer(c *C) { - s.cluster.AddLeaderStore(1, 1) - s.cluster.AddLeaderStore(2, 1) - s.cluster.AddLeaderStore(3, 1) - s.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2) - op := s.rc.Check(s.cluster.GetRegion(1)) - c.Assert(op, NotNil) - c.Assert(op.Desc(), Equals, "add-rule-peer") - c.Assert(op.GetPriorityLevel(), Equals, core.HighPriority) - c.Assert(op.Step(0).(operator.AddLearner).ToStore, Equals, uint64(3)) +func (suite *ruleCheckerTestSuite) TestAddRulePeer() { + suite.cluster.AddLeaderStore(1, 1) + suite.cluster.AddLeaderStore(2, 1) + suite.cluster.AddLeaderStore(3, 1) + suite.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2) + op := suite.rc.Check(suite.cluster.GetRegion(1)) + suite.NotNil(op) + suite.Equal("add-rule-peer", op.Desc()) + suite.Equal(core.HighPriority, op.GetPriorityLevel()) + suite.Equal(uint64(3), op.Step(0).(operator.AddLearner).ToStore) } -func (s *testRuleCheckerSuite) TestAddRulePeerWithIsolationLevel(c *C) { - s.cluster.AddLabelsStore(1, 1, map[string]string{"zone": "z1", "rack": "r1", "host": "h1"}) - s.cluster.AddLabelsStore(2, 1, map[string]string{"zone": "z1", "rack": "r1", "host": "h2"}) - s.cluster.AddLabelsStore(3, 1, map[string]string{"zone": "z1", "rack": "r2", "host": "h1"}) - s.cluster.AddLabelsStore(4, 1, map[string]string{"zone": "z1", "rack": "r3", "host": "h1"}) - s.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2) - s.ruleManager.SetRule(&placement.Rule{ +func (suite *ruleCheckerTestSuite) TestAddRulePeerWithIsolationLevel() { + suite.cluster.AddLabelsStore(1, 1, map[string]string{"zone": "z1", "rack": "r1", "host": "h1"}) + suite.cluster.AddLabelsStore(2, 1, map[string]string{"zone": "z1", "rack": "r1", "host": "h2"}) + suite.cluster.AddLabelsStore(3, 1, map[string]string{"zone": "z1", "rack": "r2", "host": "h1"}) + suite.cluster.AddLabelsStore(4, 1, map[string]string{"zone": "z1", "rack": "r3", "host": "h1"}) + suite.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2) + suite.ruleManager.SetRule(&placement.Rule{ GroupID: "pd", ID: "test", Index: 100, @@ -113,10 +87,10 @@ func (s *testRuleCheckerSuite) TestAddRulePeerWithIsolationLevel(c *C) { LocationLabels: []string{"zone", "rack", "host"}, IsolationLevel: "zone", }) - op := s.rc.Check(s.cluster.GetRegion(1)) - c.Assert(op, IsNil) - s.cluster.AddLeaderRegionWithRange(1, "", "", 1, 3) - s.ruleManager.SetRule(&placement.Rule{ + op := suite.rc.Check(suite.cluster.GetRegion(1)) + suite.Nil(op) + suite.cluster.AddLeaderRegionWithRange(1, "", "", 1, 3) + suite.ruleManager.SetRule(&placement.Rule{ GroupID: "pd", ID: "test", Index: 100, @@ -126,75 +100,75 @@ func (s *testRuleCheckerSuite) TestAddRulePeerWithIsolationLevel(c *C) { LocationLabels: []string{"zone", "rack", "host"}, IsolationLevel: "rack", }) - op = s.rc.Check(s.cluster.GetRegion(1)) - c.Assert(op, NotNil) - c.Assert(op.Desc(), Equals, "add-rule-peer") - c.Assert(op.Step(0).(operator.AddLearner).ToStore, Equals, uint64(4)) + op = suite.rc.Check(suite.cluster.GetRegion(1)) + suite.NotNil(op) + suite.Equal("add-rule-peer", op.Desc()) + suite.Equal(uint64(4), op.Step(0).(operator.AddLearner).ToStore) } -func (s *testRuleCheckerSuite) TestFixPeer(c *C) { - s.cluster.AddLeaderStore(1, 1) - s.cluster.AddLeaderStore(2, 1) - s.cluster.AddLeaderStore(3, 1) - s.cluster.AddLeaderStore(4, 1) - s.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2, 3) - op := s.rc.Check(s.cluster.GetRegion(1)) - c.Assert(op, IsNil) - s.cluster.SetStoreDown(2) - r := s.cluster.GetRegion(1) +func (suite *ruleCheckerTestSuite) TestFixPeer() { + suite.cluster.AddLeaderStore(1, 1) + suite.cluster.AddLeaderStore(2, 1) + suite.cluster.AddLeaderStore(3, 1) + suite.cluster.AddLeaderStore(4, 1) + suite.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2, 3) + op := suite.rc.Check(suite.cluster.GetRegion(1)) + suite.Nil(op) + suite.cluster.SetStoreDown(2) + r := suite.cluster.GetRegion(1) r = r.Clone(core.WithDownPeers([]*pdpb.PeerStats{{Peer: r.GetStorePeer(2), DownSeconds: 60000}})) - op = s.rc.Check(r) - c.Assert(op, NotNil) - c.Assert(op.Desc(), Equals, "replace-rule-down-peer") - c.Assert(op.GetPriorityLevel(), Equals, core.HighPriority) + op = suite.rc.Check(r) + suite.NotNil(op) + suite.Equal("replace-rule-down-peer", op.Desc()) + suite.Equal(core.HighPriority, op.GetPriorityLevel()) var add operator.AddLearner - c.Assert(op.Step(0), FitsTypeOf, add) - s.cluster.SetStoreUp(2) - s.cluster.SetStoreOffline(2) - op = s.rc.Check(s.cluster.GetRegion(1)) - c.Assert(op, NotNil) - c.Assert(op.Desc(), Equals, "replace-rule-offline-peer") - c.Assert(op.GetPriorityLevel(), Equals, core.HighPriority) - c.Assert(op.Step(0), FitsTypeOf, add) - - s.cluster.SetStoreUp(2) + suite.IsType(add, op.Step(0)) + suite.cluster.SetStoreUp(2) + suite.cluster.SetStoreOffline(2) + op = suite.rc.Check(suite.cluster.GetRegion(1)) + suite.NotNil(op) + suite.Equal("replace-rule-offline-peer", op.Desc()) + suite.Equal(core.HighPriority, op.GetPriorityLevel()) + suite.IsType(add, op.Step(0)) + + suite.cluster.SetStoreUp(2) // leader store offline - s.cluster.SetStoreOffline(1) - r1 := s.cluster.GetRegion(1) + suite.cluster.SetStoreOffline(1) + r1 := suite.cluster.GetRegion(1) nr1 := r1.Clone(core.WithPendingPeers([]*metapb.Peer{r1.GetStorePeer(3)})) - s.cluster.PutRegion(nr1) + suite.cluster.PutRegion(nr1) hasTransferLeader := false for i := 0; i < 100; i++ { - op = s.rc.Check(s.cluster.GetRegion(1)) - c.Assert(op, NotNil) + op = suite.rc.Check(suite.cluster.GetRegion(1)) + suite.NotNil(op) if step, ok := op.Step(0).(operator.TransferLeader); ok { - c.Assert(step.FromStore, Equals, uint64(1)) - c.Assert(step.ToStore, Not(Equals), uint64(3)) + suite.Equal(uint64(1), step.FromStore) + suite.NotEqual(uint64(3), step.ToStore) hasTransferLeader = true } } - c.Assert(hasTransferLeader, IsTrue) + suite.True(hasTransferLeader) } -func (s *testRuleCheckerSuite) TestFixOrphanPeers(c *C) { - s.cluster.AddLeaderStore(1, 1) - s.cluster.AddLeaderStore(2, 1) - s.cluster.AddLeaderStore(3, 1) - s.cluster.AddLeaderStore(4, 1) - s.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2, 3, 4) - op := s.rc.Check(s.cluster.GetRegion(1)) - c.Assert(op, NotNil) - c.Assert(op.Desc(), Equals, "remove-orphan-peer") - c.Assert(op.Step(0).(operator.RemovePeer).FromStore, Equals, uint64(4)) +func (suite *ruleCheckerTestSuite) TestFixOrphanPeers() { + suite.cluster.AddLeaderStore(1, 1) + suite.cluster.AddLeaderStore(2, 1) + suite.cluster.AddLeaderStore(3, 1) + suite.cluster.AddLeaderStore(4, 1) + suite.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2, 3, 4) + op := suite.rc.Check(suite.cluster.GetRegion(1)) + suite.NotNil(op) + suite.Equal("remove-orphan-peer", op.Desc()) + suite.Equal(uint64(4), op.Step(0).(operator.RemovePeer).FromStore) } -func (s *testRuleCheckerSuite) TestFixOrphanPeers2(c *C) { +func (suite *ruleCheckerTestSuite) TestFixOrphanPeers2() { // check orphan peers can only be handled when all rules are satisfied. - s.cluster.AddLabelsStore(1, 1, map[string]string{"foo": "bar"}) - s.cluster.AddLabelsStore(2, 1, map[string]string{"foo": "bar"}) - s.cluster.AddLabelsStore(3, 1, map[string]string{"foo": "baz"}) - s.cluster.AddLeaderRegionWithRange(1, "", "", 1, 3) - s.ruleManager.SetRule(&placement.Rule{ + suite.cluster.AddLabelsStore(1, 1, map[string]string{"foo": "bar"}) + suite.cluster.AddLabelsStore(2, 1, map[string]string{"foo": "bar"}) + suite.cluster.AddLabelsStore(3, 1, map[string]string{"foo": "baz"}) + suite.cluster.AddLeaderRegionWithRange(1, "", "", 1, 3) + suite.ruleManager.SetRule(&placement.Rule{ GroupID: "pd", ID: "r1", Index: 100, @@ -205,32 +179,32 @@ func (s *testRuleCheckerSuite) TestFixOrphanPeers2(c *C) { {Key: "foo", Op: "in", Values: []string{"baz"}}, }, }) - s.cluster.SetStoreDown(2) - op := s.rc.Check(s.cluster.GetRegion(1)) - c.Assert(op, IsNil) + suite.cluster.SetStoreDown(2) + op := suite.rc.Check(suite.cluster.GetRegion(1)) + suite.Nil(op) } -func (s *testRuleCheckerSuite) TestFixRole(c *C) { - s.cluster.AddLeaderStore(1, 1) - s.cluster.AddLeaderStore(2, 1) - s.cluster.AddLeaderStore(3, 1) - s.cluster.AddLeaderRegionWithRange(1, "", "", 2, 1, 3) - r := s.cluster.GetRegion(1) +func (suite *ruleCheckerTestSuite) TestFixRole() { + suite.cluster.AddLeaderStore(1, 1) + suite.cluster.AddLeaderStore(2, 1) + suite.cluster.AddLeaderStore(3, 1) + suite.cluster.AddLeaderRegionWithRange(1, "", "", 2, 1, 3) + r := suite.cluster.GetRegion(1) p := r.GetStorePeer(1) p.Role = metapb.PeerRole_Learner r = r.Clone(core.WithLearners([]*metapb.Peer{p})) - op := s.rc.Check(r) - c.Assert(op, NotNil) - c.Assert(op.Desc(), Equals, "fix-peer-role") - c.Assert(op.Step(0).(operator.PromoteLearner).ToStore, Equals, uint64(1)) + op := suite.rc.Check(r) + suite.NotNil(op) + suite.Equal("fix-peer-role", op.Desc()) + suite.Equal(uint64(1), op.Step(0).(operator.PromoteLearner).ToStore) } -func (s *testRuleCheckerSuite) TestFixRoleLeader(c *C) { - s.cluster.AddLabelsStore(1, 1, map[string]string{"role": "follower"}) - s.cluster.AddLabelsStore(2, 1, map[string]string{"role": "follower"}) - s.cluster.AddLabelsStore(3, 1, map[string]string{"role": "voter"}) - s.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2, 3) - s.ruleManager.SetRule(&placement.Rule{ +func (suite *ruleCheckerTestSuite) TestFixRoleLeader() { + suite.cluster.AddLabelsStore(1, 1, map[string]string{"role": "follower"}) + suite.cluster.AddLabelsStore(2, 1, map[string]string{"role": "follower"}) + suite.cluster.AddLabelsStore(3, 1, map[string]string{"role": "voter"}) + suite.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2, 3) + suite.ruleManager.SetRule(&placement.Rule{ GroupID: "pd", ID: "r1", Index: 100, @@ -241,7 +215,7 @@ func (s *testRuleCheckerSuite) TestFixRoleLeader(c *C) { {Key: "role", Op: "in", Values: []string{"voter"}}, }, }) - s.ruleManager.SetRule(&placement.Rule{ + suite.ruleManager.SetRule(&placement.Rule{ GroupID: "pd", ID: "r2", Index: 101, @@ -251,17 +225,17 @@ func (s *testRuleCheckerSuite) TestFixRoleLeader(c *C) { {Key: "role", Op: "in", Values: []string{"follower"}}, }, }) - op := s.rc.Check(s.cluster.GetRegion(1)) - c.Assert(op, NotNil) - c.Assert(op.Desc(), Equals, "fix-follower-role") - c.Assert(op.Step(0).(operator.TransferLeader).ToStore, Equals, uint64(3)) + op := suite.rc.Check(suite.cluster.GetRegion(1)) + suite.NotNil(op) + suite.Equal("fix-follower-role", op.Desc()) + suite.Equal(uint64(3), op.Step(0).(operator.TransferLeader).ToStore) } -func (s *testRuleCheckerSuite) TestFixRoleLeaderIssue3130(c *C) { - s.cluster.AddLabelsStore(1, 1, map[string]string{"role": "follower"}) - s.cluster.AddLabelsStore(2, 1, map[string]string{"role": "leader"}) - s.cluster.AddLeaderRegion(1, 1, 2) - s.ruleManager.SetRule(&placement.Rule{ +func (suite *ruleCheckerTestSuite) TestFixRoleLeaderIssue3130() { + suite.cluster.AddLabelsStore(1, 1, map[string]string{"role": "follower"}) + suite.cluster.AddLabelsStore(2, 1, map[string]string{"role": "leader"}) + suite.cluster.AddLeaderRegion(1, 1, 2) + suite.ruleManager.SetRule(&placement.Rule{ GroupID: "pd", ID: "r1", Index: 100, @@ -272,30 +246,30 @@ func (s *testRuleCheckerSuite) TestFixRoleLeaderIssue3130(c *C) { {Key: "role", Op: "in", Values: []string{"leader"}}, }, }) - op := s.rc.Check(s.cluster.GetRegion(1)) - c.Assert(op, NotNil) - c.Assert(op.Desc(), Equals, "fix-leader-role") - c.Assert(op.Step(0).(operator.TransferLeader).ToStore, Equals, uint64(2)) - - s.cluster.SetStoreBusy(2, true) - op = s.rc.Check(s.cluster.GetRegion(1)) - c.Assert(op, IsNil) - s.cluster.SetStoreBusy(2, false) - - s.cluster.AddLeaderRegion(1, 2, 1) - op = s.rc.Check(s.cluster.GetRegion(1)) - c.Assert(op, NotNil) - c.Assert(op.Desc(), Equals, "remove-orphan-peer") - c.Assert(op.Step(0).(operator.RemovePeer).FromStore, Equals, uint64(1)) + op := suite.rc.Check(suite.cluster.GetRegion(1)) + suite.NotNil(op) + suite.Equal("fix-leader-role", op.Desc()) + suite.Equal(uint64(2), op.Step(0).(operator.TransferLeader).ToStore) + + suite.cluster.SetStoreBusy(2, true) + op = suite.rc.Check(suite.cluster.GetRegion(1)) + suite.Nil(op) + suite.cluster.SetStoreBusy(2, false) + + suite.cluster.AddLeaderRegion(1, 2, 1) + op = suite.rc.Check(suite.cluster.GetRegion(1)) + suite.NotNil(op) + suite.Equal("remove-orphan-peer", op.Desc()) + suite.Equal(uint64(1), op.Step(0).(operator.RemovePeer).FromStore) } -func (s *testRuleCheckerSuite) TestBetterReplacement(c *C) { - s.cluster.AddLabelsStore(1, 1, map[string]string{"host": "host1"}) - s.cluster.AddLabelsStore(2, 1, map[string]string{"host": "host1"}) - s.cluster.AddLabelsStore(3, 1, map[string]string{"host": "host2"}) - s.cluster.AddLabelsStore(4, 1, map[string]string{"host": "host3"}) - s.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2, 3) - s.ruleManager.SetRule(&placement.Rule{ +func (suite *ruleCheckerTestSuite) TestBetterReplacement() { + suite.cluster.AddLabelsStore(1, 1, map[string]string{"host": "host1"}) + suite.cluster.AddLabelsStore(2, 1, map[string]string{"host": "host1"}) + suite.cluster.AddLabelsStore(3, 1, map[string]string{"host": "host2"}) + suite.cluster.AddLabelsStore(4, 1, map[string]string{"host": "host3"}) + suite.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2, 3) + suite.ruleManager.SetRule(&placement.Rule{ GroupID: "pd", ID: "test", Index: 100, @@ -304,22 +278,22 @@ func (s *testRuleCheckerSuite) TestBetterReplacement(c *C) { Count: 3, LocationLabels: []string{"host"}, }) - op := s.rc.Check(s.cluster.GetRegion(1)) - c.Assert(op, NotNil) - c.Assert(op.Desc(), Equals, "move-to-better-location") - c.Assert(op.Step(0).(operator.AddLearner).ToStore, Equals, uint64(4)) - s.cluster.AddLeaderRegionWithRange(1, "", "", 1, 3, 4) - op = s.rc.Check(s.cluster.GetRegion(1)) - c.Assert(op, IsNil) + op := suite.rc.Check(suite.cluster.GetRegion(1)) + suite.NotNil(op) + suite.Equal("move-to-better-location", op.Desc()) + suite.Equal(uint64(4), op.Step(0).(operator.AddLearner).ToStore) + suite.cluster.AddLeaderRegionWithRange(1, "", "", 1, 3, 4) + op = suite.rc.Check(suite.cluster.GetRegion(1)) + suite.Nil(op) } -func (s *testRuleCheckerSuite) TestBetterReplacement2(c *C) { - s.cluster.AddLabelsStore(1, 1, map[string]string{"zone": "z1", "host": "host1"}) - s.cluster.AddLabelsStore(2, 1, map[string]string{"zone": "z1", "host": "host2"}) - s.cluster.AddLabelsStore(3, 1, map[string]string{"zone": "z1", "host": "host3"}) - s.cluster.AddLabelsStore(4, 1, map[string]string{"zone": "z2", "host": "host1"}) - s.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2, 3) - s.ruleManager.SetRule(&placement.Rule{ +func (suite *ruleCheckerTestSuite) TestBetterReplacement2() { + suite.cluster.AddLabelsStore(1, 1, map[string]string{"zone": "z1", "host": "host1"}) + suite.cluster.AddLabelsStore(2, 1, map[string]string{"zone": "z1", "host": "host2"}) + suite.cluster.AddLabelsStore(3, 1, map[string]string{"zone": "z1", "host": "host3"}) + suite.cluster.AddLabelsStore(4, 1, map[string]string{"zone": "z2", "host": "host1"}) + suite.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2, 3) + suite.ruleManager.SetRule(&placement.Rule{ GroupID: "pd", ID: "test", Index: 100, @@ -328,21 +302,21 @@ func (s *testRuleCheckerSuite) TestBetterReplacement2(c *C) { Count: 3, LocationLabels: []string{"zone", "host"}, }) - op := s.rc.Check(s.cluster.GetRegion(1)) - c.Assert(op, NotNil) - c.Assert(op.Desc(), Equals, "move-to-better-location") - c.Assert(op.Step(0).(operator.AddLearner).ToStore, Equals, uint64(4)) - s.cluster.AddLeaderRegionWithRange(1, "", "", 1, 3, 4) - op = s.rc.Check(s.cluster.GetRegion(1)) - c.Assert(op, IsNil) + op := suite.rc.Check(suite.cluster.GetRegion(1)) + suite.NotNil(op) + suite.Equal("move-to-better-location", op.Desc()) + suite.Equal(uint64(4), op.Step(0).(operator.AddLearner).ToStore) + suite.cluster.AddLeaderRegionWithRange(1, "", "", 1, 3, 4) + op = suite.rc.Check(suite.cluster.GetRegion(1)) + suite.Nil(op) } -func (s *testRuleCheckerSuite) TestNoBetterReplacement(c *C) { - s.cluster.AddLabelsStore(1, 1, map[string]string{"host": "host1"}) - s.cluster.AddLabelsStore(2, 1, map[string]string{"host": "host1"}) - s.cluster.AddLabelsStore(3, 1, map[string]string{"host": "host2"}) - s.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2, 3) - s.ruleManager.SetRule(&placement.Rule{ +func (suite *ruleCheckerTestSuite) TestNoBetterReplacement() { + suite.cluster.AddLabelsStore(1, 1, map[string]string{"host": "host1"}) + suite.cluster.AddLabelsStore(2, 1, map[string]string{"host": "host1"}) + suite.cluster.AddLabelsStore(3, 1, map[string]string{"host": "host2"}) + suite.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2, 3) + suite.ruleManager.SetRule(&placement.Rule{ GroupID: "pd", ID: "test", Index: 100, @@ -351,72 +325,72 @@ func (s *testRuleCheckerSuite) TestNoBetterReplacement(c *C) { Count: 3, LocationLabels: []string{"host"}, }) - op := s.rc.Check(s.cluster.GetRegion(1)) - c.Assert(op, IsNil) + op := suite.rc.Check(suite.cluster.GetRegion(1)) + suite.Nil(op) } -func (s *testRuleCheckerSuite) TestIssue2419(c *C) { - s.cluster.AddLeaderStore(1, 1) - s.cluster.AddLeaderStore(2, 1) - s.cluster.AddLeaderStore(3, 1) - s.cluster.AddLeaderStore(4, 1) - s.cluster.SetStoreOffline(3) - s.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2, 3) - r := s.cluster.GetRegion(1) +func (suite *ruleCheckerTestSuite) TestIssue2419() { + suite.cluster.AddLeaderStore(1, 1) + suite.cluster.AddLeaderStore(2, 1) + suite.cluster.AddLeaderStore(3, 1) + suite.cluster.AddLeaderStore(4, 1) + suite.cluster.SetStoreOffline(3) + suite.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2, 3) + r := suite.cluster.GetRegion(1) r = r.Clone(core.WithAddPeer(&metapb.Peer{Id: 5, StoreId: 4, Role: metapb.PeerRole_Learner})) - op := s.rc.Check(r) - c.Assert(op, NotNil) - c.Assert(op.Desc(), Equals, "remove-orphan-peer") - c.Assert(op.Step(0).(operator.RemovePeer).FromStore, Equals, uint64(4)) + op := suite.rc.Check(r) + suite.NotNil(op) + suite.Equal("remove-orphan-peer", op.Desc()) + suite.Equal(uint64(4), op.Step(0).(operator.RemovePeer).FromStore) r = r.Clone(core.WithRemoveStorePeer(4)) - op = s.rc.Check(r) - c.Assert(op, NotNil) - c.Assert(op.Desc(), Equals, "replace-rule-offline-peer") - c.Assert(op.Step(0).(operator.AddLearner).ToStore, Equals, uint64(4)) - c.Assert(op.Step(1).(operator.PromoteLearner).ToStore, Equals, uint64(4)) - c.Assert(op.Step(2).(operator.RemovePeer).FromStore, Equals, uint64(3)) + op = suite.rc.Check(r) + suite.NotNil(op) + suite.Equal("replace-rule-offline-peer", op.Desc()) + suite.Equal(uint64(4), op.Step(0).(operator.AddLearner).ToStore) + suite.Equal(uint64(4), op.Step(1).(operator.PromoteLearner).ToStore) + suite.Equal(uint64(3), op.Step(2).(operator.RemovePeer).FromStore) } // Ref https://github.com/tikv/pd/issues/3521 // The problem is when offline a store, we may add learner multiple times if // the operator is timeout. -func (s *testRuleCheckerSuite) TestPriorityFixOrphanPeer(c *C) { - s.cluster.AddLabelsStore(1, 1, map[string]string{"host": "host1"}) - s.cluster.AddLabelsStore(2, 1, map[string]string{"host": "host1"}) - s.cluster.AddLabelsStore(3, 1, map[string]string{"host": "host2"}) - s.cluster.AddLabelsStore(4, 1, map[string]string{"host": "host4"}) - s.cluster.AddLabelsStore(5, 1, map[string]string{"host": "host5"}) - s.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2, 3) - op := s.rc.Check(s.cluster.GetRegion(1)) - c.Assert(op, IsNil) +func (suite *ruleCheckerTestSuite) TestPriorityFixOrphanPeer() { + suite.cluster.AddLabelsStore(1, 1, map[string]string{"host": "host1"}) + suite.cluster.AddLabelsStore(2, 1, map[string]string{"host": "host1"}) + suite.cluster.AddLabelsStore(3, 1, map[string]string{"host": "host2"}) + suite.cluster.AddLabelsStore(4, 1, map[string]string{"host": "host4"}) + suite.cluster.AddLabelsStore(5, 1, map[string]string{"host": "host5"}) + suite.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2, 3) + op := suite.rc.Check(suite.cluster.GetRegion(1)) + suite.Nil(op) var add operator.AddLearner var remove operator.RemovePeer - s.cluster.SetStoreOffline(2) - op = s.rc.Check(s.cluster.GetRegion(1)) - c.Assert(op, NotNil) - c.Assert(op.Step(0), FitsTypeOf, add) - c.Assert(op.Desc(), Equals, "replace-rule-offline-peer") - r := s.cluster.GetRegion(1).Clone(core.WithAddPeer( + suite.cluster.SetStoreOffline(2) + op = suite.rc.Check(suite.cluster.GetRegion(1)) + suite.NotNil(op) + suite.IsType(add, op.Step(0)) + suite.Equal("replace-rule-offline-peer", op.Desc()) + r := suite.cluster.GetRegion(1).Clone(core.WithAddPeer( &metapb.Peer{ Id: 5, StoreId: 4, Role: metapb.PeerRole_Learner, })) - s.cluster.PutRegion(r) - op = s.rc.Check(s.cluster.GetRegion(1)) - c.Assert(op.Step(0), FitsTypeOf, remove) - c.Assert(op.Desc(), Equals, "remove-orphan-peer") + suite.cluster.PutRegion(r) + op = suite.rc.Check(suite.cluster.GetRegion(1)) + suite.IsType(remove, op.Step(0)) + suite.Equal("remove-orphan-peer", op.Desc()) } -func (s *testRuleCheckerSuite) TestIssue3293(c *C) { - s.cluster.AddLabelsStore(1, 1, map[string]string{"host": "host1"}) - s.cluster.AddLabelsStore(2, 1, map[string]string{"host": "host1"}) - s.cluster.AddLabelsStore(3, 1, map[string]string{"host": "host2"}) - s.cluster.AddLabelsStore(4, 1, map[string]string{"host": "host4"}) - s.cluster.AddLabelsStore(5, 1, map[string]string{"host": "host5"}) - s.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2) - err := s.ruleManager.SetRule(&placement.Rule{ +func (suite *ruleCheckerTestSuite) TestIssue3293() { + suite.cluster.AddLabelsStore(1, 1, map[string]string{"host": "host1"}) + suite.cluster.AddLabelsStore(2, 1, map[string]string{"host": "host1"}) + suite.cluster.AddLabelsStore(3, 1, map[string]string{"host": "host2"}) + suite.cluster.AddLabelsStore(4, 1, map[string]string{"host": "host4"}) + suite.cluster.AddLabelsStore(5, 1, map[string]string{"host": "host5"}) + suite.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2) + err := suite.ruleManager.SetRule(&placement.Rule{ GroupID: "TiDB_DDL_51", ID: "0", Role: placement.Follower, @@ -431,26 +405,26 @@ func (s *testRuleCheckerSuite) TestIssue3293(c *C) { }, }, }) - c.Assert(err, IsNil) - s.cluster.DeleteStore(s.cluster.GetStore(5)) - err = s.ruleManager.SetRule(&placement.Rule{ + suite.NoError(err) + suite.cluster.DeleteStore(suite.cluster.GetStore(5)) + err = suite.ruleManager.SetRule(&placement.Rule{ GroupID: "TiDB_DDL_51", ID: "default", Role: placement.Voter, Count: 3, }) - c.Assert(err, IsNil) - err = s.ruleManager.DeleteRule("pd", "default") - c.Assert(err, IsNil) - op := s.rc.Check(s.cluster.GetRegion(1)) - c.Assert(op, NotNil) - c.Assert(op.Desc(), Equals, "add-rule-peer") + suite.NoError(err) + err = suite.ruleManager.DeleteRule("pd", "default") + suite.NoError(err) + op := suite.rc.Check(suite.cluster.GetRegion(1)) + suite.NotNil(op) + suite.Equal("add-rule-peer", op.Desc()) } -func (s *testRuleCheckerSuite) TestIssue3299(c *C) { - s.cluster.AddLabelsStore(1, 1, map[string]string{"host": "host1"}) - s.cluster.AddLabelsStore(2, 1, map[string]string{"dc": "sh"}) - s.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2) +func (suite *ruleCheckerTestSuite) TestIssue3299() { + suite.cluster.AddLabelsStore(1, 1, map[string]string{"host": "host1"}) + suite.cluster.AddLabelsStore(2, 1, map[string]string{"dc": "sh"}) + suite.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2) testCases := []struct { constraints []placement.LabelConstraint @@ -524,7 +498,7 @@ func (s *testRuleCheckerSuite) TestIssue3299(c *C) { } for _, t := range testCases { - err := s.ruleManager.SetRule(&placement.Rule{ + err := suite.ruleManager.SetRule(&placement.Rule{ GroupID: "p", ID: "0", Role: placement.Follower, @@ -532,21 +506,21 @@ func (s *testRuleCheckerSuite) TestIssue3299(c *C) { LabelConstraints: t.constraints, }) if t.err != "" { - c.Assert(err, ErrorMatches, t.err) + suite.Regexp(t.err, err.Error()) } else { - c.Assert(err, IsNil) + suite.NoError(err) } } } // See issue: https://github.com/tikv/pd/issues/3705 -func (s *testRuleCheckerSuite) TestFixDownPeer(c *C) { - s.cluster.AddLabelsStore(1, 1, map[string]string{"zone": "z1"}) - s.cluster.AddLabelsStore(2, 1, map[string]string{"zone": "z1"}) - s.cluster.AddLabelsStore(3, 1, map[string]string{"zone": "z2"}) - s.cluster.AddLabelsStore(4, 1, map[string]string{"zone": "z3"}) - s.cluster.AddLabelsStore(5, 1, map[string]string{"zone": "z3"}) - s.cluster.AddLeaderRegion(1, 1, 3, 4) +func (suite *ruleCheckerTestSuite) TestFixDownPeer() { + suite.cluster.AddLabelsStore(1, 1, map[string]string{"zone": "z1"}) + suite.cluster.AddLabelsStore(2, 1, map[string]string{"zone": "z1"}) + suite.cluster.AddLabelsStore(3, 1, map[string]string{"zone": "z2"}) + suite.cluster.AddLabelsStore(4, 1, map[string]string{"zone": "z3"}) + suite.cluster.AddLabelsStore(5, 1, map[string]string{"zone": "z3"}) + suite.cluster.AddLeaderRegion(1, 1, 3, 4) rule := &placement.Rule{ GroupID: "pd", ID: "test", @@ -556,33 +530,33 @@ func (s *testRuleCheckerSuite) TestFixDownPeer(c *C) { Count: 3, LocationLabels: []string{"zone"}, } - s.ruleManager.SetRule(rule) + suite.ruleManager.SetRule(rule) - region := s.cluster.GetRegion(1) - c.Assert(s.rc.Check(region), IsNil) + region := suite.cluster.GetRegion(1) + suite.Nil(suite.rc.Check(region)) - s.cluster.SetStoreDown(4) + suite.cluster.SetStoreDown(4) region = region.Clone(core.WithDownPeers([]*pdpb.PeerStats{ {Peer: region.GetStorePeer(4), DownSeconds: 6000}, })) - testutil.CheckTransferPeer(c, s.rc.Check(region), operator.OpRegion, 4, 5) + testutil.CheckTransferPeerWithTestify(suite.Require(), suite.rc.Check(region), operator.OpRegion, 4, 5) - s.cluster.SetStoreDown(5) - testutil.CheckTransferPeer(c, s.rc.Check(region), operator.OpRegion, 4, 2) + suite.cluster.SetStoreDown(5) + testutil.CheckTransferPeerWithTestify(suite.Require(), suite.rc.Check(region), operator.OpRegion, 4, 2) rule.IsolationLevel = "zone" - s.ruleManager.SetRule(rule) - c.Assert(s.rc.Check(region), IsNil) + suite.ruleManager.SetRule(rule) + suite.Nil(suite.rc.Check(region)) } // See issue: https://github.com/tikv/pd/issues/3705 -func (s *testRuleCheckerSuite) TestFixOfflinePeer(c *C) { - s.cluster.AddLabelsStore(1, 1, map[string]string{"zone": "z1"}) - s.cluster.AddLabelsStore(2, 1, map[string]string{"zone": "z1"}) - s.cluster.AddLabelsStore(3, 1, map[string]string{"zone": "z2"}) - s.cluster.AddLabelsStore(4, 1, map[string]string{"zone": "z3"}) - s.cluster.AddLabelsStore(5, 1, map[string]string{"zone": "z3"}) - s.cluster.AddLeaderRegion(1, 1, 3, 4) +func (suite *ruleCheckerTestSuite) TestFixOfflinePeer() { + suite.cluster.AddLabelsStore(1, 1, map[string]string{"zone": "z1"}) + suite.cluster.AddLabelsStore(2, 1, map[string]string{"zone": "z1"}) + suite.cluster.AddLabelsStore(3, 1, map[string]string{"zone": "z2"}) + suite.cluster.AddLabelsStore(4, 1, map[string]string{"zone": "z3"}) + suite.cluster.AddLabelsStore(5, 1, map[string]string{"zone": "z3"}) + suite.cluster.AddLeaderRegion(1, 1, 3, 4) rule := &placement.Rule{ GroupID: "pd", ID: "test", @@ -592,30 +566,31 @@ func (s *testRuleCheckerSuite) TestFixOfflinePeer(c *C) { Count: 3, LocationLabels: []string{"zone"}, } - s.ruleManager.SetRule(rule) + suite.ruleManager.SetRule(rule) - region := s.cluster.GetRegion(1) - c.Assert(s.rc.Check(region), IsNil) + region := suite.cluster.GetRegion(1) + suite.Nil(suite.rc.Check(region)) - s.cluster.SetStoreOffline(4) - testutil.CheckTransferPeer(c, s.rc.Check(region), operator.OpRegion, 4, 5) + suite.cluster.SetStoreOffline(4) + testutil.CheckTransferPeerWithTestify(suite.Require(), suite.rc.Check(region), operator.OpRegion, 4, 5) - s.cluster.SetStoreOffline(5) - testutil.CheckTransferPeer(c, s.rc.Check(region), operator.OpRegion, 4, 2) + suite.cluster.SetStoreOffline(5) + testutil.CheckTransferPeerWithTestify(suite.Require(), suite.rc.Check(region), operator.OpRegion, 4, 2) rule.IsolationLevel = "zone" - s.ruleManager.SetRule(rule) - c.Assert(s.rc.Check(region), IsNil) + suite.ruleManager.SetRule(rule) + suite.Nil(suite.rc.Check(region)) } -func (s *testRuleCheckerSerialSuite) TestRuleCache(c *C) { - s.cluster.AddLabelsStore(1, 1, map[string]string{"zone": "z1"}) - s.cluster.AddLabelsStore(2, 1, map[string]string{"zone": "z1"}) - s.cluster.AddLabelsStore(3, 1, map[string]string{"zone": "z2"}) - s.cluster.AddLabelsStore(4, 1, map[string]string{"zone": "z3"}) - s.cluster.AddLabelsStore(5, 1, map[string]string{"zone": "z3"}) - s.cluster.AddRegionStore(999, 1) - s.cluster.AddLeaderRegion(1, 1, 3, 4) +func (suite *ruleCheckerTestSuite) TestRuleCache() { + suite.cluster.PersistOptions.SetPlacementRulesCacheEnabled(true) + suite.cluster.AddLabelsStore(1, 1, map[string]string{"zone": "z1"}) + suite.cluster.AddLabelsStore(2, 1, map[string]string{"zone": "z1"}) + suite.cluster.AddLabelsStore(3, 1, map[string]string{"zone": "z2"}) + suite.cluster.AddLabelsStore(4, 1, map[string]string{"zone": "z3"}) + suite.cluster.AddLabelsStore(5, 1, map[string]string{"zone": "z3"}) + suite.cluster.AddRegionStore(999, 1) + suite.cluster.AddLeaderRegion(1, 1, 3, 4) rule := &placement.Rule{ GroupID: "pd", ID: "test", @@ -625,10 +600,10 @@ func (s *testRuleCheckerSerialSuite) TestRuleCache(c *C) { Count: 3, LocationLabels: []string{"zone"}, } - s.ruleManager.SetRule(rule) - region := s.cluster.GetRegion(1) + suite.ruleManager.SetRule(rule) + region := suite.cluster.GetRegion(1) region = region.Clone(core.WithIncConfVer(), core.WithIncVersion()) - c.Assert(s.rc.Check(region), IsNil) + suite.Nil(suite.rc.Check(region)) testcases := []struct { name string @@ -669,35 +644,35 @@ func (s *testRuleCheckerSerialSuite) TestRuleCache(c *C) { }, } for _, testcase := range testcases { - c.Log(testcase.name) + suite.T().Log(testcase.name) if testcase.stillCached { - c.Assert(failpoint.Enable("github.com/tikv/pd/server/schedule/checker/assertShouldCache", "return(true)"), IsNil) - s.rc.Check(testcase.region) - c.Assert(failpoint.Disable("github.com/tikv/pd/server/schedule/checker/assertShouldCache"), IsNil) + suite.NoError(failpoint.Enable("github.com/tikv/pd/server/schedule/checker/assertShouldCache", "return(true)")) + suite.rc.Check(testcase.region) + suite.NoError(failpoint.Disable("github.com/tikv/pd/server/schedule/checker/assertShouldCache")) } else { - c.Assert(failpoint.Enable("github.com/tikv/pd/server/schedule/checker/assertShouldNotCache", "return(true)"), IsNil) - s.rc.Check(testcase.region) - c.Assert(failpoint.Disable("github.com/tikv/pd/server/schedule/checker/assertShouldNotCache"), IsNil) + suite.NoError(failpoint.Enable("github.com/tikv/pd/server/schedule/checker/assertShouldNotCache", "return(true)")) + suite.rc.Check(testcase.region) + suite.NoError(failpoint.Disable("github.com/tikv/pd/server/schedule/checker/assertShouldNotCache")) } } } // Ref https://github.com/tikv/pd/issues/4045 -func (s *testRuleCheckerSuite) TestSkipFixOrphanPeerIfSelectedPeerisPendingOrDown(c *C) { - s.cluster.AddLabelsStore(1, 1, map[string]string{"host": "host1"}) - s.cluster.AddLabelsStore(2, 1, map[string]string{"host": "host1"}) - s.cluster.AddLabelsStore(3, 1, map[string]string{"host": "host2"}) - s.cluster.AddLabelsStore(4, 1, map[string]string{"host": "host4"}) - s.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2, 3, 4) +func (suite *ruleCheckerTestSuite) TestSkipFixOrphanPeerIfSelectedPeerisPendingOrDown() { + suite.cluster.AddLabelsStore(1, 1, map[string]string{"host": "host1"}) + suite.cluster.AddLabelsStore(2, 1, map[string]string{"host": "host1"}) + suite.cluster.AddLabelsStore(3, 1, map[string]string{"host": "host2"}) + suite.cluster.AddLabelsStore(4, 1, map[string]string{"host": "host4"}) + suite.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2, 3, 4) // set peer3 and peer4 to pending - r1 := s.cluster.GetRegion(1) + r1 := suite.cluster.GetRegion(1) r1 = r1.Clone(core.WithPendingPeers([]*metapb.Peer{r1.GetStorePeer(3), r1.GetStorePeer(4)})) - s.cluster.PutRegion(r1) + suite.cluster.PutRegion(r1) // should not remove extra peer - op := s.rc.Check(s.cluster.GetRegion(1)) - c.Assert(op, IsNil) + op := suite.rc.Check(suite.cluster.GetRegion(1)) + suite.Nil(op) // set peer3 to down-peer r1 = r1.Clone(core.WithPendingPeers([]*metapb.Peer{r1.GetStorePeer(4)})) @@ -707,39 +682,39 @@ func (s *testRuleCheckerSuite) TestSkipFixOrphanPeerIfSelectedPeerisPendingOrDow DownSeconds: 42, }, })) - s.cluster.PutRegion(r1) + suite.cluster.PutRegion(r1) // should not remove extra peer - op = s.rc.Check(s.cluster.GetRegion(1)) - c.Assert(op, IsNil) + op = suite.rc.Check(suite.cluster.GetRegion(1)) + suite.Nil(op) // set peer3 to normal r1 = r1.Clone(core.WithDownPeers(nil)) - s.cluster.PutRegion(r1) + suite.cluster.PutRegion(r1) // should remove extra peer now var remove operator.RemovePeer - op = s.rc.Check(s.cluster.GetRegion(1)) - c.Assert(op.Step(0), FitsTypeOf, remove) - c.Assert(op.Desc(), Equals, "remove-orphan-peer") + op = suite.rc.Check(suite.cluster.GetRegion(1)) + suite.IsType(remove, op.Step(0)) + suite.Equal("remove-orphan-peer", op.Desc()) } -func (s *testRuleCheckerSuite) TestPriorityFitHealthPeers(c *C) { - s.cluster.AddLabelsStore(1, 1, map[string]string{"host": "host1"}) - s.cluster.AddLabelsStore(2, 1, map[string]string{"host": "host2"}) - s.cluster.AddLabelsStore(3, 1, map[string]string{"host": "host3"}) - s.cluster.AddLabelsStore(4, 1, map[string]string{"host": "host4"}) - s.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2, 3, 4) - r1 := s.cluster.GetRegion(1) +func (suite *ruleCheckerTestSuite) TestPriorityFitHealthPeers() { + suite.cluster.AddLabelsStore(1, 1, map[string]string{"host": "host1"}) + suite.cluster.AddLabelsStore(2, 1, map[string]string{"host": "host2"}) + suite.cluster.AddLabelsStore(3, 1, map[string]string{"host": "host3"}) + suite.cluster.AddLabelsStore(4, 1, map[string]string{"host": "host4"}) + suite.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2, 3, 4) + r1 := suite.cluster.GetRegion(1) // set peer3 to pending r1 = r1.Clone(core.WithPendingPeers([]*metapb.Peer{r1.GetPeer(3)})) - s.cluster.PutRegion(r1) + suite.cluster.PutRegion(r1) var remove operator.RemovePeer - op := s.rc.Check(s.cluster.GetRegion(1)) - c.Assert(op.Step(0), FitsTypeOf, remove) - c.Assert(op.Desc(), Equals, "remove-orphan-peer") + op := suite.rc.Check(suite.cluster.GetRegion(1)) + suite.IsType(remove, op.Step(0)) + suite.Equal("remove-orphan-peer", op.Desc()) // set peer3 to down r1 = r1.Clone(core.WithDownPeers([]*pdpb.PeerStats{ @@ -749,18 +724,18 @@ func (s *testRuleCheckerSuite) TestPriorityFitHealthPeers(c *C) { }, })) r1 = r1.Clone(core.WithPendingPeers(nil)) - s.cluster.PutRegion(r1) + suite.cluster.PutRegion(r1) - op = s.rc.Check(s.cluster.GetRegion(1)) - c.Assert(op.Step(0), FitsTypeOf, remove) - c.Assert(op.Desc(), Equals, "remove-orphan-peer") + op = suite.rc.Check(suite.cluster.GetRegion(1)) + suite.IsType(remove, op.Step(0)) + suite.Equal("remove-orphan-peer", op.Desc()) } // Ref https://github.com/tikv/pd/issues/4140 -func (s *testRuleCheckerSuite) TestDemoteVoter(c *C) { - s.cluster.AddLabelsStore(1, 1, map[string]string{"zone": "z1"}) - s.cluster.AddLabelsStore(4, 1, map[string]string{"zone": "z4"}) - region := s.cluster.AddLeaderRegion(1, 1, 4) +func (suite *ruleCheckerTestSuite) TestDemoteVoter() { + suite.cluster.AddLabelsStore(1, 1, map[string]string{"zone": "z1"}) + suite.cluster.AddLabelsStore(4, 1, map[string]string{"zone": "z4"}) + region := suite.cluster.AddLeaderRegion(1, 1, 4) rule := &placement.Rule{ GroupID: "pd", ID: "test", @@ -787,57 +762,57 @@ func (s *testRuleCheckerSuite) TestDemoteVoter(c *C) { }, }, } - s.ruleManager.SetRule(rule) - s.ruleManager.SetRule(rule2) - s.ruleManager.DeleteRule("pd", "default") - op := s.rc.Check(region) - c.Assert(op, NotNil) - c.Assert(op.Desc(), Equals, "fix-demote-voter") + suite.ruleManager.SetRule(rule) + suite.ruleManager.SetRule(rule2) + suite.ruleManager.DeleteRule("pd", "default") + op := suite.rc.Check(region) + suite.NotNil(op) + suite.Equal("fix-demote-voter", op.Desc()) } -func (s *testRuleCheckerSuite) TestOfflineAndDownStore(c *C) { - s.cluster.AddLabelsStore(1, 1, map[string]string{"zone": "z1"}) - s.cluster.AddLabelsStore(2, 1, map[string]string{"zone": "z4"}) - s.cluster.AddLabelsStore(3, 1, map[string]string{"zone": "z1"}) - s.cluster.AddLabelsStore(4, 1, map[string]string{"zone": "z4"}) - region := s.cluster.AddLeaderRegion(1, 1, 2, 3) - op := s.rc.Check(region) - c.Assert(op, IsNil) +func (suite *ruleCheckerTestSuite) TestOfflineAndDownStore() { + suite.cluster.AddLabelsStore(1, 1, map[string]string{"zone": "z1"}) + suite.cluster.AddLabelsStore(2, 1, map[string]string{"zone": "z4"}) + suite.cluster.AddLabelsStore(3, 1, map[string]string{"zone": "z1"}) + suite.cluster.AddLabelsStore(4, 1, map[string]string{"zone": "z4"}) + region := suite.cluster.AddLeaderRegion(1, 1, 2, 3) + op := suite.rc.Check(region) + suite.Nil(op) // assert rule checker should generate replace offline peer operator after cached - s.cluster.SetStoreOffline(1) - op = s.rc.Check(region) - c.Assert(op, NotNil) - c.Assert(op.Desc(), Equals, "replace-rule-offline-peer") + suite.cluster.SetStoreOffline(1) + op = suite.rc.Check(region) + suite.NotNil(op) + suite.Equal("replace-rule-offline-peer", op.Desc()) // re-cache the regionFit - s.cluster.SetStoreUp(1) - op = s.rc.Check(region) - c.Assert(op, IsNil) + suite.cluster.SetStoreUp(1) + op = suite.rc.Check(region) + suite.Nil(op) // assert rule checker should generate replace down peer operator after cached - s.cluster.SetStoreDown(2) + suite.cluster.SetStoreDown(2) region = region.Clone(core.WithDownPeers([]*pdpb.PeerStats{{Peer: region.GetStorePeer(2), DownSeconds: 60000}})) - op = s.rc.Check(region) - c.Assert(op, NotNil) - c.Assert(op.Desc(), Equals, "replace-rule-down-peer") + op = suite.rc.Check(region) + suite.NotNil(op) + suite.Equal("replace-rule-down-peer", op.Desc()) } -func (s *testRuleCheckerSuite) TestPendingList(c *C) { +func (suite *ruleCheckerTestSuite) TestPendingList() { // no enough store - s.cluster.AddLeaderStore(1, 1) - s.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2) - op := s.rc.Check(s.cluster.GetRegion(1)) - c.Assert(op, IsNil) - _, exist := s.rc.pendingList.Get(1) - c.Assert(exist, IsTrue) + suite.cluster.AddLeaderStore(1, 1) + suite.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2) + op := suite.rc.Check(suite.cluster.GetRegion(1)) + suite.Nil(op) + _, exist := suite.rc.pendingList.Get(1) + suite.True(exist) // add more stores - s.cluster.AddLeaderStore(2, 1) - s.cluster.AddLeaderStore(3, 1) - op = s.rc.Check(s.cluster.GetRegion(1)) - c.Assert(op, NotNil) - c.Assert(op.Desc(), Equals, "add-rule-peer") - c.Assert(op.GetPriorityLevel(), Equals, core.HighPriority) - c.Assert(op.Step(0).(operator.AddLearner).ToStore, Equals, uint64(3)) - _, exist = s.rc.pendingList.Get(1) - c.Assert(exist, IsFalse) + suite.cluster.AddLeaderStore(2, 1) + suite.cluster.AddLeaderStore(3, 1) + op = suite.rc.Check(suite.cluster.GetRegion(1)) + suite.NotNil(op) + suite.Equal("add-rule-peer", op.Desc()) + suite.Equal(core.HighPriority, op.GetPriorityLevel()) + suite.Equal(uint64(3), op.Step(0).(operator.AddLearner).ToStore) + _, exist = suite.rc.pendingList.Get(1) + suite.False(exist) } diff --git a/server/schedule/checker/split_checker_test.go b/server/schedule/checker/split_checker_test.go index 606c5953762..957ca87bc07 100644 --- a/server/schedule/checker/split_checker_test.go +++ b/server/schedule/checker/split_checker_test.go @@ -17,8 +17,9 @@ package checker import ( "context" "encoding/hex" + "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/mock/mockcluster" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/schedule/labeler" @@ -26,37 +27,18 @@ import ( "github.com/tikv/pd/server/schedule/placement" ) -var _ = Suite(&testSplitCheckerSuite{}) - -type testSplitCheckerSuite struct { - cluster *mockcluster.Cluster - ruleManager *placement.RuleManager - labeler *labeler.RegionLabeler - sc *SplitChecker - ctx context.Context - cancel context.CancelFunc -} - -func (s *testSplitCheckerSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) -} - -func (s *testSplitCheckerSuite) TearDownTest(c *C) { - s.cancel() -} - -func (s *testSplitCheckerSuite) SetUpTest(c *C) { +func TestSplit(t *testing.T) { + re := require.New(t) cfg := config.NewTestOptions() cfg.GetReplicationConfig().EnablePlacementRules = true - s.cluster = mockcluster.NewCluster(s.ctx, cfg) - s.ruleManager = s.cluster.RuleManager - s.labeler = s.cluster.RegionLabeler - s.sc = NewSplitChecker(s.cluster, s.ruleManager, s.labeler) -} - -func (s *testSplitCheckerSuite) TestSplit(c *C) { - s.cluster.AddLeaderStore(1, 1) - s.ruleManager.SetRule(&placement.Rule{ + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster := mockcluster.NewCluster(ctx, cfg) + ruleManager := cluster.RuleManager + regionLabeler := cluster.RegionLabeler + sc := NewSplitChecker(cluster, ruleManager, regionLabeler) + cluster.AddLeaderStore(1, 1) + ruleManager.SetRule(&placement.Rule{ GroupID: "test", ID: "test", StartKeyHex: "aa", @@ -64,25 +46,25 @@ func (s *testSplitCheckerSuite) TestSplit(c *C) { Role: placement.Voter, Count: 1, }) - s.cluster.AddLeaderRegionWithRange(1, "", "", 1) - op := s.sc.Check(s.cluster.GetRegion(1)) - c.Assert(op, NotNil) - c.Assert(op.Len(), Equals, 1) + cluster.AddLeaderRegionWithRange(1, "", "", 1) + op := sc.Check(cluster.GetRegion(1)) + re.NotNil(op) + re.Equal(1, op.Len()) splitKeys := op.Step(0).(operator.SplitRegion).SplitKeys - c.Assert(hex.EncodeToString(splitKeys[0]), Equals, "aa") - c.Assert(hex.EncodeToString(splitKeys[1]), Equals, "cc") + re.Equal("aa", hex.EncodeToString(splitKeys[0])) + re.Equal("cc", hex.EncodeToString(splitKeys[1])) // region label has higher priority. - s.labeler.SetLabelRule(&labeler.LabelRule{ + regionLabeler.SetLabelRule(&labeler.LabelRule{ ID: "test", Labels: []labeler.RegionLabel{{Key: "test", Value: "test"}}, RuleType: labeler.KeyRange, Data: makeKeyRanges("bb", "dd"), }) - op = s.sc.Check(s.cluster.GetRegion(1)) - c.Assert(op, NotNil) - c.Assert(op.Len(), Equals, 1) + op = sc.Check(cluster.GetRegion(1)) + re.NotNil(op) + re.Equal(1, op.Len()) splitKeys = op.Step(0).(operator.SplitRegion).SplitKeys - c.Assert(hex.EncodeToString(splitKeys[0]), Equals, "bb") - c.Assert(hex.EncodeToString(splitKeys[1]), Equals, "dd") + re.Equal("bb", hex.EncodeToString(splitKeys[0])) + re.Equal("dd", hex.EncodeToString(splitKeys[1])) } From ddf711bd5f9f39fb71d4fdfc78c11fabe041ee4e Mon Sep 17 00:00:00 2001 From: JmPotato Date: Mon, 20 Jun 2022 18:10:37 +0800 Subject: [PATCH 58/82] tests: testify the TSO tests (#5169) ref tikv/pd#4813 Testify the TSO tests. Signed-off-by: JmPotato Co-authored-by: Ti Chi Robot --- pkg/testutil/testutil.go | 21 +- tests/client/client_test.go | 4 +- tests/cluster.go | 26 +- tests/server/tso/allocator_test.go | 115 ++++----- tests/server/tso/common_test.go | 25 +- tests/server/tso/consistency_test.go | 345 ++++++++++++--------------- tests/server/tso/global_tso_test.go | 122 ++++------ tests/server/tso/manager_test.go | 94 ++++---- tests/server/tso/tso_test.go | 57 ++--- 9 files changed, 351 insertions(+), 458 deletions(-) diff --git a/pkg/testutil/testutil.go b/pkg/testutil/testutil.go index bc54e901a63..59063aa5385 100644 --- a/pkg/testutil/testutil.go +++ b/pkg/testutil/testutil.go @@ -26,8 +26,9 @@ import ( ) const ( - waitMaxRetry = 200 - waitRetrySleep = time.Millisecond * 100 + defaultWaitRetryTimes = 200 + defaultSleepInterval = time.Millisecond * 100 + defaultWaitFor = time.Second * 20 ) // CheckFunc is a condition checker that passed to WaitUntil. Its implementation @@ -38,6 +39,7 @@ type CheckFunc func() bool type WaitOp struct { retryTimes int sleepInterval time.Duration + waitFor time.Duration } // WaitOption configures WaitOp @@ -53,13 +55,18 @@ func WithSleepInterval(sleep time.Duration) WaitOption { return func(op *WaitOp) { op.sleepInterval = sleep } } +// WithWaitFor specify the max wait for duration +func WithWaitFor(waitFor time.Duration) WaitOption { + return func(op *WaitOp) { op.waitFor = waitFor } +} + // WaitUntil repeatedly evaluates f() for a period of time, util it returns true. // NOTICE: this function will be removed soon, please use `Eventually` instead. func WaitUntil(c *check.C, f CheckFunc, opts ...WaitOption) { c.Log("wait start") option := &WaitOp{ - retryTimes: waitMaxRetry, - sleepInterval: waitRetrySleep, + retryTimes: defaultWaitRetryTimes, + sleepInterval: defaultSleepInterval, } for _, opt := range opts { opt(option) @@ -76,15 +83,15 @@ func WaitUntil(c *check.C, f CheckFunc, opts ...WaitOption) { // Eventually asserts that given condition will be met in a period of time. func Eventually(re *require.Assertions, condition func() bool, opts ...WaitOption) { option := &WaitOp{ - retryTimes: waitMaxRetry, - sleepInterval: waitRetrySleep, + waitFor: defaultWaitFor, + sleepInterval: defaultSleepInterval, } for _, opt := range opts { opt(option) } re.Eventually( condition, - option.sleepInterval*time.Duration(option.retryTimes), + option.waitFor, option.sleepInterval, ) } diff --git a/tests/client/client_test.go b/tests/client/client_test.go index 975b54d72f8..86824261d12 100644 --- a/tests/client/client_test.go +++ b/tests/client/client_test.go @@ -235,7 +235,7 @@ func TestTSOAllocatorLeader(t *testing.T) { err = cluster.RunInitialServers() re.NoError(err) - cluster.WaitAllLeadersWithTestify(re, dcLocationConfig) + cluster.WaitAllLeaders(re, dcLocationConfig) var ( testServers = cluster.GetServers() @@ -347,7 +347,7 @@ func TestGlobalAndLocalTSO(t *testing.T) { re.NoError(err) dcLocationConfig["pd4"] = "dc-4" cluster.CheckClusterDCLocation() - cluster.WaitAllLeadersWithTestify(re, dcLocationConfig) + cluster.WaitAllLeaders(re, dcLocationConfig) // Test a nonexistent dc-location for Local TSO p, l, err := cli.GetLocalTS(context.TODO(), "nonexistent-dc") diff --git a/tests/cluster.go b/tests/cluster.go index 0d7efe90ec9..6c79d680c7f 100644 --- a/tests/cluster.go +++ b/tests/cluster.go @@ -22,7 +22,6 @@ import ( "time" "github.com/coreos/go-semver/semver" - "github.com/pingcap/check" "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" @@ -605,27 +604,7 @@ func (c *TestCluster) WaitAllocatorLeader(dcLocation string, ops ...WaitOption) } // WaitAllLeaders will block and wait for the election of PD leader and all Local TSO Allocator leaders. -func (c *TestCluster) WaitAllLeaders(testC *check.C, dcLocations map[string]string) { - c.WaitLeader() - c.CheckClusterDCLocation() - // Wait for each DC's Local TSO Allocator leader - wg := sync.WaitGroup{} - for _, dcLocation := range dcLocations { - wg.Add(1) - go func(dc string) { - testutil.WaitUntil(testC, func() bool { - leaderName := c.WaitAllocatorLeader(dc) - return leaderName != "" - }) - wg.Done() - }(dcLocation) - } - wg.Wait() -} - -// WaitAllLeadersWithTestify will block and wait for the election of PD leader and all Local TSO Allocator leaders. -// NOTICE: this is a temporary function that we will be used to replace `WaitAllLeaders` later. -func (c *TestCluster) WaitAllLeadersWithTestify(re *require.Assertions, dcLocations map[string]string) { +func (c *TestCluster) WaitAllLeaders(re *require.Assertions, dcLocations map[string]string) { c.WaitLeader() c.CheckClusterDCLocation() // Wait for each DC's Local TSO Allocator leader @@ -634,8 +613,7 @@ func (c *TestCluster) WaitAllLeadersWithTestify(re *require.Assertions, dcLocati wg.Add(1) go func(dc string) { testutil.Eventually(re, func() bool { - leaderName := c.WaitAllocatorLeader(dc) - return leaderName != "" + return c.WaitAllocatorLeader(dc) != "" }) wg.Done() }(dcLocation) diff --git a/tests/server/tso/allocator_test.go b/tests/server/tso/allocator_test.go index c7bb38e5d9a..59cedea0783 100644 --- a/tests/server/tso/allocator_test.go +++ b/tests/server/tso/allocator_test.go @@ -21,10 +21,11 @@ import ( "context" "strconv" "sync" + "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/etcdutil" "github.com/tikv/pd/pkg/slice" "github.com/tikv/pd/pkg/testutil" @@ -33,23 +34,10 @@ import ( "github.com/tikv/pd/tests" ) -var _ = Suite(&testAllocatorSuite{}) - -type testAllocatorSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *testAllocatorSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) -} - -func (s *testAllocatorSuite) TearDownSuite(c *C) { - s.cancel() -} - -// Make sure we have the correct number of Local TSO Allocator leaders. -func (s *testAllocatorSuite) TestAllocatorLeader(c *C) { +func TestAllocatorLeader(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() // There will be three Local TSO Allocator leaders elected dcLocationConfig := map[string]string{ "pd2": "dc-1", @@ -57,19 +45,16 @@ func (s *testAllocatorSuite) TestAllocatorLeader(c *C) { "pd6": "leader", /* Test dc-location name is same as the special key */ } dcLocationNum := len(dcLocationConfig) - cluster, err := tests.NewTestCluster(s.ctx, dcLocationNum*2, func(conf *config.Config, serverName string) { + cluster, err := tests.NewTestCluster(ctx, dcLocationNum*2, func(conf *config.Config, serverName string) { if zoneLabel, ok := dcLocationConfig[serverName]; ok { conf.EnableLocalTSO = true conf.Labels[config.ZoneLabel] = zoneLabel } }) + re.NoError(err) defer cluster.Destroy() - c.Assert(err, IsNil) - - err = cluster.RunInitialServers() - c.Assert(err, IsNil) - - cluster.WaitAllLeaders(c, dcLocationConfig) + re.NoError(cluster.RunInitialServers()) + cluster.WaitAllLeaders(re, dcLocationConfig) // To check whether we have enough Local TSO Allocator leaders allAllocatorLeaders := make([]tso.Allocator, 0, dcLocationNum) for _, server := range cluster.GetServers() { @@ -80,7 +65,7 @@ func (s *testAllocatorSuite) TestAllocatorLeader(c *C) { tso.FilterUninitialized()) // One PD server will have at most three initialized Local TSO Allocators, // which also means three allocator leaders - c.Assert(len(allocators), LessEqual, dcLocationNum) + re.LessOrEqual(len(allocators), dcLocationNum) if len(allocators) == 0 { continue } @@ -96,7 +81,7 @@ func (s *testAllocatorSuite) TestAllocatorLeader(c *C) { } // At the end, we should have three initialized Local TSO Allocator, // i.e., the Local TSO Allocator leaders for all dc-locations in testDCLocations - c.Assert(allAllocatorLeaders, HasLen, dcLocationNum) + re.Len(allAllocatorLeaders, dcLocationNum) allocatorLeaderMemberIDs := make([]uint64, 0, dcLocationNum) for _, allocator := range allAllocatorLeaders { allocatorLeader, _ := allocator.(*tso.LocalTSOAllocator) @@ -106,62 +91,63 @@ func (s *testAllocatorSuite) TestAllocatorLeader(c *C) { // Filter out Global TSO Allocator allocators := server.GetTSOAllocatorManager().GetAllocators(tso.FilterDCLocation(tso.GlobalDCLocation)) if _, ok := dcLocationConfig[server.GetServer().Name()]; !ok { - c.Assert(allocators, HasLen, 0) + re.Empty(allocators) continue } - c.Assert(allocators, HasLen, dcLocationNum) + re.Len(allocators, dcLocationNum) for _, allocator := range allocators { allocatorFollower, _ := allocator.(*tso.LocalTSOAllocator) allocatorFollowerMemberID := allocatorFollower.GetAllocatorLeader().GetMemberId() - c.Assert( + re.True( slice.AnyOf( allocatorLeaderMemberIDs, - func(i int) bool { return allocatorLeaderMemberIDs[i] == allocatorFollowerMemberID }), - IsTrue) + func(i int) bool { return allocatorLeaderMemberIDs[i] == allocatorFollowerMemberID }, + ), + ) } } } -func (s *testAllocatorSuite) TestPriorityAndDifferentLocalTSO(c *C) { +func TestPriorityAndDifferentLocalTSO(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() dcLocationConfig := map[string]string{ "pd1": "dc-1", "pd2": "dc-2", "pd3": "dc-3", } dcLocationNum := len(dcLocationConfig) - cluster, err := tests.NewTestCluster(s.ctx, dcLocationNum, func(conf *config.Config, serverName string) { + cluster, err := tests.NewTestCluster(ctx, dcLocationNum, func(conf *config.Config, serverName string) { conf.EnableLocalTSO = true conf.Labels[config.ZoneLabel] = dcLocationConfig[serverName] }) defer cluster.Destroy() - c.Assert(err, IsNil) + re.NoError(err) + re.NoError(cluster.RunInitialServers()) - err = cluster.RunInitialServers() - c.Assert(err, IsNil) - - cluster.WaitAllLeaders(c, dcLocationConfig) + cluster.WaitAllLeaders(re, dcLocationConfig) // Wait for all nodes becoming healthy. time.Sleep(time.Second * 5) // Join a new dc-location - pd4, err := cluster.Join(s.ctx, func(conf *config.Config, serverName string) { + pd4, err := cluster.Join(ctx, func(conf *config.Config, serverName string) { conf.EnableLocalTSO = true conf.Labels[config.ZoneLabel] = "dc-4" }) - c.Assert(err, IsNil) - err = pd4.Run() - c.Assert(err, IsNil) + re.NoError(err) + re.NoError(pd4.Run()) dcLocationConfig["pd4"] = "dc-4" cluster.CheckClusterDCLocation() - testutil.WaitUntil(c, func() bool { - leaderName := cluster.WaitAllocatorLeader("dc-4") - return leaderName != "" - }) + re.NotEqual("", cluster.WaitAllocatorLeader( + "dc-4", + tests.WithRetryTimes(90), tests.WithWaitInterval(time.Second), + )) // Scatter the Local TSO Allocators to different servers waitAllocatorPriorityCheck(cluster) - cluster.WaitAllLeaders(c, dcLocationConfig) + cluster.WaitAllLeaders(re, dcLocationConfig) // Before the priority is checked, we may have allocators typology like this: // pd1: dc-1, dc-2 and dc-3 allocator leader @@ -178,10 +164,9 @@ func (s *testAllocatorSuite) TestPriorityAndDifferentLocalTSO(c *C) { for serverName, dcLocation := range dcLocationConfig { go func(serName, dc string) { defer wg.Done() - testutil.WaitUntil(c, func() bool { - leaderName := cluster.WaitAllocatorLeader(dc) - return leaderName == serName - }, testutil.WithRetryTimes(12), testutil.WithSleepInterval(5*time.Second)) + testutil.Eventually(re, func() bool { + return cluster.WaitAllocatorLeader(dc) == serName + }, testutil.WithWaitFor(time.Second*90), testutil.WithSleepInterval(time.Second)) }(serverName, dcLocation) } wg.Wait() @@ -189,12 +174,12 @@ func (s *testAllocatorSuite) TestPriorityAndDifferentLocalTSO(c *C) { for serverName, server := range cluster.GetServers() { tsoAllocatorManager := server.GetTSOAllocatorManager() localAllocatorLeaders, err := tsoAllocatorManager.GetHoldingLocalAllocatorLeaders() - c.Assert(err, IsNil) + re.NoError(err) for _, localAllocatorLeader := range localAllocatorLeaders { - s.testTSOSuffix(c, cluster, tsoAllocatorManager, localAllocatorLeader.GetDCLocation()) + testTSOSuffix(re, cluster, tsoAllocatorManager, localAllocatorLeader.GetDCLocation()) } if serverName == cluster.GetLeader() { - s.testTSOSuffix(c, cluster, tsoAllocatorManager, tso.GlobalDCLocation) + testTSOSuffix(re, cluster, tsoAllocatorManager, tso.GlobalDCLocation) } } } @@ -211,29 +196,29 @@ func waitAllocatorPriorityCheck(cluster *tests.TestCluster) { wg.Wait() } -func (s *testAllocatorSuite) testTSOSuffix(c *C, cluster *tests.TestCluster, am *tso.AllocatorManager, dcLocation string) { +func testTSOSuffix(re *require.Assertions, cluster *tests.TestCluster, am *tso.AllocatorManager, dcLocation string) { suffixBits := am.GetSuffixBits() - c.Assert(suffixBits, Greater, 0) + re.Greater(suffixBits, 0) var suffix int64 // The suffix of a Global TSO will always be 0 if dcLocation != tso.GlobalDCLocation { suffixResp, err := etcdutil.EtcdKVGet( cluster.GetEtcdClient(), am.GetLocalTSOSuffixPath(dcLocation)) - c.Assert(err, IsNil) - c.Assert(suffixResp.Kvs, HasLen, 1) + re.NoError(err) + re.Len(suffixResp.Kvs, 1) suffix, err = strconv.ParseInt(string(suffixResp.Kvs[0].Value), 10, 64) - c.Assert(err, IsNil) - c.Assert(suffixBits, GreaterEqual, tso.CalSuffixBits(int32(suffix))) + re.NoError(err) + re.GreaterOrEqual(suffixBits, tso.CalSuffixBits(int32(suffix))) } allocator, err := am.GetAllocator(dcLocation) - c.Assert(err, IsNil) + re.NoError(err) var tso pdpb.Timestamp - testutil.WaitUntil(c, func() bool { + testutil.Eventually(re, func() bool { tso, err = allocator.GenerateTSO(1) - c.Assert(err, IsNil) + re.NoError(err) return tso.GetPhysical() != 0 }) // Test whether the TSO has the right suffix - c.Assert(suffix, Equals, tso.Logical&((1<>timestamp.GetSuffixBits(), GreaterEqual, req.GetCount()) + re.Greater(timestamp.GetPhysical(), int64(0)) + re.GreaterOrEqual(uint32(timestamp.GetLogical())>>timestamp.GetSuffixBits(), req.GetCount()) return timestamp } -func testGetTimestamp(c *C, ctx context.Context, pdCli pdpb.PDClient, req *pdpb.TsoRequest) *pdpb.Timestamp { +func testGetTimestamp(re *require.Assertions, ctx context.Context, pdCli pdpb.PDClient, req *pdpb.TsoRequest) *pdpb.Timestamp { tsoClient, err := pdCli.Tso(ctx) - c.Assert(err, IsNil) + re.NoError(err) defer tsoClient.CloseSend() - err = tsoClient.Send(req) - c.Assert(err, IsNil) + re.NoError(tsoClient.Send(req)) resp, err := tsoClient.Recv() - c.Assert(err, IsNil) - return checkAndReturnTimestampResponse(c, req, resp) -} - -func Test(t *testing.T) { - TestingT(t) + re.NoError(err) + return checkAndReturnTimestampResponse(re, req, resp) } func TestMain(m *testing.M) { diff --git a/tests/server/tso/consistency_test.go b/tests/server/tso/consistency_test.go index 170a1b4e9a8..430160cd5ac 100644 --- a/tests/server/tso/consistency_test.go +++ b/tests/server/tso/consistency_test.go @@ -20,11 +20,13 @@ package tso_test import ( "context" "sync" + "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/grpcutil" "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/pkg/tsoutil" @@ -33,9 +35,8 @@ import ( "github.com/tikv/pd/tests" ) -var _ = Suite(&testTSOConsistencySuite{}) - -type testTSOConsistencySuite struct { +type tsoConsistencyTestSuite struct { + suite.Suite ctx context.Context cancel context.CancelFunc @@ -46,42 +47,44 @@ type testTSOConsistencySuite struct { tsPool map[uint64]struct{} } -func (s *testTSOConsistencySuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) - s.dcClientMap = make(map[string]pdpb.PDClient) - s.tsPool = make(map[uint64]struct{}) +func TestTSOConsistencyTestSuite(t *testing.T) { + suite.Run(t, new(tsoConsistencyTestSuite)) +} + +func (suite *tsoConsistencyTestSuite) SetupSuite() { + suite.ctx, suite.cancel = context.WithCancel(context.Background()) + suite.dcClientMap = make(map[string]pdpb.PDClient) + suite.tsPool = make(map[uint64]struct{}) } -func (s *testTSOConsistencySuite) TearDownSuite(c *C) { - s.cancel() +func (suite *tsoConsistencyTestSuite) TearDownSuite() { + suite.cancel() } // TestNormalGlobalTSO is used to test the normal way of global TSO generation. -func (s *testTSOConsistencySuite) TestNormalGlobalTSO(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 1) +func (suite *tsoConsistencyTestSuite) TestNormalGlobalTSO() { + cluster, err := tests.NewTestCluster(suite.ctx, 1) defer cluster.Destroy() - c.Assert(err, IsNil) - - err = cluster.RunInitialServers() - c.Assert(err, IsNil) + suite.NoError(err) + suite.NoError(cluster.RunInitialServers()) cluster.WaitLeader() leaderServer := cluster.GetServer(cluster.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClient(c, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClientWithTestify(suite.Require(), leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() req := &pdpb.TsoRequest{ Header: testutil.NewRequestHeader(clusterID), Count: uint32(tsoCount), DcLocation: tso.GlobalDCLocation, } - s.requestGlobalTSOConcurrently(c, grpcPDClient, req) + suite.requestGlobalTSOConcurrently(grpcPDClient, req) // Test Global TSO after the leader change leaderServer.GetServer().GetMember().ResetLeader() cluster.WaitLeader() - s.requestGlobalTSOConcurrently(c, grpcPDClient, req) + suite.requestGlobalTSOConcurrently(grpcPDClient, req) } -func (s *testTSOConsistencySuite) requestGlobalTSOConcurrently(c *C, grpcPDClient pdpb.PDClient, req *pdpb.TsoRequest) { +func (suite *tsoConsistencyTestSuite) requestGlobalTSOConcurrently(grpcPDClient pdpb.PDClient, req *pdpb.TsoRequest) { var wg sync.WaitGroup wg.Add(tsoRequestConcurrencyNumber) for i := 0; i < tsoRequestConcurrencyNumber; i++ { @@ -92,9 +95,9 @@ func (s *testTSOConsistencySuite) requestGlobalTSOConcurrently(c *C, grpcPDClien Logical: 0, } for j := 0; j < tsoRequestRound; j++ { - ts := s.testGetNormalGlobalTimestamp(c, grpcPDClient, req) + ts := suite.testGetNormalGlobalTimestamp(grpcPDClient, req) // Check whether the TSO fallbacks - c.Assert(tsoutil.CompareTimestamp(ts, last), Equals, 1) + suite.Equal(1, tsoutil.CompareTimestamp(ts, last)) last = ts time.Sleep(10 * time.Millisecond) } @@ -103,49 +106,47 @@ func (s *testTSOConsistencySuite) requestGlobalTSOConcurrently(c *C, grpcPDClien wg.Wait() } -func (s *testTSOConsistencySuite) testGetNormalGlobalTimestamp(c *C, pdCli pdpb.PDClient, req *pdpb.TsoRequest) *pdpb.Timestamp { +func (suite *tsoConsistencyTestSuite) testGetNormalGlobalTimestamp(pdCli pdpb.PDClient, req *pdpb.TsoRequest) *pdpb.Timestamp { ctx, cancel := context.WithCancel(context.Background()) defer cancel() tsoClient, err := pdCli.Tso(ctx) - c.Assert(err, IsNil) + suite.NoError(err) defer tsoClient.CloseSend() - err = tsoClient.Send(req) - c.Assert(err, IsNil) + suite.NoError(tsoClient.Send(req)) resp, err := tsoClient.Recv() - c.Assert(err, IsNil) - c.Assert(resp.GetCount(), Equals, req.GetCount()) + suite.NoError(err) + suite.Equal(req.GetCount(), resp.GetCount()) res := resp.GetTimestamp() - c.Assert(res.GetPhysical(), Greater, int64(0)) - c.Assert(uint32(res.GetLogical())>>res.GetSuffixBits(), GreaterEqual, req.GetCount()) + suite.Greater(res.GetPhysical(), int64(0)) + suite.GreaterOrEqual(uint32(res.GetLogical())>>res.GetSuffixBits(), req.GetCount()) return res } // TestSynchronizedGlobalTSO is used to test the synchronized way of global TSO generation. -func (s *testTSOConsistencySuite) TestSynchronizedGlobalTSO(c *C) { +func (suite *tsoConsistencyTestSuite) TestSynchronizedGlobalTSO() { dcLocationConfig := map[string]string{ "pd1": "dc-1", "pd2": "dc-2", "pd3": "dc-3", } dcLocationNum := len(dcLocationConfig) - cluster, err := tests.NewTestCluster(s.ctx, dcLocationNum, func(conf *config.Config, serverName string) { + cluster, err := tests.NewTestCluster(suite.ctx, dcLocationNum, func(conf *config.Config, serverName string) { conf.EnableLocalTSO = true conf.Labels[config.ZoneLabel] = dcLocationConfig[serverName] }) defer cluster.Destroy() - c.Assert(err, IsNil) - - err = cluster.RunInitialServers() - c.Assert(err, IsNil) + suite.NoError(err) + suite.NoError(cluster.RunInitialServers()) - cluster.WaitAllLeaders(c, dcLocationConfig) + re := suite.Require() + cluster.WaitAllLeaders(re, dcLocationConfig) - s.leaderServer = cluster.GetServer(cluster.GetLeader()) - c.Assert(s.leaderServer, NotNil) - s.dcClientMap[tso.GlobalDCLocation] = testutil.MustNewGrpcClient(c, s.leaderServer.GetAddr()) + suite.leaderServer = cluster.GetServer(cluster.GetLeader()) + suite.NotNil(suite.leaderServer) + suite.dcClientMap[tso.GlobalDCLocation] = testutil.MustNewGrpcClientWithTestify(re, suite.leaderServer.GetAddr()) for _, dcLocation := range dcLocationConfig { - pdName := s.leaderServer.GetAllocatorLeader(dcLocation).GetName() - s.dcClientMap[dcLocation] = testutil.MustNewGrpcClient(c, cluster.GetServer(pdName).GetAddr()) + pdName := suite.leaderServer.GetAllocatorLeader(dcLocation).GetName() + suite.dcClientMap[dcLocation] = testutil.MustNewGrpcClientWithTestify(re, cluster.GetServer(pdName).GetAddr()) } ctx, cancel := context.WithCancel(context.Background()) @@ -155,14 +156,14 @@ func (s *testTSOConsistencySuite) TestSynchronizedGlobalTSO(c *C) { // Get some local TSOs first oldLocalTSOs := make([]*pdpb.Timestamp, 0, dcLocationNum) for _, dcLocation := range dcLocationConfig { - localTSO := s.getTimestampByDC(ctx, c, cluster, tsoCount, dcLocation) + localTSO := suite.getTimestampByDC(ctx, cluster, dcLocation) oldLocalTSOs = append(oldLocalTSOs, localTSO) - c.Assert(tsoutil.CompareTimestamp(maxGlobalTSO, localTSO), Equals, -1) + suite.Equal(-1, tsoutil.CompareTimestamp(maxGlobalTSO, localTSO)) } // Get a global TSO then - globalTSO := s.getTimestampByDC(ctx, c, cluster, tsoCount, tso.GlobalDCLocation) + globalTSO := suite.getTimestampByDC(ctx, cluster, tso.GlobalDCLocation) for _, oldLocalTSO := range oldLocalTSOs { - c.Assert(tsoutil.CompareTimestamp(globalTSO, oldLocalTSO), Equals, 1) + suite.Equal(1, tsoutil.CompareTimestamp(globalTSO, oldLocalTSO)) } if tsoutil.CompareTimestamp(maxGlobalTSO, globalTSO) < 0 { maxGlobalTSO = globalTSO @@ -170,153 +171,147 @@ func (s *testTSOConsistencySuite) TestSynchronizedGlobalTSO(c *C) { // Get some local TSOs again newLocalTSOs := make([]*pdpb.Timestamp, 0, dcLocationNum) for _, dcLocation := range dcLocationConfig { - newLocalTSOs = append(newLocalTSOs, s.getTimestampByDC(ctx, c, cluster, tsoCount, dcLocation)) + newLocalTSOs = append(newLocalTSOs, suite.getTimestampByDC(ctx, cluster, dcLocation)) } for _, newLocalTSO := range newLocalTSOs { - c.Assert(tsoutil.CompareTimestamp(maxGlobalTSO, newLocalTSO), Equals, -1) + suite.Equal(-1, tsoutil.CompareTimestamp(maxGlobalTSO, newLocalTSO)) } } } -func (s *testTSOConsistencySuite) getTimestampByDC(ctx context.Context, c *C, cluster *tests.TestCluster, n uint32, dcLocation string) *pdpb.Timestamp { +func (suite *tsoConsistencyTestSuite) getTimestampByDC(ctx context.Context, cluster *tests.TestCluster, dcLocation string) *pdpb.Timestamp { req := &pdpb.TsoRequest{ - Header: testutil.NewRequestHeader(s.leaderServer.GetClusterID()), - Count: n, + Header: testutil.NewRequestHeader(suite.leaderServer.GetClusterID()), + Count: tsoCount, DcLocation: dcLocation, } - pdClient, ok := s.dcClientMap[dcLocation] - c.Assert(ok, IsTrue) - forwardedHost := cluster.GetServer(s.leaderServer.GetAllocatorLeader(dcLocation).GetName()).GetAddr() + pdClient, ok := suite.dcClientMap[dcLocation] + suite.True(ok) + forwardedHost := cluster.GetServer(suite.leaderServer.GetAllocatorLeader(dcLocation).GetName()).GetAddr() ctx = grpcutil.BuildForwardContext(ctx, forwardedHost) tsoClient, err := pdClient.Tso(ctx) - c.Assert(err, IsNil) + suite.NoError(err) defer tsoClient.CloseSend() - err = tsoClient.Send(req) - c.Assert(err, IsNil) + suite.NoError(tsoClient.Send(req)) resp, err := tsoClient.Recv() - c.Assert(err, IsNil) - return checkAndReturnTimestampResponse(c, req, resp) + suite.NoError(err) + return checkAndReturnTimestampResponse(suite.Require(), req, resp) } -func (s *testTSOConsistencySuite) TestSynchronizedGlobalTSOOverflow(c *C) { +func (suite *tsoConsistencyTestSuite) TestSynchronizedGlobalTSOOverflow() { dcLocationConfig := map[string]string{ "pd1": "dc-1", "pd2": "dc-2", "pd3": "dc-3", } dcLocationNum := len(dcLocationConfig) - cluster, err := tests.NewTestCluster(s.ctx, dcLocationNum, func(conf *config.Config, serverName string) { + cluster, err := tests.NewTestCluster(suite.ctx, dcLocationNum, func(conf *config.Config, serverName string) { conf.EnableLocalTSO = true conf.Labels[config.ZoneLabel] = dcLocationConfig[serverName] }) defer cluster.Destroy() - c.Assert(err, IsNil) - - err = cluster.RunInitialServers() - c.Assert(err, IsNil) + suite.NoError(err) + suite.NoError(cluster.RunInitialServers()) - cluster.WaitAllLeaders(c, dcLocationConfig) + re := suite.Require() + cluster.WaitAllLeaders(re, dcLocationConfig) - s.leaderServer = cluster.GetServer(cluster.GetLeader()) - c.Assert(s.leaderServer, NotNil) - s.dcClientMap[tso.GlobalDCLocation] = testutil.MustNewGrpcClient(c, s.leaderServer.GetAddr()) + suite.leaderServer = cluster.GetServer(cluster.GetLeader()) + suite.NotNil(suite.leaderServer) + suite.dcClientMap[tso.GlobalDCLocation] = testutil.MustNewGrpcClientWithTestify(re, suite.leaderServer.GetAddr()) for _, dcLocation := range dcLocationConfig { - pdName := s.leaderServer.GetAllocatorLeader(dcLocation).GetName() - s.dcClientMap[dcLocation] = testutil.MustNewGrpcClient(c, cluster.GetServer(pdName).GetAddr()) + pdName := suite.leaderServer.GetAllocatorLeader(dcLocation).GetName() + suite.dcClientMap[dcLocation] = testutil.MustNewGrpcClientWithTestify(re, cluster.GetServer(pdName).GetAddr()) } ctx, cancel := context.WithCancel(context.Background()) defer cancel() - c.Assert(failpoint.Enable("github.com/tikv/pd/server/tso/globalTSOOverflow", `return(true)`), IsNil) - s.getTimestampByDC(ctx, c, cluster, tsoCount, tso.GlobalDCLocation) - failpoint.Disable("github.com/tikv/pd/server/tso/globalTSOOverflow") + suite.NoError(failpoint.Enable("github.com/tikv/pd/server/tso/globalTSOOverflow", `return(true)`)) + suite.getTimestampByDC(ctx, cluster, tso.GlobalDCLocation) + suite.NoError(failpoint.Disable("github.com/tikv/pd/server/tso/globalTSOOverflow")) } -func (s *testTSOConsistencySuite) TestLocalAllocatorLeaderChange(c *C) { - c.Assert(failpoint.Enable("github.com/tikv/pd/server/mockLocalAllocatorLeaderChange", `return(true)`), IsNil) - defer failpoint.Disable("github.com/tikv/pd/server/mockLocalAllocatorLeaderChange") +func (suite *tsoConsistencyTestSuite) TestLocalAllocatorLeaderChange() { + suite.NoError(failpoint.Enable("github.com/tikv/pd/server/mockLocalAllocatorLeaderChange", `return(true)`)) dcLocationConfig := map[string]string{ "pd1": "dc-1", } dcLocationNum := len(dcLocationConfig) - cluster, err := tests.NewTestCluster(s.ctx, dcLocationNum, func(conf *config.Config, serverName string) { + cluster, err := tests.NewTestCluster(suite.ctx, dcLocationNum, func(conf *config.Config, serverName string) { conf.EnableLocalTSO = true conf.Labels[config.ZoneLabel] = dcLocationConfig[serverName] }) defer cluster.Destroy() - c.Assert(err, IsNil) - - err = cluster.RunInitialServers() - c.Assert(err, IsNil) + suite.NoError(err) + suite.NoError(cluster.RunInitialServers()) - cluster.WaitAllLeaders(c, dcLocationConfig) + re := suite.Require() + cluster.WaitAllLeaders(re, dcLocationConfig) - s.leaderServer = cluster.GetServer(cluster.GetLeader()) - c.Assert(s.leaderServer, NotNil) - s.dcClientMap[tso.GlobalDCLocation] = testutil.MustNewGrpcClient(c, s.leaderServer.GetAddr()) + suite.leaderServer = cluster.GetServer(cluster.GetLeader()) + suite.NotNil(suite.leaderServer) + suite.dcClientMap[tso.GlobalDCLocation] = testutil.MustNewGrpcClientWithTestify(re, suite.leaderServer.GetAddr()) for _, dcLocation := range dcLocationConfig { - pdName := s.leaderServer.GetAllocatorLeader(dcLocation).GetName() - s.dcClientMap[dcLocation] = testutil.MustNewGrpcClient(c, cluster.GetServer(pdName).GetAddr()) + pdName := suite.leaderServer.GetAllocatorLeader(dcLocation).GetName() + suite.dcClientMap[dcLocation] = testutil.MustNewGrpcClientWithTestify(re, cluster.GetServer(pdName).GetAddr()) } ctx, cancel := context.WithCancel(context.Background()) defer cancel() - s.getTimestampByDC(ctx, c, cluster, tsoCount, tso.GlobalDCLocation) + suite.getTimestampByDC(ctx, cluster, tso.GlobalDCLocation) + suite.NoError(failpoint.Disable("github.com/tikv/pd/server/mockLocalAllocatorLeaderChange")) } -func (s *testTSOConsistencySuite) TestLocalTSO(c *C) { +func (suite *tsoConsistencyTestSuite) TestLocalTSO() { dcLocationConfig := map[string]string{ "pd1": "dc-1", "pd2": "dc-2", "pd3": "dc-3", } dcLocationNum := len(dcLocationConfig) - cluster, err := tests.NewTestCluster(s.ctx, dcLocationNum, func(conf *config.Config, serverName string) { + cluster, err := tests.NewTestCluster(suite.ctx, dcLocationNum, func(conf *config.Config, serverName string) { conf.EnableLocalTSO = true conf.Labels[config.ZoneLabel] = dcLocationConfig[serverName] }) defer cluster.Destroy() - c.Assert(err, IsNil) + suite.NoError(err) + suite.NoError(cluster.RunInitialServers()) - err = cluster.RunInitialServers() - c.Assert(err, IsNil) - - cluster.WaitAllLeaders(c, dcLocationConfig) - s.testTSO(c, cluster, dcLocationConfig, nil) + cluster.WaitAllLeaders(suite.Require(), dcLocationConfig) + suite.testTSO(cluster, dcLocationConfig, nil) } -func (s *testTSOConsistencySuite) checkTSOUnique(tso *pdpb.Timestamp) bool { - s.tsPoolMutex.Lock() - defer s.tsPoolMutex.Unlock() +func (suite *tsoConsistencyTestSuite) checkTSOUnique(tso *pdpb.Timestamp) bool { + suite.tsPoolMutex.Lock() + defer suite.tsPoolMutex.Unlock() ts := tsoutil.GenerateTS(tso) - if _, exist := s.tsPool[ts]; exist { + if _, exist := suite.tsPool[ts]; exist { return false } - s.tsPool[ts] = struct{}{} + suite.tsPool[ts] = struct{}{} return true } -func (s *testTSOConsistencySuite) TestLocalTSOAfterMemberChanged(c *C) { +func (suite *tsoConsistencyTestSuite) TestLocalTSOAfterMemberChanged() { dcLocationConfig := map[string]string{ "pd1": "dc-1", "pd2": "dc-2", "pd3": "dc-3", } dcLocationNum := len(dcLocationConfig) - cluster, err := tests.NewTestCluster(s.ctx, dcLocationNum, func(conf *config.Config, serverName string) { + cluster, err := tests.NewTestCluster(suite.ctx, dcLocationNum, func(conf *config.Config, serverName string) { conf.EnableLocalTSO = true conf.Labels[config.ZoneLabel] = dcLocationConfig[serverName] }) defer cluster.Destroy() - c.Assert(err, IsNil) - - err = cluster.RunInitialServers() - c.Assert(err, IsNil) + suite.NoError(err) + suite.NoError(cluster.RunInitialServers()) - cluster.WaitAllLeaders(c, dcLocationConfig) + re := suite.Require() + cluster.WaitAllLeaders(re, dcLocationConfig) leaderServer := cluster.GetServer(cluster.GetLeader()) - leaderCli := testutil.MustNewGrpcClient(c, leaderServer.GetAddr()) + leaderCli := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) req := &pdpb.TsoRequest{ Header: testutil.NewRequestHeader(cluster.GetCluster().GetId()), Count: tsoCount, @@ -324,40 +319,40 @@ func (s *testTSOConsistencySuite) TestLocalTSOAfterMemberChanged(c *C) { } ctx, cancel := context.WithCancel(context.Background()) ctx = grpcutil.BuildForwardContext(ctx, leaderServer.GetAddr()) - previousTS := testGetTimestamp(c, ctx, leaderCli, req) + previousTS := testGetTimestamp(re, ctx, leaderCli, req) cancel() // Wait for all nodes becoming healthy. time.Sleep(time.Second * 5) // Mock the situation that the system time of PD nodes in dc-4 is slower than others. - c.Assert(failpoint.Enable("github.com/tikv/pd/server/tso/systemTimeSlow", `return(true)`), IsNil) + suite.NoError(failpoint.Enable("github.com/tikv/pd/server/tso/systemTimeSlow", `return(true)`)) // Join a new dc-location - pd4, err := cluster.Join(s.ctx, func(conf *config.Config, serverName string) { + pd4, err := cluster.Join(suite.ctx, func(conf *config.Config, serverName string) { conf.EnableLocalTSO = true conf.Labels[config.ZoneLabel] = "dc-4" }) - c.Assert(err, IsNil) - err = pd4.Run() - c.Assert(err, IsNil) + suite.NoError(err) + suite.NoError(pd4.Run()) dcLocationConfig["pd4"] = "dc-4" cluster.CheckClusterDCLocation() - testutil.WaitUntil(c, func() bool { - leaderName := cluster.WaitAllocatorLeader("dc-4") - return leaderName != "" - }) - s.testTSO(c, cluster, dcLocationConfig, previousTS) + re.NotEqual("", cluster.WaitAllocatorLeader( + "dc-4", + tests.WithRetryTimes(90), tests.WithWaitInterval(time.Second), + )) + suite.testTSO(cluster, dcLocationConfig, previousTS) - failpoint.Disable("github.com/tikv/pd/server/tso/systemTimeSlow") + suite.NoError(failpoint.Disable("github.com/tikv/pd/server/tso/systemTimeSlow")) } -func (s *testTSOConsistencySuite) testTSO(c *C, cluster *tests.TestCluster, dcLocationConfig map[string]string, previousTS *pdpb.Timestamp) { +func (suite *tsoConsistencyTestSuite) testTSO(cluster *tests.TestCluster, dcLocationConfig map[string]string, previousTS *pdpb.Timestamp) { + re := suite.Require() leaderServer := cluster.GetServer(cluster.GetLeader()) dcClientMap := make(map[string]pdpb.PDClient) for _, dcLocation := range dcLocationConfig { pdName := leaderServer.GetAllocatorLeader(dcLocation).GetName() - dcClientMap[dcLocation] = testutil.MustNewGrpcClient(c, cluster.GetServer(pdName).GetAddr()) + dcClientMap[dcLocation] = testutil.MustNewGrpcClientWithTestify(re, cluster.GetServer(pdName).GetAddr()) } var wg sync.WaitGroup @@ -381,68 +376,49 @@ func (s *testTSOConsistencySuite) testTSO(c *C, cluster *tests.TestCluster, dcLo } ctx, cancel := context.WithCancel(context.Background()) ctx = grpcutil.BuildForwardContext(ctx, cluster.GetServer(leaderServer.GetAllocatorLeader(dcLocation).GetName()).GetAddr()) - ts := testGetTimestamp(c, ctx, dcClientMap[dcLocation], req) + ts := testGetTimestamp(re, ctx, dcClientMap[dcLocation], req) cancel() lastTS := lastList[dcLocation] // Check whether the TSO fallbacks - c.Assert(tsoutil.CompareTimestamp(ts, lastTS), Equals, 1) + suite.Equal(1, tsoutil.CompareTimestamp(ts, lastTS)) if previousTS != nil { // Because we have a Global TSO synchronization, even though the system time // of the PD nodes in dc-4 is slower, its TSO will still be big enough. - c.Assert(tsoutil.CompareTimestamp(ts, previousTS), Equals, 1) + suite.Equal(1, tsoutil.CompareTimestamp(ts, previousTS)) } lastList[dcLocation] = ts // Check whether the TSO is not unique - c.Assert(s.checkTSOUnique(ts), IsTrue) + suite.True(suite.checkTSOUnique(ts)) } time.Sleep(10 * time.Millisecond) } }() } wg.Wait() - - failpoint.Disable("github.com/tikv/pd/server/tso/systemTimeSlow") -} - -var _ = Suite(&testFallbackTSOConsistencySuite{}) - -type testFallbackTSOConsistencySuite struct { - ctx context.Context - cancel context.CancelFunc - cluster *tests.TestCluster - grpcPDClient pdpb.PDClient - server *tests.TestServer } -func (s *testFallbackTSOConsistencySuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) - c.Assert(failpoint.Enable("github.com/tikv/pd/server/tso/fallBackSync", `return(true)`), IsNil) - c.Assert(failpoint.Enable("github.com/tikv/pd/server/tso/fallBackUpdate", `return(true)`), IsNil) +func TestFallbackTSOConsistency(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + re.NoError(failpoint.Enable("github.com/tikv/pd/server/tso/fallBackSync", `return(true)`)) + re.NoError(failpoint.Enable("github.com/tikv/pd/server/tso/fallBackUpdate", `return(true)`)) var err error - s.cluster, err = tests.NewTestCluster(s.ctx, 1) - c.Assert(err, IsNil) + cluster, err := tests.NewTestCluster(ctx, 1) + re.NoError(err) + defer cluster.Destroy() - err = s.cluster.RunInitialServers() - c.Assert(err, IsNil) - s.cluster.WaitLeader() + re.NoError(cluster.RunInitialServers()) + cluster.WaitLeader() - s.server = s.cluster.GetServer(s.cluster.GetLeader()) - s.grpcPDClient = testutil.MustNewGrpcClient(c, s.server.GetAddr()) - svr := s.server.GetServer() + server := cluster.GetServer(cluster.GetLeader()) + grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, server.GetAddr()) + svr := server.GetServer() svr.Close() - failpoint.Disable("github.com/tikv/pd/server/tso/fallBackSync") - failpoint.Disable("github.com/tikv/pd/server/tso/fallBackUpdate") - err = svr.Run() - c.Assert(err, IsNil) - s.cluster.WaitLeader() -} - -func (s *testFallbackTSOConsistencySuite) TearDownSuite(c *C) { - s.cancel() - s.cluster.Destroy() -} - -func (s *testFallbackTSOConsistencySuite) TestFallbackTSOConsistency(c *C) { + re.NoError(failpoint.Disable("github.com/tikv/pd/server/tso/fallBackSync")) + re.NoError(failpoint.Disable("github.com/tikv/pd/server/tso/fallBackUpdate")) + re.NoError(svr.Run()) + cluster.WaitLeader() var wg sync.WaitGroup wg.Add(tsoRequestConcurrencyNumber) for i := 0; i < tsoRequestConcurrencyNumber; i++ { @@ -453,8 +429,23 @@ func (s *testFallbackTSOConsistencySuite) TestFallbackTSOConsistency(c *C) { Logical: 0, } for j := 0; j < tsoRequestRound; j++ { - ts := s.testGetTSO(c, 10) - c.Assert(tsoutil.CompareTimestamp(ts, last), Equals, 1) + clusterID := server.GetClusterID() + req := &pdpb.TsoRequest{ + Header: testutil.NewRequestHeader(clusterID), + Count: 10, + DcLocation: tso.GlobalDCLocation, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tsoClient, err := grpcPDClient.Tso(ctx) + re.NoError(err) + defer tsoClient.CloseSend() + re.NoError(tsoClient.Send(req)) + resp, err := tsoClient.Recv() + re.NoError(err) + ts := checkAndReturnTimestampResponse(re, req, resp) + re.Equal(1, tsoutil.CompareTimestamp(ts, last)) last = ts time.Sleep(10 * time.Millisecond) } @@ -462,23 +453,3 @@ func (s *testFallbackTSOConsistencySuite) TestFallbackTSOConsistency(c *C) { } wg.Wait() } - -func (s *testFallbackTSOConsistencySuite) testGetTSO(c *C, n uint32) *pdpb.Timestamp { - clusterID := s.server.GetClusterID() - req := &pdpb.TsoRequest{ - Header: testutil.NewRequestHeader(clusterID), - Count: n, - DcLocation: tso.GlobalDCLocation, - } - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - tsoClient, err := s.grpcPDClient.Tso(ctx) - c.Assert(err, IsNil) - defer tsoClient.CloseSend() - err = tsoClient.Send(req) - c.Assert(err, IsNil) - resp, err := tsoClient.Recv() - c.Assert(err, IsNil) - return checkAndReturnTimestampResponse(c, req, resp) -} diff --git a/tests/server/tso/global_tso_test.go b/tests/server/tso/global_tso_test.go index 1086751fa08..795841b6830 100644 --- a/tests/server/tso/global_tso_test.go +++ b/tests/server/tso/global_tso_test.go @@ -19,13 +19,13 @@ package tso_test import ( "context" - "strings" "sync" + "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/grpcutil" "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server/tso" @@ -41,32 +41,19 @@ import ( // which will coordinate and synchronize a TSO with other Local TSO Allocator // leaders. -var _ = Suite(&testNormalGlobalTSOSuite{}) - -type testNormalGlobalTSOSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *testNormalGlobalTSOSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) -} - -func (s *testNormalGlobalTSOSuite) TearDownSuite(c *C) { - s.cancel() -} - -func (s *testNormalGlobalTSOSuite) TestConcurrentlyReset(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 1) +func TestConcurrentlyReset(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 1) defer cluster.Destroy() - c.Assert(err, IsNil) + re.NoError(err) - err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(cluster.RunInitialServers()) cluster.WaitLeader() leader := cluster.GetServer(cluster.GetLeader()) - c.Assert(leader, NotNil) + re.NotNil(leader) var wg sync.WaitGroup wg.Add(2) @@ -84,41 +71,41 @@ func (s *testNormalGlobalTSOSuite) TestConcurrentlyReset(c *C) { wg.Wait() } -func (s *testNormalGlobalTSOSuite) TestZeroTSOCount(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 1) +func TestZeroTSOCount(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 1) defer cluster.Destroy() - c.Assert(err, IsNil) - - err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) + re.NoError(cluster.RunInitialServers()) cluster.WaitLeader() leaderServer := cluster.GetServer(cluster.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClient(c, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() req := &pdpb.TsoRequest{ Header: testutil.NewRequestHeader(clusterID), DcLocation: tso.GlobalDCLocation, } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() tsoClient, err := grpcPDClient.Tso(ctx) - c.Assert(err, IsNil) + re.NoError(err) defer tsoClient.CloseSend() - err = tsoClient.Send(req) - c.Assert(err, IsNil) + re.NoError(tsoClient.Send(req)) _, err = tsoClient.Recv() - c.Assert(err, NotNil) + re.Error(err) } -func (s *testNormalGlobalTSOSuite) TestRequestFollower(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 2) - c.Assert(err, IsNil) +func TestRequestFollower(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 2) + re.NoError(err) defer cluster.Destroy() - err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(cluster.RunInitialServers()) cluster.WaitLeader() var followerServer *tests.TestServer @@ -127,79 +114,74 @@ func (s *testNormalGlobalTSOSuite) TestRequestFollower(c *C) { followerServer = s } } - c.Assert(followerServer, NotNil) + re.NotNil(followerServer) - grpcPDClient := testutil.MustNewGrpcClient(c, followerServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, followerServer.GetAddr()) clusterID := followerServer.GetClusterID() req := &pdpb.TsoRequest{ Header: testutil.NewRequestHeader(clusterID), Count: 1, DcLocation: tso.GlobalDCLocation, } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() ctx = grpcutil.BuildForwardContext(ctx, followerServer.GetAddr()) tsoClient, err := grpcPDClient.Tso(ctx) - c.Assert(err, IsNil) + re.NoError(err) defer tsoClient.CloseSend() start := time.Now() - err = tsoClient.Send(req) - c.Assert(err, IsNil) + re.NoError(tsoClient.Send(req)) _, err = tsoClient.Recv() - c.Assert(err, NotNil) - c.Assert(strings.Contains(err.Error(), "generate timestamp failed"), IsTrue) + re.Error(err) + re.Contains(err.Error(), "generate timestamp failed") // Requesting follower should fail fast, or the unavailable time will be // too long. - c.Assert(time.Since(start), Less, time.Second) + re.Less(time.Since(start), time.Second) } // In some cases, when a TSO request arrives, the SyncTimestamp may not finish yet. // This test is used to simulate this situation and verify that the retry mechanism. -func (s *testNormalGlobalTSOSuite) TestDelaySyncTimestamp(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 2) - c.Assert(err, IsNil) +func TestDelaySyncTimestamp(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 2) + re.NoError(err) defer cluster.Destroy() - - err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(cluster.RunInitialServers()) cluster.WaitLeader() var leaderServer, nextLeaderServer *tests.TestServer leaderServer = cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer, NotNil) + re.NotNil(leaderServer) for _, s := range cluster.GetServers() { if s.GetConfig().Name != cluster.GetLeader() { nextLeaderServer = s } } - c.Assert(nextLeaderServer, NotNil) + re.NotNil(nextLeaderServer) - grpcPDClient := testutil.MustNewGrpcClient(c, nextLeaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, nextLeaderServer.GetAddr()) clusterID := nextLeaderServer.GetClusterID() req := &pdpb.TsoRequest{ Header: testutil.NewRequestHeader(clusterID), Count: 1, DcLocation: tso.GlobalDCLocation, } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - c.Assert(failpoint.Enable("github.com/tikv/pd/server/tso/delaySyncTimestamp", `return(true)`), IsNil) + re.NoError(failpoint.Enable("github.com/tikv/pd/server/tso/delaySyncTimestamp", `return(true)`)) // Make the old leader resign and wait for the new leader to get a lease leaderServer.ResignLeader() - c.Assert(nextLeaderServer.WaitLeader(), IsTrue) + re.True(nextLeaderServer.WaitLeader()) ctx = grpcutil.BuildForwardContext(ctx, nextLeaderServer.GetAddr()) tsoClient, err := grpcPDClient.Tso(ctx) - c.Assert(err, IsNil) + re.NoError(err) defer tsoClient.CloseSend() - err = tsoClient.Send(req) - c.Assert(err, IsNil) + re.NoError(tsoClient.Send(req)) resp, err := tsoClient.Recv() - c.Assert(err, IsNil) - c.Assert(checkAndReturnTimestampResponse(c, req, resp), NotNil) - failpoint.Disable("github.com/tikv/pd/server/tso/delaySyncTimestamp") + re.NoError(err) + re.NotNil(checkAndReturnTimestampResponse(re, req, resp)) + re.NoError(failpoint.Disable("github.com/tikv/pd/server/tso/delaySyncTimestamp")) } diff --git a/tests/server/tso/manager_test.go b/tests/server/tso/manager_test.go index 26fa07cc1d5..5ea8bc4be92 100644 --- a/tests/server/tso/manager_test.go +++ b/tests/server/tso/manager_test.go @@ -20,10 +20,11 @@ package tso_test import ( "context" "strconv" + "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/failpoint" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/etcdutil" "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server/config" @@ -32,24 +33,12 @@ import ( "go.etcd.io/etcd/clientv3" ) -var _ = Suite(&testManagerSuite{}) - -type testManagerSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *testManagerSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) -} - -func (s *testManagerSuite) TearDownSuite(c *C) { - s.cancel() -} - // TestClusterDCLocations will write different dc-locations to each server // and test whether we can get the whole dc-location config from each server. -func (s *testManagerSuite) TestClusterDCLocations(c *C) { +func TestClusterDCLocations(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() testCase := struct { dcLocationNumber int dcLocationConfig map[string]string @@ -65,17 +54,15 @@ func (s *testManagerSuite) TestClusterDCLocations(c *C) { }, } serverNumber := len(testCase.dcLocationConfig) - cluster, err := tests.NewTestCluster(s.ctx, serverNumber, func(conf *config.Config, serverName string) { + cluster, err := tests.NewTestCluster(ctx, serverNumber, func(conf *config.Config, serverName string) { conf.EnableLocalTSO = true conf.Labels[config.ZoneLabel] = testCase.dcLocationConfig[serverName] }) defer cluster.Destroy() - c.Assert(err, IsNil) + re.NoError(err) + re.NoError(cluster.RunInitialServers()) - err = cluster.RunInitialServers() - c.Assert(err, IsNil) - - cluster.WaitAllLeaders(c, testCase.dcLocationConfig) + cluster.WaitAllLeaders(re, testCase.dcLocationConfig) serverNameMap := make(map[uint64]string) for _, server := range cluster.GetServers() { serverNameMap[server.GetServerID()] = server.GetServer().Name() @@ -86,21 +73,24 @@ func (s *testManagerSuite) TestClusterDCLocations(c *C) { for _, server := range cluster.GetServers() { obtainedServerNumber := 0 dcLocationMap := server.GetTSOAllocatorManager().GetClusterDCLocations() - c.Assert(err, IsNil) - c.Assert(dcLocationMap, HasLen, testCase.dcLocationNumber) + re.NoError(err) + re.Len(dcLocationMap, testCase.dcLocationNumber) for obtainedDCLocation, info := range dcLocationMap { obtainedServerNumber += len(info.ServerIDs) for _, serverID := range info.ServerIDs { expectedDCLocation, exist := testCase.dcLocationConfig[serverNameMap[serverID]] - c.Assert(exist, IsTrue) - c.Assert(obtainedDCLocation, Equals, expectedDCLocation) + re.True(exist) + re.Equal(expectedDCLocation, obtainedDCLocation) } } - c.Assert(obtainedServerNumber, Equals, serverNumber) + re.Equal(serverNumber, obtainedServerNumber) } } -func (s *testManagerSuite) TestLocalTSOSuffix(c *C) { +func TestLocalTSOSuffix(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() testCase := struct { dcLocations []string dcLocationConfig map[string]string @@ -116,44 +106,45 @@ func (s *testManagerSuite) TestLocalTSOSuffix(c *C) { }, } serverNumber := len(testCase.dcLocationConfig) - cluster, err := tests.NewTestCluster(s.ctx, serverNumber, func(conf *config.Config, serverName string) { + cluster, err := tests.NewTestCluster(ctx, serverNumber, func(conf *config.Config, serverName string) { conf.EnableLocalTSO = true conf.Labels[config.ZoneLabel] = testCase.dcLocationConfig[serverName] }) defer cluster.Destroy() - c.Assert(err, IsNil) - - err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) + re.NoError(cluster.RunInitialServers()) - cluster.WaitAllLeaders(c, testCase.dcLocationConfig) + cluster.WaitAllLeaders(re, testCase.dcLocationConfig) tsoAllocatorManager := cluster.GetServer("pd1").GetTSOAllocatorManager() for _, dcLocation := range testCase.dcLocations { suffixResp, err := etcdutil.EtcdKVGet( cluster.GetEtcdClient(), tsoAllocatorManager.GetLocalTSOSuffixPath(dcLocation)) - c.Assert(err, IsNil) - c.Assert(suffixResp.Kvs, HasLen, 1) + re.NoError(err) + re.Len(suffixResp.Kvs, 1) // Test the increment of the suffix allSuffixResp, err := etcdutil.EtcdKVGet( cluster.GetEtcdClient(), tsoAllocatorManager.GetLocalTSOSuffixPathPrefix(), clientv3.WithPrefix(), clientv3.WithSort(clientv3.SortByValue, clientv3.SortAscend)) - c.Assert(err, IsNil) - c.Assert(len(allSuffixResp.Kvs), Equals, len(testCase.dcLocations)) + re.NoError(err) + re.Equal(len(testCase.dcLocations), len(allSuffixResp.Kvs)) var lastSuffixNum int64 for _, kv := range allSuffixResp.Kvs { suffixNum, err := strconv.ParseInt(string(kv.Value), 10, 64) - c.Assert(err, IsNil) - c.Assert(suffixNum, Greater, lastSuffixNum) + re.NoError(err) + re.Greater(suffixNum, lastSuffixNum) lastSuffixNum = suffixNum } } } -func (s *testManagerSuite) TestNextLeaderKey(c *C) { +func TestNextLeaderKey(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() tso.PriorityCheck = 5 * time.Second defer func() { tso.PriorityCheck = 1 * time.Minute @@ -163,32 +154,31 @@ func (s *testManagerSuite) TestNextLeaderKey(c *C) { "pd2": "dc-1", } serverNum := len(dcLocationConfig) - cluster, err := tests.NewTestCluster(s.ctx, serverNum, func(conf *config.Config, serverName string) { + cluster, err := tests.NewTestCluster(ctx, serverNum, func(conf *config.Config, serverName string) { conf.EnableLocalTSO = true conf.Labels[config.ZoneLabel] = dcLocationConfig[serverName] }) defer cluster.Destroy() - c.Assert(err, IsNil) - c.Assert(failpoint.Enable("github.com/tikv/pd/server/tso/injectNextLeaderKey", "return(true)"), IsNil) - err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) + re.NoError(failpoint.Enable("github.com/tikv/pd/server/tso/injectNextLeaderKey", "return(true)")) + re.NoError(cluster.RunInitialServers()) cluster.WaitLeader(tests.WithWaitInterval(5*time.Second), tests.WithRetryTimes(3)) // To speed up the test, we force to do the check cluster.CheckClusterDCLocation() originName := cluster.WaitAllocatorLeader("dc-1", tests.WithRetryTimes(5), tests.WithWaitInterval(5*time.Second)) - c.Assert(originName, Equals, "") - c.Assert(failpoint.Disable("github.com/tikv/pd/server/tso/injectNextLeaderKey"), IsNil) + re.Equal("", originName) + re.NoError(failpoint.Disable("github.com/tikv/pd/server/tso/injectNextLeaderKey")) cluster.CheckClusterDCLocation() originName = cluster.WaitAllocatorLeader("dc-1") - c.Assert(originName, Not(Equals), "") + re.NotEqual("", originName) for name, server := range cluster.GetServers() { if name == originName { continue } err := server.GetTSOAllocatorManager().TransferAllocatorForDCLocation("dc-1", server.GetServer().GetMember().ID()) - c.Assert(err, IsNil) - testutil.WaitUntil(c, func() bool { + re.NoError(err) + testutil.Eventually(re, func() bool { cluster.CheckClusterDCLocation() currName := cluster.WaitAllocatorLeader("dc-1") return currName == name diff --git a/tests/server/tso/tso_test.go b/tests/server/tso/tso_test.go index 27bc53d5652..8cb6b6d837f 100644 --- a/tests/server/tso/tso_test.go +++ b/tests/server/tso/tso_test.go @@ -19,81 +19,66 @@ package tso_test import ( "context" + "testing" - . "github.com/pingcap/check" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/grpcutil" "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server/config" "github.com/tikv/pd/tests" ) -var _ = Suite(&testTSOSuite{}) - -type testTSOSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *testTSOSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) -} - -func (s *testTSOSuite) TearDownSuite(c *C) { - s.cancel() -} - -func (s *testTSOSuite) TestLoadTimestamp(c *C) { +func TestLoadTimestamp(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() dcLocationConfig := map[string]string{ "pd1": "dc-1", "pd2": "dc-2", "pd3": "dc-3", } dcLocationNum := len(dcLocationConfig) - cluster, err := tests.NewTestCluster(s.ctx, dcLocationNum, func(conf *config.Config, serverName string) { + cluster, err := tests.NewTestCluster(ctx, dcLocationNum, func(conf *config.Config, serverName string) { conf.EnableLocalTSO = true conf.Labels[config.ZoneLabel] = dcLocationConfig[serverName] }) defer cluster.Destroy() - c.Assert(err, IsNil) - - err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) + re.NoError(cluster.RunInitialServers()) - cluster.WaitAllLeaders(c, dcLocationConfig) + cluster.WaitAllLeaders(re, dcLocationConfig) - lastTSMap := requestLocalTSOs(c, cluster, dcLocationConfig) + lastTSMap := requestLocalTSOs(re, cluster, dcLocationConfig) - c.Assert(failpoint.Enable("github.com/tikv/pd/server/tso/systemTimeSlow", `return(true)`), IsNil) + re.NoError(failpoint.Enable("github.com/tikv/pd/server/tso/systemTimeSlow", `return(true)`)) // Reboot the cluster. - err = cluster.StopAll() - c.Assert(err, IsNil) - err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(cluster.StopAll()) + re.NoError(cluster.RunInitialServers()) - cluster.WaitAllLeaders(c, dcLocationConfig) + cluster.WaitAllLeaders(re, dcLocationConfig) // Re-request the Local TSOs. - newTSMap := requestLocalTSOs(c, cluster, dcLocationConfig) + newTSMap := requestLocalTSOs(re, cluster, dcLocationConfig) for dcLocation, newTS := range newTSMap { lastTS, ok := lastTSMap[dcLocation] - c.Assert(ok, IsTrue) + re.True(ok) // The new physical time of TSO should be larger even if the system time is slow. - c.Assert(newTS.GetPhysical()-lastTS.GetPhysical(), Greater, int64(0)) + re.Greater(newTS.GetPhysical()-lastTS.GetPhysical(), int64(0)) } failpoint.Disable("github.com/tikv/pd/server/tso/systemTimeSlow") } -func requestLocalTSOs(c *C, cluster *tests.TestCluster, dcLocationConfig map[string]string) map[string]*pdpb.Timestamp { +func requestLocalTSOs(re *require.Assertions, cluster *tests.TestCluster, dcLocationConfig map[string]string) map[string]*pdpb.Timestamp { dcClientMap := make(map[string]pdpb.PDClient) tsMap := make(map[string]*pdpb.Timestamp) leaderServer := cluster.GetServer(cluster.GetLeader()) for _, dcLocation := range dcLocationConfig { pdName := leaderServer.GetAllocatorLeader(dcLocation).GetName() - dcClientMap[dcLocation] = testutil.MustNewGrpcClient(c, cluster.GetServer(pdName).GetAddr()) + dcClientMap[dcLocation] = testutil.MustNewGrpcClientWithTestify(re, cluster.GetServer(pdName).GetAddr()) } for _, dcLocation := range dcLocationConfig { req := &pdpb.TsoRequest{ @@ -103,7 +88,7 @@ func requestLocalTSOs(c *C, cluster *tests.TestCluster, dcLocationConfig map[str } ctx, cancel := context.WithCancel(context.Background()) ctx = grpcutil.BuildForwardContext(ctx, cluster.GetServer(leaderServer.GetAllocatorLeader(dcLocation).GetName()).GetAddr()) - tsMap[dcLocation] = testGetTimestamp(c, ctx, dcClientMap[dcLocation], req) + tsMap[dcLocation] = testGetTimestamp(re, ctx, dcClientMap[dcLocation], req) cancel() } return tsMap From e0f5b49af3c44417e0fcfee02dd1a3e48f17bb8f Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Tue, 21 Jun 2022 12:02:37 +0800 Subject: [PATCH 59/82] operator: migrate test framework to testify (#5191) ref tikv/pd#4813 Signed-off-by: lhy1024 Co-authored-by: Ti Chi Robot --- server/schedule/operator/builder_test.go | 237 +++++++++--------- server/schedule/operator/status_test.go | 17 +- .../schedule/operator/status_tracker_test.go | 122 ++++----- server/schedule/operator/step_test.go | 130 +++++----- 4 files changed, 258 insertions(+), 248 deletions(-) diff --git a/server/schedule/operator/builder_test.go b/server/schedule/operator/builder_test.go index 22fac017ec0..ed8e6b88ffe 100644 --- a/server/schedule/operator/builder_test.go +++ b/server/schedule/operator/builder_test.go @@ -16,132 +16,137 @@ package operator import ( "context" + "testing" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/mock/mockcluster" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/core" ) -var _ = Suite(&testBuilderSuite{}) +type operatorBuilderTestSuite struct { + suite.Suite -type testBuilderSuite struct { cluster *mockcluster.Cluster ctx context.Context cancel context.CancelFunc } -func (s *testBuilderSuite) SetUpTest(c *C) { +func TestOperatorBuilderTestSuite(t *testing.T) { + suite.Run(t, new(operatorBuilderTestSuite)) +} + +func (suite *operatorBuilderTestSuite) SetupTest() { opts := config.NewTestOptions() - s.ctx, s.cancel = context.WithCancel(context.Background()) - s.cluster = mockcluster.NewCluster(s.ctx, opts) - s.cluster.SetLabelPropertyConfig(config.LabelPropertyConfig{ + suite.ctx, suite.cancel = context.WithCancel(context.Background()) + suite.cluster = mockcluster.NewCluster(suite.ctx, opts) + suite.cluster.SetLabelPropertyConfig(config.LabelPropertyConfig{ config.RejectLeader: {{Key: "noleader", Value: "true"}}, }) - s.cluster.SetLocationLabels([]string{"zone", "host"}) - s.cluster.AddLabelsStore(1, 0, map[string]string{"zone": "z1", "host": "h1"}) - s.cluster.AddLabelsStore(2, 0, map[string]string{"zone": "z1", "host": "h1"}) - s.cluster.AddLabelsStore(3, 0, map[string]string{"zone": "z1", "host": "h1"}) - s.cluster.AddLabelsStore(4, 0, map[string]string{"zone": "z1", "host": "h1"}) - s.cluster.AddLabelsStore(5, 0, map[string]string{"zone": "z1", "host": "h1"}) - s.cluster.AddLabelsStore(6, 0, map[string]string{"zone": "z1", "host": "h2"}) - s.cluster.AddLabelsStore(7, 0, map[string]string{"zone": "z1", "host": "h2"}) - s.cluster.AddLabelsStore(8, 0, map[string]string{"zone": "z2", "host": "h1"}) - s.cluster.AddLabelsStore(9, 0, map[string]string{"zone": "z2", "host": "h2"}) - s.cluster.AddLabelsStore(10, 0, map[string]string{"zone": "z3", "host": "h1", "noleader": "true"}) + suite.cluster.SetLocationLabels([]string{"zone", "host"}) + suite.cluster.AddLabelsStore(1, 0, map[string]string{"zone": "z1", "host": "h1"}) + suite.cluster.AddLabelsStore(2, 0, map[string]string{"zone": "z1", "host": "h1"}) + suite.cluster.AddLabelsStore(3, 0, map[string]string{"zone": "z1", "host": "h1"}) + suite.cluster.AddLabelsStore(4, 0, map[string]string{"zone": "z1", "host": "h1"}) + suite.cluster.AddLabelsStore(5, 0, map[string]string{"zone": "z1", "host": "h1"}) + suite.cluster.AddLabelsStore(6, 0, map[string]string{"zone": "z1", "host": "h2"}) + suite.cluster.AddLabelsStore(7, 0, map[string]string{"zone": "z1", "host": "h2"}) + suite.cluster.AddLabelsStore(8, 0, map[string]string{"zone": "z2", "host": "h1"}) + suite.cluster.AddLabelsStore(9, 0, map[string]string{"zone": "z2", "host": "h2"}) + suite.cluster.AddLabelsStore(10, 0, map[string]string{"zone": "z3", "host": "h1", "noleader": "true"}) } -func (s *testBuilderSuite) TearDownTest(c *C) { - s.cancel() +func (suite *operatorBuilderTestSuite) TearDownTest() { + suite.cancel() } -func (s *testBuilderSuite) TestNewBuilder(c *C) { +func (suite *operatorBuilderTestSuite) TestNewBuilder() { peers := []*metapb.Peer{{Id: 11, StoreId: 1}, {Id: 12, StoreId: 2, Role: metapb.PeerRole_Learner}} region := core.NewRegionInfo(&metapb.Region{Id: 42, Peers: peers}, peers[0]) - builder := NewBuilder("test", s.cluster, region) - c.Assert(builder.err, IsNil) - c.Assert(builder.originPeers, HasLen, 2) - c.Assert(builder.originPeers[1], DeepEquals, peers[0]) - c.Assert(builder.originPeers[2], DeepEquals, peers[1]) - c.Assert(builder.originLeaderStoreID, Equals, uint64(1)) - c.Assert(builder.targetPeers, HasLen, 2) - c.Assert(builder.targetPeers[1], DeepEquals, peers[0]) - c.Assert(builder.targetPeers[2], DeepEquals, peers[1]) + builder := NewBuilder("test", suite.cluster, region) + suite.NoError(builder.err) + suite.Len(builder.originPeers, 2) + suite.Equal(peers[0], builder.originPeers[1]) + suite.Equal(peers[1], builder.originPeers[2]) + suite.Equal(uint64(1), builder.originLeaderStoreID) + suite.Len(builder.targetPeers, 2) + suite.Equal(peers[0], builder.targetPeers[1]) + suite.Equal(peers[1], builder.targetPeers[2]) region = region.Clone(core.WithLeader(nil)) - builder = NewBuilder("test", s.cluster, region) - c.Assert(builder.err, NotNil) + builder = NewBuilder("test", suite.cluster, region) + suite.Error(builder.err) } -func (s *testBuilderSuite) newBuilder() *Builder { +func (suite *operatorBuilderTestSuite) newBuilder() *Builder { peers := []*metapb.Peer{ {Id: 11, StoreId: 1}, {Id: 12, StoreId: 2}, {Id: 13, StoreId: 3, Role: metapb.PeerRole_Learner}, } region := core.NewRegionInfo(&metapb.Region{Id: 1, Peers: peers}, peers[0]) - return NewBuilder("test", s.cluster, region) + return NewBuilder("test", suite.cluster, region) } -func (s *testBuilderSuite) TestRecord(c *C) { - c.Assert(s.newBuilder().AddPeer(&metapb.Peer{StoreId: 1}).err, NotNil) - c.Assert(s.newBuilder().AddPeer(&metapb.Peer{StoreId: 4}).err, IsNil) - c.Assert(s.newBuilder().PromoteLearner(1).err, NotNil) - c.Assert(s.newBuilder().PromoteLearner(3).err, IsNil) - c.Assert(s.newBuilder().SetLeader(1).SetLeader(2).err, IsNil) - c.Assert(s.newBuilder().SetLeader(3).err, NotNil) - c.Assert(s.newBuilder().RemovePeer(4).err, NotNil) - c.Assert(s.newBuilder().AddPeer(&metapb.Peer{StoreId: 4, Role: metapb.PeerRole_Learner}).RemovePeer(4).err, IsNil) - c.Assert(s.newBuilder().SetLeader(2).RemovePeer(2).err, NotNil) - c.Assert(s.newBuilder().PromoteLearner(4).err, NotNil) - c.Assert(s.newBuilder().SetLeader(4).err, NotNil) - c.Assert(s.newBuilder().SetPeers(map[uint64]*metapb.Peer{2: {Id: 2}}).err, NotNil) +func (suite *operatorBuilderTestSuite) TestRecord() { + suite.Error(suite.newBuilder().AddPeer(&metapb.Peer{StoreId: 1}).err) + suite.NoError(suite.newBuilder().AddPeer(&metapb.Peer{StoreId: 4}).err) + suite.Error(suite.newBuilder().PromoteLearner(1).err) + suite.NoError(suite.newBuilder().PromoteLearner(3).err) + suite.NoError(suite.newBuilder().SetLeader(1).SetLeader(2).err) + suite.Error(suite.newBuilder().SetLeader(3).err) + suite.Error(suite.newBuilder().RemovePeer(4).err) + suite.NoError(suite.newBuilder().AddPeer(&metapb.Peer{StoreId: 4, Role: metapb.PeerRole_Learner}).RemovePeer(4).err) + suite.Error(suite.newBuilder().SetLeader(2).RemovePeer(2).err) + suite.Error(suite.newBuilder().PromoteLearner(4).err) + suite.Error(suite.newBuilder().SetLeader(4).err) + suite.Error(suite.newBuilder().SetPeers(map[uint64]*metapb.Peer{2: {Id: 2}}).err) m := map[uint64]*metapb.Peer{ 2: {StoreId: 2}, 3: {StoreId: 3, Role: metapb.PeerRole_Learner}, 4: {StoreId: 4}, } - builder := s.newBuilder().SetPeers(m).EnableLightWeight() - c.Assert(builder.targetPeers, HasLen, 3) - c.Assert(builder.targetPeers[2], DeepEquals, m[2]) - c.Assert(builder.targetPeers[3], DeepEquals, m[3]) - c.Assert(builder.targetPeers[4], DeepEquals, m[4]) - c.Assert(builder.targetLeaderStoreID, Equals, uint64(0)) - c.Assert(builder.lightWeight, IsTrue) + builder := suite.newBuilder().SetPeers(m).EnableLightWeight() + suite.Len(builder.targetPeers, 3) + suite.Equal(m[2], builder.targetPeers[2]) + suite.Equal(m[3], builder.targetPeers[3]) + suite.Equal(m[4], builder.targetPeers[4]) + suite.Equal(uint64(0), builder.targetLeaderStoreID) + suite.True(builder.lightWeight) } -func (s *testBuilderSuite) TestPrepareBuild(c *C) { +func (suite *operatorBuilderTestSuite) TestPrepareBuild() { // no voter. - _, err := s.newBuilder().SetPeers(map[uint64]*metapb.Peer{4: {StoreId: 4, Role: metapb.PeerRole_Learner}}).prepareBuild() - c.Assert(err, NotNil) + _, err := suite.newBuilder().SetPeers(map[uint64]*metapb.Peer{4: {StoreId: 4, Role: metapb.PeerRole_Learner}}).prepareBuild() + suite.Error(err) // use joint consensus - builder := s.newBuilder().SetPeers(map[uint64]*metapb.Peer{ + builder := suite.newBuilder().SetPeers(map[uint64]*metapb.Peer{ 1: {StoreId: 1, Role: metapb.PeerRole_Learner}, 3: {StoreId: 3}, 4: {StoreId: 4, Id: 14}, 5: {StoreId: 5, Role: metapb.PeerRole_Learner}, }) _, err = builder.prepareBuild() - c.Assert(err, IsNil) - c.Assert(builder.toAdd, HasLen, 2) - c.Assert(builder.toAdd[4].GetRole(), Not(Equals), metapb.PeerRole_Learner) - c.Assert(builder.toAdd[4].GetId(), Equals, uint64(14)) - c.Assert(builder.toAdd[5].GetRole(), Equals, metapb.PeerRole_Learner) - c.Assert(builder.toAdd[5].GetId(), Not(Equals), uint64(0)) - c.Assert(builder.toRemove, HasLen, 1) - c.Assert(builder.toRemove[2], NotNil) - c.Assert(builder.toPromote, HasLen, 1) - c.Assert(builder.toPromote[3], NotNil) - c.Assert(builder.toDemote, HasLen, 1) - c.Assert(builder.toDemote[1], NotNil) - c.Assert(builder.currentLeaderStoreID, Equals, uint64(1)) + suite.NoError(err) + suite.Len(builder.toAdd, 2) + suite.NotEqual(metapb.PeerRole_Learner, builder.toAdd[4].GetRole()) + suite.Equal(uint64(14), builder.toAdd[4].GetId()) + suite.Equal(metapb.PeerRole_Learner, builder.toAdd[5].GetRole()) + suite.NotEqual(uint64(0), builder.toAdd[5].GetId()) + suite.Len(builder.toRemove, 1) + suite.NotNil(builder.toRemove[2]) + suite.Len(builder.toPromote, 1) + suite.NotNil(builder.toPromote[3]) + suite.Len(builder.toDemote, 1) + suite.NotNil(builder.toDemote[1]) + suite.Equal(uint64(1), builder.currentLeaderStoreID) // do not use joint consensus - builder = s.newBuilder().SetPeers(map[uint64]*metapb.Peer{ + builder = suite.newBuilder().SetPeers(map[uint64]*metapb.Peer{ 1: {StoreId: 1, Role: metapb.PeerRole_Learner}, 2: {StoreId: 2}, 3: {StoreId: 3}, @@ -150,22 +155,22 @@ func (s *testBuilderSuite) TestPrepareBuild(c *C) { }) builder.useJointConsensus = false _, err = builder.prepareBuild() - c.Assert(err, IsNil) - c.Assert(builder.toAdd, HasLen, 3) - c.Assert(builder.toAdd[1].GetRole(), Equals, metapb.PeerRole_Learner) - c.Assert(builder.toAdd[1].GetId(), Not(Equals), uint64(0)) - c.Assert(builder.toAdd[4].GetRole(), Not(Equals), metapb.PeerRole_Learner) - c.Assert(builder.toAdd[4].GetId(), Equals, uint64(14)) - c.Assert(builder.toAdd[5].GetRole(), Equals, metapb.PeerRole_Learner) - c.Assert(builder.toAdd[5].GetId(), Not(Equals), uint64(0)) - c.Assert(builder.toRemove, HasLen, 1) - c.Assert(builder.toRemove[1], NotNil) - c.Assert(builder.toPromote, HasLen, 1) - c.Assert(builder.toPromote[3], NotNil) - c.Assert(builder.currentLeaderStoreID, Equals, uint64(1)) + suite.NoError(err) + suite.Len(builder.toAdd, 3) + suite.Equal(metapb.PeerRole_Learner, builder.toAdd[1].GetRole()) + suite.NotEqual(uint64(0), builder.toAdd[1].GetId()) + suite.NotEqual(metapb.PeerRole_Learner, builder.toAdd[4].GetRole()) + suite.Equal(uint64(14), builder.toAdd[4].GetId()) + suite.Equal(metapb.PeerRole_Learner, builder.toAdd[5].GetRole()) + suite.NotEqual(uint64(0), builder.toAdd[5].GetId()) + suite.Len(builder.toRemove, 1) + suite.NotNil(builder.toRemove[1]) + suite.Len(builder.toPromote, 1) + suite.NotNil(builder.toPromote[3]) + suite.Equal(uint64(1), builder.currentLeaderStoreID) } -func (s *testBuilderSuite) TestBuild(c *C) { +func (suite *operatorBuilderTestSuite) TestBuild() { type testCase struct { name string useJointConsensus bool @@ -530,9 +535,9 @@ func (s *testBuilderSuite) TestBuild(c *C) { } for _, tc := range cases { - c.Log(tc.name) + suite.T().Log(tc.name) region := core.NewRegionInfo(&metapb.Region{Id: 1, Peers: tc.originPeers}, tc.originPeers[0]) - builder := NewBuilder("test", s.cluster, region) + builder := NewBuilder("test", suite.cluster, region) builder.useJointConsensus = tc.useJointConsensus m := make(map[uint64]*metapb.Peer) for _, p := range tc.targetPeers { @@ -541,71 +546,69 @@ func (s *testBuilderSuite) TestBuild(c *C) { builder.SetPeers(m).SetLeader(tc.targetPeers[0].GetStoreId()) op, err := builder.Build(0) if len(tc.steps) == 0 { - c.Assert(err, NotNil) + suite.Error(err) continue } - c.Assert(err, IsNil) - c.Assert(op.Kind(), Equals, tc.kind) - c.Assert(op.Len(), Equals, len(tc.steps)) + suite.NoError(err) + suite.Equal(tc.kind, op.Kind()) + suite.Len(tc.steps, op.Len()) for i := 0; i < op.Len(); i++ { switch step := op.Step(i).(type) { case TransferLeader: - c.Assert(step.FromStore, Equals, tc.steps[i].(TransferLeader).FromStore) - c.Assert(step.ToStore, Equals, tc.steps[i].(TransferLeader).ToStore) + suite.Equal(tc.steps[i].(TransferLeader).FromStore, step.FromStore) + suite.Equal(tc.steps[i].(TransferLeader).ToStore, step.ToStore) case AddPeer: - c.Assert(step.ToStore, Equals, tc.steps[i].(AddPeer).ToStore) + suite.Equal(tc.steps[i].(AddPeer).ToStore, step.ToStore) case RemovePeer: - c.Assert(step.FromStore, Equals, tc.steps[i].(RemovePeer).FromStore) + suite.Equal(tc.steps[i].(RemovePeer).FromStore, step.FromStore) case AddLearner: - c.Assert(step.ToStore, Equals, tc.steps[i].(AddLearner).ToStore) + suite.Equal(tc.steps[i].(AddLearner).ToStore, step.ToStore) case PromoteLearner: - c.Assert(step.ToStore, Equals, tc.steps[i].(PromoteLearner).ToStore) + suite.Equal(tc.steps[i].(PromoteLearner).ToStore, step.ToStore) case ChangePeerV2Enter: - c.Assert(len(step.PromoteLearners), Equals, len(tc.steps[i].(ChangePeerV2Enter).PromoteLearners)) - c.Assert(len(step.DemoteVoters), Equals, len(tc.steps[i].(ChangePeerV2Enter).DemoteVoters)) + suite.Len(tc.steps[i].(ChangePeerV2Enter).PromoteLearners, len(step.PromoteLearners)) + suite.Len(tc.steps[i].(ChangePeerV2Enter).DemoteVoters, len(step.DemoteVoters)) for j, p := range tc.steps[i].(ChangePeerV2Enter).PromoteLearners { - c.Assert(step.PromoteLearners[j].ToStore, Equals, p.ToStore) + suite.Equal(p.ToStore, step.PromoteLearners[j].ToStore) } for j, d := range tc.steps[i].(ChangePeerV2Enter).DemoteVoters { - c.Assert(step.DemoteVoters[j].ToStore, Equals, d.ToStore) + suite.Equal(d.ToStore, step.DemoteVoters[j].ToStore) } case ChangePeerV2Leave: - c.Assert(len(step.PromoteLearners), Equals, len(tc.steps[i].(ChangePeerV2Leave).PromoteLearners)) - c.Assert(len(step.DemoteVoters), Equals, len(tc.steps[i].(ChangePeerV2Leave).DemoteVoters)) + suite.Len(tc.steps[i].(ChangePeerV2Leave).PromoteLearners, len(step.PromoteLearners)) + suite.Len(tc.steps[i].(ChangePeerV2Leave).DemoteVoters, len(step.DemoteVoters)) for j, p := range tc.steps[i].(ChangePeerV2Leave).PromoteLearners { - c.Assert(step.PromoteLearners[j].ToStore, Equals, p.ToStore) + suite.Equal(p.ToStore, step.PromoteLearners[j].ToStore) } for j, d := range tc.steps[i].(ChangePeerV2Leave).DemoteVoters { - c.Assert(step.DemoteVoters[j].ToStore, Equals, d.ToStore) + suite.Equal(d.ToStore, step.DemoteVoters[j].ToStore) } } } } } -// Test for not set unhealthy peer as target for promote learner and transfer leader -func (s *testBuilderSuite) TestTargetUnhealthyPeer(c *C) { +func (suite *operatorBuilderTestSuite) TestTargetUnhealthyPeer() { p := &metapb.Peer{Id: 2, StoreId: 2, Role: metapb.PeerRole_Learner} region := core.NewRegionInfo(&metapb.Region{Id: 1, Peers: []*metapb.Peer{{Id: 1, StoreId: 1}, p}}, &metapb.Peer{Id: 1, StoreId: 1}, core.WithPendingPeers([]*metapb.Peer{p})) - builder := NewBuilder("test", s.cluster, region) + builder := NewBuilder("test", suite.cluster, region) builder.PromoteLearner(2) - c.Assert(builder.err, NotNil) + suite.Error(builder.err) region = core.NewRegionInfo(&metapb.Region{Id: 1, Peers: []*metapb.Peer{{Id: 1, StoreId: 1}, p}}, &metapb.Peer{Id: 1, StoreId: 1}, core.WithDownPeers([]*pdpb.PeerStats{{Peer: p}})) - builder = NewBuilder("test", s.cluster, region) + builder = NewBuilder("test", suite.cluster, region) builder.PromoteLearner(2) - c.Assert(builder.err, NotNil) - + suite.Error(builder.err) p = &metapb.Peer{Id: 2, StoreId: 2, Role: metapb.PeerRole_Voter} region = core.NewRegionInfo(&metapb.Region{Id: 1, Peers: []*metapb.Peer{{Id: 1, StoreId: 1}, p}}, &metapb.Peer{Id: 1, StoreId: 1}, core.WithPendingPeers([]*metapb.Peer{p})) - builder = NewBuilder("test", s.cluster, region) + builder = NewBuilder("test", suite.cluster, region) builder.SetLeader(2) - c.Assert(builder.err, NotNil) + suite.Error(builder.err) region = core.NewRegionInfo(&metapb.Region{Id: 1, Peers: []*metapb.Peer{{Id: 1, StoreId: 1}, p}}, &metapb.Peer{Id: 1, StoreId: 1}, core.WithDownPeers([]*pdpb.PeerStats{{Peer: p}})) - builder = NewBuilder("test", s.cluster, region) + builder = NewBuilder("test", suite.cluster, region) builder.SetLeader(2) - c.Assert(builder.err, NotNil) + suite.Error(builder.err) } diff --git a/server/schedule/operator/status_test.go b/server/schedule/operator/status_test.go index 42502e1e096..6bdf2710657 100644 --- a/server/schedule/operator/status_test.go +++ b/server/schedule/operator/status_test.go @@ -15,21 +15,20 @@ package operator import ( - . "github.com/pingcap/check" -) - -var _ = Suite(&testOpStatusSuite{}) + "testing" -type testOpStatusSuite struct{} + "github.com/stretchr/testify/require" +) -func (s *testOpStatusSuite) TestIsEndStatus(c *C) { +func TestIsEndStatus(t *testing.T) { + re := require.New(t) for st := OpStatus(0); st < firstEndStatus; st++ { - c.Assert(IsEndStatus(st), IsFalse) + re.False(IsEndStatus(st)) } for st := firstEndStatus; st < statusCount; st++ { - c.Assert(IsEndStatus(st), IsTrue) + re.True(IsEndStatus(st)) } for st := statusCount; st < statusCount+100; st++ { - c.Assert(IsEndStatus(st), IsFalse) + re.False(IsEndStatus(st)) } } diff --git a/server/schedule/operator/status_tracker_test.go b/server/schedule/operator/status_tracker_test.go index 8ada8b386f2..d4441b0e7b6 100644 --- a/server/schedule/operator/status_tracker_test.go +++ b/server/schedule/operator/status_tracker_test.go @@ -15,64 +15,64 @@ package operator import ( + "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testOpStatusTrackerSuite{}) - -type testOpStatusTrackerSuite struct{} - -func (s *testOpStatusTrackerSuite) TestCreate(c *C) { +func TestCreate(t *testing.T) { + re := require.New(t) before := time.Now() trk := NewOpStatusTracker() - c.Assert(trk.Status(), Equals, CREATED) - c.Assert(trk.ReachTime(), DeepEquals, trk.ReachTimeOf(CREATED)) - checkTimeOrder(c, before, trk.ReachTime(), time.Now()) - checkReachTime(c, &trk, CREATED) + re.Equal(CREATED, trk.Status()) + re.Equal(trk.ReachTimeOf(CREATED), trk.ReachTime()) + checkTimeOrder(re, before, trk.ReachTime(), time.Now()) + checkReachTime(re, &trk, CREATED) } -func (s *testOpStatusTrackerSuite) TestNonEndTrans(c *C) { +func TestNonEndTrans(t *testing.T) { + re := require.New(t) { trk := NewOpStatusTracker() - checkInvalidTrans(c, &trk, SUCCESS, REPLACED, TIMEOUT) - checkValidTrans(c, &trk, STARTED) - checkInvalidTrans(c, &trk, EXPIRED) - checkValidTrans(c, &trk, SUCCESS) - checkReachTime(c, &trk, CREATED, STARTED, SUCCESS) + checkInvalidTrans(re, &trk, SUCCESS, REPLACED, TIMEOUT) + checkValidTrans(re, &trk, STARTED) + checkInvalidTrans(re, &trk, EXPIRED) + checkValidTrans(re, &trk, SUCCESS) + checkReachTime(re, &trk, CREATED, STARTED, SUCCESS) } { trk := NewOpStatusTracker() - checkValidTrans(c, &trk, CANCELED) - checkReachTime(c, &trk, CREATED, CANCELED) + checkValidTrans(re, &trk, CANCELED) + checkReachTime(re, &trk, CREATED, CANCELED) } { trk := NewOpStatusTracker() - checkValidTrans(c, &trk, STARTED) - checkValidTrans(c, &trk, CANCELED) - checkReachTime(c, &trk, CREATED, STARTED, CANCELED) + checkValidTrans(re, &trk, STARTED) + checkValidTrans(re, &trk, CANCELED) + checkReachTime(re, &trk, CREATED, STARTED, CANCELED) } { trk := NewOpStatusTracker() - checkValidTrans(c, &trk, STARTED) - checkValidTrans(c, &trk, REPLACED) - checkReachTime(c, &trk, CREATED, STARTED, REPLACED) + checkValidTrans(re, &trk, STARTED) + checkValidTrans(re, &trk, REPLACED) + checkReachTime(re, &trk, CREATED, STARTED, REPLACED) } { trk := NewOpStatusTracker() - checkValidTrans(c, &trk, EXPIRED) - checkReachTime(c, &trk, CREATED, EXPIRED) + checkValidTrans(re, &trk, EXPIRED) + checkReachTime(re, &trk, CREATED, EXPIRED) } { trk := NewOpStatusTracker() - checkValidTrans(c, &trk, STARTED) - checkValidTrans(c, &trk, TIMEOUT) - checkReachTime(c, &trk, CREATED, STARTED, TIMEOUT) + checkValidTrans(re, &trk, STARTED) + checkValidTrans(re, &trk, TIMEOUT) + checkReachTime(re, &trk, CREATED, STARTED, TIMEOUT) } } -func (s *testOpStatusTrackerSuite) TestEndStatusTrans(c *C) { +func TestEndStatusTrans(t *testing.T) { + re := require.New(t) allStatus := make([]OpStatus, 0, statusCount) for st := OpStatus(0); st < statusCount; st++ { allStatus = append(allStatus, st) @@ -80,41 +80,43 @@ func (s *testOpStatusTrackerSuite) TestEndStatusTrans(c *C) { for from := firstEndStatus; from < statusCount; from++ { trk := NewOpStatusTracker() trk.current = from - c.Assert(trk.IsEnd(), IsTrue) - checkInvalidTrans(c, &trk, allStatus...) + re.True(trk.IsEnd()) + checkInvalidTrans(re, &trk, allStatus...) } } -func (s *testOpStatusTrackerSuite) TestCheckExpired(c *C) { +func TestCheckExpired(t *testing.T) { + re := require.New(t) { // Not expired before := time.Now() trk := NewOpStatusTracker() after := time.Now() - c.Assert(trk.CheckExpired(10*time.Second), IsFalse) - c.Assert(trk.Status(), Equals, CREATED) - checkTimeOrder(c, before, trk.ReachTime(), after) + re.False(trk.CheckExpired(10 * time.Second)) + re.Equal(CREATED, trk.Status()) + checkTimeOrder(re, before, trk.ReachTime(), after) } { // Expired but status not changed trk := NewOpStatusTracker() trk.setTime(CREATED, time.Now().Add(-10*time.Second)) - c.Assert(trk.CheckExpired(5*time.Second), IsTrue) - c.Assert(trk.Status(), Equals, EXPIRED) + re.True(trk.CheckExpired(5 * time.Second)) + re.Equal(EXPIRED, trk.Status()) } { // Expired and status changed trk := NewOpStatusTracker() before := time.Now() - c.Assert(trk.To(EXPIRED), IsTrue) + re.True(trk.To(EXPIRED)) after := time.Now() - c.Assert(trk.CheckExpired(0), IsTrue) - c.Assert(trk.Status(), Equals, EXPIRED) - checkTimeOrder(c, before, trk.ReachTime(), after) + re.True(trk.CheckExpired(0)) + re.Equal(EXPIRED, trk.Status()) + checkTimeOrder(re, before, trk.ReachTime(), after) } } -func (s *testOpStatusTrackerSuite) TestCheckStepTimeout(c *C) { +func TestCheckStepTimeout(t *testing.T) { + re := require.New(t) testdata := []struct { step OpStep start time.Time @@ -133,45 +135,45 @@ func (s *testOpStatusTrackerSuite) TestCheckStepTimeout(c *C) { // Timeout and status changed trk := NewOpStatusTracker() trk.To(STARTED) - c.Assert(trk.CheckStepTimeout(v.start, v.step, 0), Equals, v.status == TIMEOUT) - c.Assert(trk.Status(), Equals, v.status) + re.Equal(v.status == TIMEOUT, trk.CheckStepTimeout(v.start, v.step, 0)) + re.Equal(v.status, trk.Status()) } } -func checkTimeOrder(c *C, t1, t2, t3 time.Time) { - c.Assert(t1.Before(t2), IsTrue) - c.Assert(t3.After(t2), IsTrue) +func checkTimeOrder(re *require.Assertions, t1, t2, t3 time.Time) { + re.True(t1.Before(t2)) + re.True(t3.After(t2)) } -func checkValidTrans(c *C, trk *OpStatusTracker, st OpStatus) { +func checkValidTrans(re *require.Assertions, trk *OpStatusTracker, st OpStatus) { before := time.Now() - c.Assert(trk.To(st), IsTrue) - c.Assert(trk.Status(), Equals, st) - c.Assert(trk.ReachTime(), DeepEquals, trk.ReachTimeOf(st)) - checkTimeOrder(c, before, trk.ReachTime(), time.Now()) + re.True(trk.To(st)) + re.Equal(st, trk.Status()) + re.Equal(trk.ReachTimeOf(st), trk.ReachTime()) + checkTimeOrder(re, before, trk.ReachTime(), time.Now()) } -func checkInvalidTrans(c *C, trk *OpStatusTracker, sts ...OpStatus) { +func checkInvalidTrans(re *require.Assertions, trk *OpStatusTracker, sts ...OpStatus) { origin := trk.Status() originTime := trk.ReachTime() sts = append(sts, statusCount, statusCount+1, statusCount+10) for _, st := range sts { - c.Assert(trk.To(st), IsFalse) - c.Assert(trk.Status(), Equals, origin) - c.Assert(trk.ReachTime(), DeepEquals, originTime) + re.False(trk.To(st)) + re.Equal(origin, trk.Status()) + re.Equal(originTime, trk.ReachTime()) } } -func checkReachTime(c *C, trk *OpStatusTracker, reached ...OpStatus) { +func checkReachTime(re *require.Assertions, trk *OpStatusTracker, reached ...OpStatus) { reachedMap := make(map[OpStatus]struct{}, len(reached)) for _, st := range reached { - c.Assert(trk.ReachTimeOf(st).IsZero(), IsFalse) + re.False(trk.ReachTimeOf(st).IsZero()) reachedMap[st] = struct{}{} } for st := OpStatus(0); st <= statusCount+10; st++ { if _, ok := reachedMap[st]; ok { continue } - c.Assert(trk.ReachTimeOf(st).IsZero(), IsTrue) + re.True(trk.ReachTimeOf(st).IsZero()) } } diff --git a/server/schedule/operator/step_test.go b/server/schedule/operator/step_test.go index f4bd9865b25..aa2f18c7220 100644 --- a/server/schedule/operator/step_test.go +++ b/server/schedule/operator/step_test.go @@ -16,38 +16,43 @@ package operator import ( "context" + "testing" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/mock/mockcluster" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/core" ) -type testStepSuite struct { +type operatorStepTestSuite struct { + suite.Suite + cluster *mockcluster.Cluster } -var _ = Suite(&testStepSuite{}) +func TestOperatorStepTestSuite(t *testing.T) { + suite.Run(t, new(operatorStepTestSuite)) +} type testCase struct { - Peers []*metapb.Peer // first is leader - ConfVerChanged uint64 - IsFinish bool - CheckInProgres Checker + Peers []*metapb.Peer // first is leader + ConfVerChanged uint64 + IsFinish bool + CheckInProgress func(err error, msgAndArgs ...interface{}) bool } -func (s *testStepSuite) SetUpTest(c *C) { - s.cluster = mockcluster.NewCluster(context.Background(), config.NewTestOptions()) +func (suite *operatorStepTestSuite) SetupTest() { + suite.cluster = mockcluster.NewCluster(context.Background(), config.NewTestOptions()) for i := 1; i <= 10; i++ { - s.cluster.PutStoreWithLabels(uint64(i)) + suite.cluster.PutStoreWithLabels(uint64(i)) } - s.cluster.SetStoreDown(8) - s.cluster.SetStoreDown(9) - s.cluster.SetStoreDown(10) + suite.cluster.SetStoreDown(8) + suite.cluster.SetStoreDown(9) + suite.cluster.SetStoreDown(10) } -func (s *testStepSuite) TestTransferLeader(c *C) { +func (suite *operatorStepTestSuite) TestTransferLeader() { step := TransferLeader{FromStore: 1, ToStore: 2} cases := []testCase{ { @@ -58,7 +63,7 @@ func (s *testStepSuite) TestTransferLeader(c *C) { }, 0, false, - IsNil, + suite.NoError, }, { []*metapb.Peer{ @@ -68,7 +73,7 @@ func (s *testStepSuite) TestTransferLeader(c *C) { }, 0, true, - IsNil, + suite.NoError, }, { []*metapb.Peer{ @@ -78,10 +83,10 @@ func (s *testStepSuite) TestTransferLeader(c *C) { }, 0, false, - IsNil, + suite.NoError, }, } - s.check(c, step, "transfer leader from store 1 to store 2", cases) + suite.check(step, "transfer leader from store 1 to store 2", cases) step = TransferLeader{FromStore: 1, ToStore: 9} // 9 is down cases = []testCase{ @@ -93,13 +98,13 @@ func (s *testStepSuite) TestTransferLeader(c *C) { }, 0, false, - NotNil, + suite.Error, }, } - s.check(c, step, "transfer leader from store 1 to store 9", cases) + suite.check(step, "transfer leader from store 1 to store 9", cases) } -func (s *testStepSuite) TestAddPeer(c *C) { +func (suite *operatorStepTestSuite) TestAddPeer() { step := AddPeer{ToStore: 2, PeerID: 2} cases := []testCase{ { @@ -108,7 +113,7 @@ func (s *testStepSuite) TestAddPeer(c *C) { }, 0, false, - IsNil, + suite.NoError, }, { []*metapb.Peer{ @@ -117,10 +122,10 @@ func (s *testStepSuite) TestAddPeer(c *C) { }, 1, true, - IsNil, + suite.NoError, }, } - s.check(c, step, "add peer 2 on store 2", cases) + suite.check(step, "add peer 2 on store 2", cases) step = AddPeer{ToStore: 9, PeerID: 9} cases = []testCase{ @@ -130,13 +135,13 @@ func (s *testStepSuite) TestAddPeer(c *C) { }, 0, false, - NotNil, + suite.Error, }, } - s.check(c, step, "add peer 9 on store 9", cases) + suite.check(step, "add peer 9 on store 9", cases) } -func (s *testStepSuite) TestAddLearner(c *C) { +func (suite *operatorStepTestSuite) TestAddLearner() { step := AddLearner{ToStore: 2, PeerID: 2} cases := []testCase{ { @@ -145,7 +150,7 @@ func (s *testStepSuite) TestAddLearner(c *C) { }, 0, false, - IsNil, + suite.NoError, }, { []*metapb.Peer{ @@ -154,10 +159,10 @@ func (s *testStepSuite) TestAddLearner(c *C) { }, 1, true, - IsNil, + suite.NoError, }, } - s.check(c, step, "add learner peer 2 on store 2", cases) + suite.check(step, "add learner peer 2 on store 2", cases) step = AddLearner{ToStore: 9, PeerID: 9} cases = []testCase{ @@ -167,13 +172,13 @@ func (s *testStepSuite) TestAddLearner(c *C) { }, 0, false, - NotNil, + suite.Error, }, } - s.check(c, step, "add learner peer 9 on store 9", cases) + suite.check(step, "add learner peer 9 on store 9", cases) } -func (s *testStepSuite) TestChangePeerV2Enter(c *C) { +func (suite *operatorStepTestSuite) TestChangePeerV2Enter() { cpe := ChangePeerV2Enter{ PromoteLearners: []PromoteLearner{{PeerID: 3, ToStore: 3}, {PeerID: 4, ToStore: 4}}, DemoteVoters: []DemoteVoter{{PeerID: 1, ToStore: 1}, {PeerID: 2, ToStore: 2}}, @@ -188,7 +193,7 @@ func (s *testStepSuite) TestChangePeerV2Enter(c *C) { }, 0, false, - IsNil, + suite.NoError, }, { // after step []*metapb.Peer{ @@ -199,7 +204,7 @@ func (s *testStepSuite) TestChangePeerV2Enter(c *C) { }, 4, true, - IsNil, + suite.NoError, }, { // miss peer id []*metapb.Peer{ @@ -210,7 +215,7 @@ func (s *testStepSuite) TestChangePeerV2Enter(c *C) { }, 0, false, - NotNil, + suite.Error, }, { // miss store id []*metapb.Peer{ @@ -221,7 +226,7 @@ func (s *testStepSuite) TestChangePeerV2Enter(c *C) { }, 0, false, - NotNil, + suite.Error, }, { // miss peer id []*metapb.Peer{ @@ -232,7 +237,7 @@ func (s *testStepSuite) TestChangePeerV2Enter(c *C) { }, 0, false, - NotNil, + suite.Error, }, { // change is not atomic []*metapb.Peer{ @@ -243,7 +248,7 @@ func (s *testStepSuite) TestChangePeerV2Enter(c *C) { }, 0, false, - NotNil, + suite.Error, }, { // change is not atomic []*metapb.Peer{ @@ -254,7 +259,7 @@ func (s *testStepSuite) TestChangePeerV2Enter(c *C) { }, 0, false, - NotNil, + suite.Error, }, { // there are other peers in the joint state []*metapb.Peer{ @@ -266,7 +271,7 @@ func (s *testStepSuite) TestChangePeerV2Enter(c *C) { }, 4, true, - NotNil, + suite.Error, }, { // there are other peers in the joint state []*metapb.Peer{ @@ -279,16 +284,16 @@ func (s *testStepSuite) TestChangePeerV2Enter(c *C) { }, 0, false, - NotNil, + suite.Error, }, } desc := "use joint consensus, " + "promote learner peer 3 on store 3 to voter, promote learner peer 4 on store 4 to voter, " + "demote voter peer 1 on store 1 to learner, demote voter peer 2 on store 2 to learner" - s.check(c, cpe, desc, cases) + suite.check(cpe, desc, cases) } -func (s *testStepSuite) TestChangePeerV2Leave(c *C) { +func (suite *operatorStepTestSuite) TestChangePeerV2Leave() { cpl := ChangePeerV2Leave{ PromoteLearners: []PromoteLearner{{PeerID: 3, ToStore: 3}, {PeerID: 4, ToStore: 4}}, DemoteVoters: []DemoteVoter{{PeerID: 1, ToStore: 1}, {PeerID: 2, ToStore: 2}}, @@ -303,7 +308,7 @@ func (s *testStepSuite) TestChangePeerV2Leave(c *C) { }, 0, false, - IsNil, + suite.NoError, }, { // after step []*metapb.Peer{ @@ -314,7 +319,7 @@ func (s *testStepSuite) TestChangePeerV2Leave(c *C) { }, 4, true, - IsNil, + suite.NoError, }, { // miss peer id []*metapb.Peer{ @@ -325,7 +330,7 @@ func (s *testStepSuite) TestChangePeerV2Leave(c *C) { }, 0, false, - NotNil, + suite.Error, }, { // miss store id []*metapb.Peer{ @@ -336,7 +341,7 @@ func (s *testStepSuite) TestChangePeerV2Leave(c *C) { }, 0, false, - NotNil, + suite.Error, }, { // miss peer id []*metapb.Peer{ @@ -347,7 +352,7 @@ func (s *testStepSuite) TestChangePeerV2Leave(c *C) { }, 0, false, - NotNil, + suite.Error, }, { // change is not atomic []*metapb.Peer{ @@ -358,7 +363,7 @@ func (s *testStepSuite) TestChangePeerV2Leave(c *C) { }, 0, false, - NotNil, + suite.Error, }, { // change is not atomic []*metapb.Peer{ @@ -369,7 +374,7 @@ func (s *testStepSuite) TestChangePeerV2Leave(c *C) { }, 0, false, - NotNil, + suite.Error, }, { // there are other peers in the joint state []*metapb.Peer{ @@ -381,7 +386,7 @@ func (s *testStepSuite) TestChangePeerV2Leave(c *C) { }, 0, false, - NotNil, + suite.Error, }, { // there are other peers in the joint state []*metapb.Peer{ @@ -394,7 +399,7 @@ func (s *testStepSuite) TestChangePeerV2Leave(c *C) { }, 4, false, - NotNil, + suite.Error, }, { // demote leader []*metapb.Peer{ @@ -405,21 +410,22 @@ func (s *testStepSuite) TestChangePeerV2Leave(c *C) { }, 0, false, - NotNil, + suite.Error, }, } desc := "leave joint state, " + "promote learner peer 3 on store 3 to voter, promote learner peer 4 on store 4 to voter, " + "demote voter peer 1 on store 1 to learner, demote voter peer 2 on store 2 to learner" - s.check(c, cpl, desc, cases) + suite.check(cpl, desc, cases) } -func (s *testStepSuite) check(c *C, step OpStep, desc string, cases []testCase) { - c.Assert(step.String(), Equals, desc) - for _, tc := range cases { - region := core.NewRegionInfo(&metapb.Region{Id: 1, Peers: tc.Peers}, tc.Peers[0]) - c.Assert(step.ConfVerChanged(region), Equals, tc.ConfVerChanged) - c.Assert(step.IsFinish(region), Equals, tc.IsFinish) - c.Assert(step.CheckInProgress(s.cluster, region), tc.CheckInProgres) +func (suite *operatorStepTestSuite) check(step OpStep, desc string, cases []testCase) { + suite.Equal(desc, step.String()) + for _, testCase := range cases { + region := core.NewRegionInfo(&metapb.Region{Id: 1, Peers: testCase.Peers}, testCase.Peers[0]) + suite.Equal(testCase.ConfVerChanged, step.ConfVerChanged(region)) + suite.Equal(testCase.IsFinish, step.IsFinish(region)) + err := step.CheckInProgress(suite.cluster, region) + testCase.CheckInProgress(err) } } From 01b8f34a4034d4c36fc0fb56e0fc7892bb6ecfa9 Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Tue, 21 Jun 2022 12:18:36 +0800 Subject: [PATCH 60/82] *: update the swagger dependency (#5183) close tikv/pd#5160 Signed-off-by: Ryan Leung Co-authored-by: Ti Chi Robot --- Makefile | 6 +- go.mod | 6 +- go.sum | 63 +++-- server/api/admin.go | 34 +-- server/api/checker.go | 34 +-- server/api/cluster.go | 22 +- server/api/config.go | 194 +++++++------- server/api/health.go | 14 +- server/api/hot_status.go | 46 ++-- server/api/label.go | 26 +- server/api/log.go | 20 +- server/api/member.go | 104 ++++---- server/api/min_resolved_ts.go | 12 +- server/api/operator.go | 80 +++--- server/api/plugin.go | 36 +-- server/api/pprof.go | 62 ++--- server/api/region.go | 390 ++++++++++++++--------------- server/api/region_label.go | 126 +++++----- server/api/replication_mode.go | 10 +- server/api/router.go | 18 +- server/api/rule.go | 294 +++++++++++----------- server/api/scheduler.go | 66 ++--- server/api/service_gc_safepoint.go | 28 +-- server/api/service_middleware.go | 44 ++-- server/api/stats.go | 14 +- server/api/status.go | 8 +- server/api/store.go | 218 ++++++++-------- server/api/trend.go | 16 +- server/api/tso.go | 22 +- server/api/unsafe_operation.go | 20 +- server/api/version.go | 8 +- tests/client/go.sum | 59 +++-- 32 files changed, 1078 insertions(+), 1022 deletions(-) diff --git a/Makefile b/Makefile index 2afd99c0734..7ac146e39eb 100644 --- a/Makefile +++ b/Makefile @@ -112,10 +112,8 @@ docker-image: #### Build utils ### swagger-spec: install-tools - go mod vendor - swag init --parseVendor --generalInfo server/api/router.go --exclude vendor/github.com/pingcap/tidb-dashboard --output docs/swagger - go mod tidy - rm -rf vendor + swag init --parseDependency --parseInternal --parseDepth 1 --dir server --generalInfo api/router.go --output docs/swagger + swag fmt --dir server dashboard-ui: ./scripts/embed-dashboard-ui.sh diff --git a/go.mod b/go.mod index 0a662f0f16b..6c2a3c4380b 100644 --- a/go.mod +++ b/go.mod @@ -40,7 +40,7 @@ require ( github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.7.0 github.com/swaggo/http-swagger v0.0.0-20200308142732-58ac5e232fba - github.com/swaggo/swag v1.6.6-0.20200529100950-7c765ddd0476 + github.com/swaggo/swag v1.8.3 github.com/syndtr/goleveldb v1.0.1-0.20190318030020-c3a204f8e965 github.com/tidwall/gjson v1.9.3 // indirect github.com/unrolled/render v1.0.1 @@ -50,9 +50,9 @@ require ( go.etcd.io/etcd v0.5.0-alpha.5.0.20191023171146-3cf2f69b5738 go.uber.org/goleak v1.1.12 go.uber.org/zap v1.19.1 - golang.org/x/text v0.3.3 + golang.org/x/text v0.3.7 golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 - golang.org/x/tools v0.1.5 + golang.org/x/tools v0.1.10 google.golang.org/grpc v1.26.0 gotest.tools/gotestsum v1.7.0 ) diff --git a/go.sum b/go.sum index ad3415e8116..54862125e1a 100644 --- a/go.sum +++ b/go.sum @@ -79,6 +79,7 @@ github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:ma github.com/cpuguy83/go-md2man/v2 v2.0.0 h1:EoUDS0afbrsXAZ9YQ9jdu/mZ2sXgT1/2yyNng4PGlyM= github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v0.0.0-20161028175848-04cdfd42973b/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -136,20 +137,24 @@ github.com/go-ole/go-ole v1.2.4 h1:nNBDSCOigTSiarFpYE9J/KtEA1IOW4CNeqT9TQDqCxI= github.com/go-ole/go-ole v1.2.4/go.mod h1:XCwSNxSkXRo4vlyPy93sltvi/qJq0jqQhjqQNIwKuxM= github.com/go-openapi/jsonpointer v0.17.0/go.mod h1:cOnomiV+CVVwFLk0A/MExoFMjwdsUdVpsRhURCKh+3M= github.com/go-openapi/jsonpointer v0.19.2/go.mod h1:3akKfEdA7DF1sugOqz1dVQHBcuDBPKZGEoHC/NkiQRg= -github.com/go-openapi/jsonpointer v0.19.3 h1:gihV7YNZK1iK6Tgwwsxo2rJbD1GTbdm72325Bq8FI3w= github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= +github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY= +github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= github.com/go-openapi/jsonreference v0.17.0/go.mod h1:g4xxGn04lDIRh0GJb5QlpE3HfopLOL6uZrK/VgnsK9I= github.com/go-openapi/jsonreference v0.19.0/go.mod h1:g4xxGn04lDIRh0GJb5QlpE3HfopLOL6uZrK/VgnsK9I= github.com/go-openapi/jsonreference v0.19.2/go.mod h1:jMjeRr2HHw6nAVajTXJ4eiUwohSTlpa0o73RUL1owJc= -github.com/go-openapi/jsonreference v0.19.3 h1:5cxNfTy0UVC3X8JL5ymxzyoUZmo8iZb+jeTWn7tUa8o= github.com/go-openapi/jsonreference v0.19.3/go.mod h1:rjx6GuL8TTa9VaixXglHmQmIL98+wF9xc8zWvFonSJ8= +github.com/go-openapi/jsonreference v0.19.6 h1:UBIxjkht+AWIgYzCDSv2GN+E/togfwXUJFRTWhl2Jjs= +github.com/go-openapi/jsonreference v0.19.6/go.mod h1:diGHMEHg2IqXZGKxqyvWdfWU/aim5Dprw5bqpKkTvns= github.com/go-openapi/spec v0.19.0/go.mod h1:XkF/MOi14NmjsfZ8VtAKf8pIlbZzyoTvZsdfssdxcBI= -github.com/go-openapi/spec v0.19.4 h1:ixzUSnHTd6hCemgtAJgluaTSGYpLNpJY4mA2DIkdOAo= github.com/go-openapi/spec v0.19.4/go.mod h1:FpwSN1ksY1eteniUU7X0N/BgJ7a4WvBFVA8Lj9mJglo= +github.com/go-openapi/spec v0.20.4 h1:O8hJrt0UMnhHcluhIdUgCLRWyM2x7QkBXRvOs7m+O1M= +github.com/go-openapi/spec v0.20.4/go.mod h1:faYFR1CvsJZ0mNsmsphTMSoRrNV3TEDoAM7FOEWeq8I= github.com/go-openapi/swag v0.17.0/go.mod h1:AByQ+nYG6gQg71GINrmuDXCPWdL640yX49/kXLo40Tg= github.com/go-openapi/swag v0.19.2/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= -github.com/go-openapi/swag v0.19.5 h1:lTz6Ys4CmqqCQmZPBlbQENR1/GucA2bzYTE12Pw4tFY= github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= +github.com/go-openapi/swag v0.19.15 h1:D2NRCBzS9/pEY3gP9Nl8aDqGUcPFrwG2p+CNFrLyrCM= +github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8ceBS/t7Q= @@ -283,6 +288,8 @@ github.com/jonboulle/clockwork v0.2.2 h1:UOGuzwb1PwsrDAObMuhUnj0p5ULPj8V/xJ7Kx9q github.com/jonboulle/clockwork v0.2.2/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8= github.com/joomcode/errorx v1.0.1 h1:CalpDWz14ZHd68fIqluJasJosAewpz2TFaJALrUxjrk= github.com/joomcode/errorx v1.0.1/go.mod h1:kgco15ekB6cs+4Xjzo7SPeXzx38PbJzBwbnu9qfVNHQ= +github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.5/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= @@ -298,13 +305,13 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxv github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.5/go.mod h1:9r2w37qlBe7rQ6e1fg1S/9xpWHSnaqNdHD3WcMdbPDA= github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= @@ -316,8 +323,9 @@ github.com/lucasb-eyer/go-colorful v1.0.3/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/mailru/easyjson v0.0.0-20180823135443-60711f1a8329/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= -github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e h1:hB2xlXdHp/pmPZq0y3QnmWAArdw9PqbmotexnWx/FU8= github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.7.6 h1:8yTIVnZgCoiM1TgqoeTl+LfU5Jg6/xL3QhGQnimLYnA= +github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= @@ -365,6 +373,8 @@ github.com/montanaflynn/stats v0.5.0/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFW github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5 h1:BvoENQQU+fZ9uukda/RzCAL/191HHwJA5b13R6diVlY= github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/oleiade/reflections v1.0.1 h1:D1XO3LVEYroYskEsoSiGItp9RUxG6jWnCVvrqH0HHQM= github.com/oleiade/reflections v1.0.1/go.mod h1:rdFxbxq4QXVZWj0F+e9jqjDkc7dbp97vkRixKo2JR60= @@ -501,8 +511,9 @@ github.com/swaggo/http-swagger v0.0.0-20200308142732-58ac5e232fba h1:lUPlXKqgbqT github.com/swaggo/http-swagger v0.0.0-20200308142732-58ac5e232fba/go.mod h1:O1lAbCgAAX/KZ80LM/OXwtWFI/5TvZlwxSg8Cq08PV0= github.com/swaggo/swag v1.5.1/go.mod h1:1Bl9F/ZBpVWh22nY0zmYyASPO1lI/zIwRDrpZU+tv8Y= github.com/swaggo/swag v1.6.3/go.mod h1:wcc83tB4Mb2aNiL/HP4MFeQdpHUrca+Rp/DRNgWAUio= -github.com/swaggo/swag v1.6.6-0.20200529100950-7c765ddd0476 h1:UjnSXdNPIG+5FJ6xLQODEdk7gSnJlMldu3sPAxxCO+4= github.com/swaggo/swag v1.6.6-0.20200529100950-7c765ddd0476/go.mod h1:xDhTyuFIujYiN3DKWC/H/83xcfHp+UE/IzWWampG7Zc= +github.com/swaggo/swag v1.8.3 h1:3pZSSCQ//gAH88lfmxM3Cd1+JCsxV8Md6f36b9hrZ5s= +github.com/swaggo/swag v1.8.3/go.mod h1:jMLeXOOmYyjk8PvHTsXBdrubsNd9gUJTTCzL5iBnseg= github.com/syndtr/goleveldb v1.0.1-0.20190318030020-c3a204f8e965 h1:1oFLiOyVl+W7bnBzGhf7BbIv9loSFQcieWWYIjLqcAw= github.com/syndtr/goleveldb v1.0.1-0.20190318030020-c3a204f8e965/go.mod h1:9OrXJhf154huy1nPWmuSrkgjPUtUNhA+Zmy+6AESzuA= github.com/thoas/go-funk v0.8.0 h1:JP9tKSvnpFVclYgDM0Is7FD9M4fhPvqA0s0BsXmzSRQ= @@ -535,8 +546,9 @@ github.com/unrolled/render v1.0.1 h1:VDDnQQVfBMsOsp3VaCJszSO0nkBIVEYoPWeRThk9spY github.com/unrolled/render v1.0.1/go.mod h1:gN9T0NhL4Bfbwu8ann7Ry/TGHYfosul+J0obPf6NBdM= github.com/urfave/cli v1.20.0 h1:fDqGv3UG/4jbVl/QkFwEdddtEDjh/5Ov6X+0B/3bPaw= github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= -github.com/urfave/cli/v2 v2.1.1 h1:Qt8FeAtxE/vfdrLmR3rxR6JRE0RoVmbXu8+6kZtYU4k= github.com/urfave/cli/v2 v2.1.1/go.mod h1:SE9GqnLQmjVa0iPEY0f1w3ygNIYcIJ0OKPMoW2caLfQ= +github.com/urfave/cli/v2 v2.3.0 h1:qph92Y649prgesehzOrQjdWyxFOp/QVM+6imKHad91M= +github.com/urfave/cli/v2 v2.3.0/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI= github.com/urfave/negroni v0.3.0 h1:PaXOb61mWeZJxc1Ji2xJjpVg9QfPo0rrB+lHyBxGNSU= github.com/urfave/negroni v0.3.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4= github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU= @@ -549,6 +561,7 @@ github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1: github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= @@ -600,8 +613,9 @@ golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200204104054-c9f3fb736b72/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/image v0.0.0-20200119044424-58c23975cae1 h1:5h3ngYt7+vXCDZCup/HkCQgW5XwmSvR/nA2JmJ0RErg= golang.org/x/image v0.0.0-20200119044424-58c23975cae1/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= @@ -615,8 +629,9 @@ golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKG golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.4.2 h1:Gz96sIWK3OalVv/I/qNygP42zyoKp3xptRVCWRFEBvo= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 h1:kQgndtyPBW/JIYERgdxfwMYh3AVStj88WQTlNDi2a+o= +golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -637,8 +652,12 @@ golang.org/x/net v0.0.0-20191002035440-2ec189313ef0/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 h1:4nGaVu0QrbjT/AK2PRLuQfQuh6DJve+pELhqTdAj3x0= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM= +golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4 h1:HVyaeDAYux4pnY+D/SiwmLOR36ewZ4iGQIIrtnuCjFA= +golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be h1:vEDujvNQGv4jgYKudGeI/+DAX4Jffq6hpD55MmoEvKs= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -677,16 +696,23 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210217105451-b926d437f341/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210420072515-93ed5bcd2bfe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac h1:oN6lz7iLW/YC7un8pq+9bOLyXrprv2+DKfkJY+2LJJw= golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E= +golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= -golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 h1:M73Iuj3xbbb9Uk1DYhzydthsj6oOd6l9bpuFcNoUvTs= @@ -720,8 +746,9 @@ golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roY golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20210112230658-8b4aab62c064/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= -golang.org/x/tools v0.1.5 h1:ouewzE6p+/VEB31YYnTbEJdi8pFqKp4P4n85vwo3DHA= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.1.10 h1:QjFRCZxdOhBJ/UNgnBZLbNV13DlbnK0quyivTnXJM20= +golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -748,8 +775,9 @@ google.golang.org/grpc v1.26.0 h1:2dTRdpdFEEhJYQD8EMLB61nnrzSCTbG38PhqdhvOltg= google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= @@ -770,6 +798,7 @@ gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/driver/mysql v1.0.6 h1:mA0XRPjIKi4bkE9nv+NKs6qj6QWOchqUSdWOcpd3x1E= diff --git a/server/api/admin.go b/server/api/admin.go index 2954874d7fd..1fa63c8ad9a 100644 --- a/server/api/admin.go +++ b/server/api/admin.go @@ -38,13 +38,13 @@ func newAdminHandler(svr *server.Server, rd *render.Render) *adminHandler { } } -// @Tags admin -// @Summary Drop a specific region from cache. -// @Param id path integer true "Region Id" -// @Produce json -// @Success 200 {string} string "The region is removed from server cache." -// @Failure 400 {string} string "The input is invalid." -// @Router /admin/cache/region/{id} [delete] +// @Tags admin +// @Summary Drop a specific region from cache. +// @Param id path integer true "Region Id" +// @Produce json +// @Success 200 {string} string "The region is removed from server cache." +// @Failure 400 {string} string "The input is invalid." +// @Router /admin/cache/region/{id} [delete] func (h *adminHandler) DeleteRegionCache(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) vars := mux.Vars(r) @@ -59,16 +59,16 @@ func (h *adminHandler) DeleteRegionCache(w http.ResponseWriter, r *http.Request) } // FIXME: details of input json body params -// @Tags admin -// @Summary Reset the ts. -// @Accept json -// @Param body body object true "json params" -// @Produce json -// @Success 200 {string} string "Reset ts successfully." -// @Failure 400 {string} string "The input is invalid." -// @Failure 403 {string} string "Reset ts is forbidden." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /admin/reset-ts [post] +// @Tags admin +// @Summary Reset the ts. +// @Accept json +// @Param body body object true "json params" +// @Produce json +// @Success 200 {string} string "Reset ts successfully." +// @Failure 400 {string} string "The input is invalid." +// @Failure 403 {string} string "Reset ts is forbidden." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /admin/reset-ts [post] func (h *adminHandler) ResetTS(w http.ResponseWriter, r *http.Request) { handler := h.svr.GetHandler() var input map[string]interface{} diff --git a/server/api/checker.go b/server/api/checker.go index 9a01ad9c83f..a62cedcf74c 100644 --- a/server/api/checker.go +++ b/server/api/checker.go @@ -36,16 +36,16 @@ func newCheckerHandler(svr *server.Server, r *render.Render) *checkerHandler { } // FIXME: details of input json body params -// @Tags checker -// @Summary Pause or resume region merge. -// @Accept json -// @Param name path string true "The name of the checker." -// @Param body body object true "json params" -// @Produce json -// @Success 200 {string} string "Pause or resume the scheduler successfully." -// @Failure 400 {string} string "Bad format request." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /checker/{name} [post] +// @Tags checker +// @Summary Pause or resume region merge. +// @Accept json +// @Param name path string true "The name of the checker." +// @Param body body object true "json params" +// @Produce json +// @Success 200 {string} string "Pause or resume the scheduler successfully." +// @Failure 400 {string} string "Bad format request." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /checker/{name} [post] func (c *checkerHandler) PauseOrResumeChecker(w http.ResponseWriter, r *http.Request) { var input map[string]int if err := apiutil.ReadJSONRespondError(c.r, w, r.Body, &input); err != nil { @@ -74,13 +74,13 @@ func (c *checkerHandler) PauseOrResumeChecker(w http.ResponseWriter, r *http.Req } // FIXME: details of input json body params -// @Tags checker -// @Summary Get if checker is paused -// @Param name path string true "The name of the scheduler." -// @Produce json -// @Success 200 {string} string "Pause or resume the scheduler successfully." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /checker/{name} [get] +// @Tags checker +// @Summary Get if checker is paused +// @Param name path string true "The name of the scheduler." +// @Produce json +// @Success 200 {string} string "Pause or resume the scheduler successfully." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /checker/{name} [get] func (c *checkerHandler) GetCheckerStatus(w http.ResponseWriter, r *http.Request) { name := mux.Vars(r)["name"] isPaused, err := c.IsCheckerPaused(name) diff --git a/server/api/cluster.go b/server/api/cluster.go index f7ff6251353..fcf972d56a7 100644 --- a/server/api/cluster.go +++ b/server/api/cluster.go @@ -33,21 +33,21 @@ func newClusterHandler(svr *server.Server, rd *render.Render) *clusterHandler { } } -// @Tags cluster -// @Summary Get cluster info. -// @Produce json -// @Success 200 {object} metapb.Cluster -// @Router /cluster [get] +// @Tags cluster +// @Summary Get cluster info. +// @Produce json +// @Success 200 {object} metapb.Cluster +// @Router /cluster [get] func (h *clusterHandler) GetCluster(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusOK, h.svr.GetCluster()) } -// @Tags cluster -// @Summary Get cluster status. -// @Produce json -// @Success 200 {object} cluster.Status -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /cluster/status [get] +// @Tags cluster +// @Summary Get cluster status. +// @Produce json +// @Success 200 {object} cluster.Status +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /cluster/status [get] func (h *clusterHandler) GetClusterStatus(w http.ResponseWriter, r *http.Request) { status, err := h.svr.GetClusterStatus() if err != nil { diff --git a/server/api/config.go b/server/api/config.go index 7ed3a9de56c..61cd27cd595 100644 --- a/server/api/config.go +++ b/server/api/config.go @@ -48,23 +48,23 @@ func newConfHandler(svr *server.Server, rd *render.Render) *confHandler { } } -// @Tags config -// @Summary Get full config. -// @Produce json -// @Success 200 {object} config.Config -// @Router /config [get] +// @Tags config +// @Summary Get full config. +// @Produce json +// @Success 200 {object} config.Config +// @Router /config [get] func (h *confHandler) GetConfig(w http.ResponseWriter, r *http.Request) { cfg := h.svr.GetConfig() cfg.Schedule.MaxMergeRegionKeys = cfg.Schedule.GetMaxMergeRegionKeys() h.rd.JSON(w, http.StatusOK, cfg) } -// @Tags config -// @Summary Get default config. -// @Produce json -// @Success 200 {object} config.Config -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /config/default [get] +// @Tags config +// @Summary Get default config. +// @Produce json +// @Success 200 {object} config.Config +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/default [get] func (h *confHandler) GetDefaultConfig(w http.ResponseWriter, r *http.Request) { config := config.NewConfig() err := config.Adjust(nil, false) @@ -76,16 +76,16 @@ func (h *confHandler) GetDefaultConfig(w http.ResponseWriter, r *http.Request) { } // FIXME: details of input json body params -// @Tags config -// @Summary Update a config item. -// @Accept json -// @Param ttlSecond query integer false "ttl". ttl param is only for BR and lightning now. Don't use it. -// @Param body body object false "json params" -// @Produce json -// @Success 200 {string} string "The config is updated." -// @Failure 400 {string} string "The input is invalid." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /config [post] +// @Tags config +// @Summary Update a config item. +// @Accept json +// @Param ttlSecond query integer false "ttl param is only for BR and lightning now. Don't use it." +// @Param body body object false "json params" +// @Produce json +// @Success 200 {string} string "The config is updated." +// @Failure 400 {string} string "The input is invalid." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config [post] func (h *confHandler) SetConfig(w http.ResponseWriter, r *http.Request) { cfg := h.svr.GetConfig() data, err := io.ReadAll(r.Body) @@ -272,27 +272,27 @@ func getConfigMap(cfg map[string]interface{}, key []string, value interface{}) m return cfg } -// @Tags config -// @Summary Get schedule config. -// @Produce json -// @Success 200 {object} config.ScheduleConfig -// @Router /config/schedule [get] +// @Tags config +// @Summary Get schedule config. +// @Produce json +// @Success 200 {object} config.ScheduleConfig +// @Router /config/schedule [get] func (h *confHandler) GetScheduleConfig(w http.ResponseWriter, r *http.Request) { cfg := h.svr.GetScheduleConfig() cfg.MaxMergeRegionKeys = cfg.GetMaxMergeRegionKeys() h.rd.JSON(w, http.StatusOK, cfg) } -// @Tags config -// @Summary Update a schedule config item. -// @Accept json -// @Param body body object string "json params" -// @Produce json -// @Success 200 {string} string "The config is updated." -// @Failure 400 {string} string "The input is invalid." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Failure 503 {string} string "PD server has no leader." -// @Router /config/schedule [post] +// @Tags config +// @Summary Update a schedule config item. +// @Accept json +// @Param body body object string "json params" +// @Produce json +// @Success 200 {string} string "The config is updated." +// @Failure 400 {string} string "The input is invalid." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Failure 503 {string} string "PD server has no leader." +// @Router /config/schedule [post] func (h *confHandler) SetScheduleConfig(w http.ResponseWriter, r *http.Request) { data, err := io.ReadAll(r.Body) r.Body.Close() @@ -335,25 +335,25 @@ func (h *confHandler) SetScheduleConfig(w http.ResponseWriter, r *http.Request) h.rd.JSON(w, http.StatusOK, "The config is updated.") } -// @Tags config -// @Summary Get replication config. -// @Produce json -// @Success 200 {object} config.ReplicationConfig -// @Router /config/replicate [get] +// @Tags config +// @Summary Get replication config. +// @Produce json +// @Success 200 {object} config.ReplicationConfig +// @Router /config/replicate [get] func (h *confHandler) GetReplicationConfig(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusOK, h.svr.GetReplicationConfig()) } -// @Tags config -// @Summary Update a replication config item. -// @Accept json -// @Param body body object string "json params" -// @Produce json -// @Success 200 {string} string "The config is updated." -// @Failure 400 {string} string "The input is invalid." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Failure 503 {string} string "PD server has no leader." -// @Router /config/replicate [post] +// @Tags config +// @Summary Update a replication config item. +// @Accept json +// @Param body body object string "json params" +// @Produce json +// @Success 200 {string} string "The config is updated." +// @Failure 400 {string} string "The input is invalid." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Failure 503 {string} string "PD server has no leader." +// @Router /config/replicate [post] func (h *confHandler) SetReplicationConfig(w http.ResponseWriter, r *http.Request) { config := h.svr.GetReplicationConfig() if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &config); err != nil { @@ -367,24 +367,24 @@ func (h *confHandler) SetReplicationConfig(w http.ResponseWriter, r *http.Reques h.rd.JSON(w, http.StatusOK, "The config is updated.") } -// @Tags config -// @Summary Get label property config. -// @Produce json -// @Success 200 {object} config.LabelPropertyConfig -// @Router /config/label-property [get] +// @Tags config +// @Summary Get label property config. +// @Produce json +// @Success 200 {object} config.LabelPropertyConfig +// @Router /config/label-property [get] func (h *confHandler) GetLabelPropertyConfig(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusOK, h.svr.GetLabelProperty()) } -// @Tags config -// @Summary Update label property config item. -// @Accept json -// @Param body body object string "json params" -// @Produce json -// @Success 200 {string} string "The config is updated." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Failure 503 {string} string "PD server has no leader." -// @Router /config/label-property [post] +// @Tags config +// @Summary Update label property config item. +// @Accept json +// @Param body body object string "json params" +// @Produce json +// @Success 200 {string} string "The config is updated." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Failure 503 {string} string "PD server has no leader." +// @Router /config/label-property [post] func (h *confHandler) SetLabelPropertyConfig(w http.ResponseWriter, r *http.Request) { input := make(map[string]string) if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &input); err != nil { @@ -407,24 +407,24 @@ func (h *confHandler) SetLabelPropertyConfig(w http.ResponseWriter, r *http.Requ h.rd.JSON(w, http.StatusOK, "The config is updated.") } -// @Tags config -// @Summary Get cluster version. -// @Produce json -// @Success 200 {object} semver.Version -// @Router /config/cluster-version [get] +// @Tags config +// @Summary Get cluster version. +// @Produce json +// @Success 200 {object} semver.Version +// @Router /config/cluster-version [get] func (h *confHandler) GetClusterVersion(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusOK, h.svr.GetClusterVersion()) } -// @Tags config -// @Summary Update cluster version. -// @Accept json -// @Param body body object string "json params" -// @Produce json -// @Success 200 {string} string "The cluster version is updated." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Failure 503 {string} string "PD server has no leader." -// @Router /config/cluster-version [post] +// @Tags config +// @Summary Update cluster version. +// @Accept json +// @Param body body object string "json params" +// @Produce json +// @Success 200 {string} string "The cluster version is updated." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Failure 503 {string} string "PD server has no leader." +// @Router /config/cluster-version [post] func (h *confHandler) SetClusterVersion(w http.ResponseWriter, r *http.Request) { input := make(map[string]string) if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &input); err != nil { @@ -444,23 +444,23 @@ func (h *confHandler) SetClusterVersion(w http.ResponseWriter, r *http.Request) h.rd.JSON(w, http.StatusOK, "The cluster version is updated.") } -// @Tags config -// @Summary Get replication mode config. -// @Produce json -// @Success 200 {object} config.ReplicationModeConfig -// @Router /config/replication-mode [get] +// @Tags config +// @Summary Get replication mode config. +// @Produce json +// @Success 200 {object} config.ReplicationModeConfig +// @Router /config/replication-mode [get] func (h *confHandler) GetReplicationModeConfig(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusOK, h.svr.GetReplicationModeConfig()) } -// @Tags config -// @Summary Set replication mode config. -// @Accept json -// @Param body body object string "json params" -// @Produce json -// @Success 200 {string} string "The replication mode config is updated." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /config/replication-mode [post] +// @Tags config +// @Summary Set replication mode config. +// @Accept json +// @Param body body object string "json params" +// @Produce json +// @Success 200 {string} string "The replication mode config is updated." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/replication-mode [post] func (h *confHandler) SetReplicationModeConfig(w http.ResponseWriter, r *http.Request) { config := h.svr.GetReplicationModeConfig() if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &config); err != nil { @@ -474,11 +474,11 @@ func (h *confHandler) SetReplicationModeConfig(w http.ResponseWriter, r *http.Re h.rd.JSON(w, http.StatusOK, "The replication mode config is updated.") } -// @Tags config -// @Summary Get PD server config. -// @Produce json -// @Success 200 {object} config.PDServerConfig -// @Router /config/pd-server [get] +// @Tags config +// @Summary Get PD server config. +// @Produce json +// @Success 200 {object} config.PDServerConfig +// @Router /config/pd-server [get] func (h *confHandler) GetPDServerConfig(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusOK, h.svr.GetPDServerConfig()) } diff --git a/server/api/health.go b/server/api/health.go index 982a663e934..fbbc4a3672f 100644 --- a/server/api/health.go +++ b/server/api/health.go @@ -43,11 +43,11 @@ func newHealthHandler(svr *server.Server, rd *render.Render) *healthHandler { } } -// @Summary Health status of PD servers. -// @Produce json -// @Success 200 {array} Health -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /health [get] +// @Summary Health status of PD servers. +// @Produce json +// @Success 200 {array} Health +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /health [get] func (h *healthHandler) GetHealthStatus(w http.ResponseWriter, r *http.Request) { client := h.svr.GetClient() members, err := cluster.GetMembers(client) @@ -73,6 +73,6 @@ func (h *healthHandler) GetHealthStatus(w http.ResponseWriter, r *http.Request) h.rd.JSON(w, http.StatusOK, healths) } -// @Summary Ping PD servers. -// @Router /ping [get] +// @Summary Ping PD servers. +// @Router /ping [get] func (h *healthHandler) Ping(w http.ResponseWriter, r *http.Request) {} diff --git a/server/api/hot_status.go b/server/api/hot_status.go index cbd537c7ac7..1b04638a94d 100644 --- a/server/api/hot_status.go +++ b/server/api/hot_status.go @@ -62,11 +62,11 @@ func newHotStatusHandler(handler *server.Handler, rd *render.Render) *hotStatusH } } -// @Tags hotspot -// @Summary List the hot write regions. -// @Produce json -// @Success 200 {object} statistics.StoreHotPeersInfos -// @Router /hotspot/regions/write [get] +// @Tags hotspot +// @Summary List the hot write regions. +// @Produce json +// @Success 200 {object} statistics.StoreHotPeersInfos +// @Router /hotspot/regions/write [get] func (h *hotStatusHandler) GetHotWriteRegions(w http.ResponseWriter, r *http.Request) { storeIDs := r.URL.Query()["store_id"] if len(storeIDs) < 1 { @@ -98,11 +98,11 @@ func (h *hotStatusHandler) GetHotWriteRegions(w http.ResponseWriter, r *http.Req h.rd.JSON(w, http.StatusOK, rc.GetHotWriteRegions(ids...)) } -// @Tags hotspot -// @Summary List the hot read regions. -// @Produce json -// @Success 200 {object} statistics.StoreHotPeersInfos -// @Router /hotspot/regions/read [get] +// @Tags hotspot +// @Summary List the hot read regions. +// @Produce json +// @Success 200 {object} statistics.StoreHotPeersInfos +// @Router /hotspot/regions/read [get] func (h *hotStatusHandler) GetHotReadRegions(w http.ResponseWriter, r *http.Request) { storeIDs := r.URL.Query()["store_id"] if len(storeIDs) < 1 { @@ -134,11 +134,11 @@ func (h *hotStatusHandler) GetHotReadRegions(w http.ResponseWriter, r *http.Requ h.rd.JSON(w, http.StatusOK, rc.GetHotReadRegions(ids...)) } -// @Tags hotspot -// @Summary List the hot stores. -// @Produce json -// @Success 200 {object} HotStoreStats -// @Router /hotspot/stores [get] +// @Tags hotspot +// @Summary List the hot stores. +// @Produce json +// @Success 200 {object} HotStoreStats +// @Router /hotspot/stores [get] func (h *hotStatusHandler) GetHotStores(w http.ResponseWriter, r *http.Request) { stats := HotStoreStats{ BytesWriteStats: make(map[uint64]float64), @@ -169,14 +169,14 @@ func (h *hotStatusHandler) GetHotStores(w http.ResponseWriter, r *http.Request) h.rd.JSON(w, http.StatusOK, stats) } -// @Tags hotspot -// @Summary List the history hot regions. -// @Accept json -// @Produce json -// @Success 200 {object} storage.HistoryHotRegions -// @Failure 400 {string} string "The input is invalid." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /hotspot/regions/history [get] +// @Tags hotspot +// @Summary List the history hot regions. +// @Accept json +// @Produce json +// @Success 200 {object} storage.HistoryHotRegions +// @Failure 400 {string} string "The input is invalid." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /hotspot/regions/history [get] func (h *hotStatusHandler) GetHistoryHotRegions(w http.ResponseWriter, r *http.Request) { data, err := io.ReadAll(r.Body) r.Body.Close() diff --git a/server/api/label.go b/server/api/label.go index f9cbc49c882..abaad02a4e3 100644 --- a/server/api/label.go +++ b/server/api/label.go @@ -37,11 +37,11 @@ func newLabelsHandler(svr *server.Server, rd *render.Render) *labelsHandler { } } -// @Tags label -// @Summary List all label values. -// @Produce json -// @Success 200 {array} metapb.StoreLabel -// @Router /labels [get] +// @Tags label +// @Summary List all label values. +// @Produce json +// @Success 200 {array} metapb.StoreLabel +// @Router /labels [get] func (h *labelsHandler) GetLabels(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) var labels []*metapb.StoreLabel @@ -59,14 +59,14 @@ func (h *labelsHandler) GetLabels(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusOK, labels) } -// @Tags label -// @Summary List stores that have specific label values. -// @Param name query string true "name of store label filter" -// @Param value query string true "value of store label filter" -// @Produce json -// @Success 200 {object} StoresInfo -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /labels/stores [get] +// @Tags label +// @Summary List stores that have specific label values. +// @Param name query string true "name of store label filter" +// @Param value query string true "value of store label filter" +// @Produce json +// @Success 200 {object} StoresInfo +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /labels/stores [get] func (h *labelsHandler) GetStoresByLabel(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) name := r.URL.Query().Get("name") diff --git a/server/api/log.go b/server/api/log.go index 793338aba4c..ed7a07e5279 100644 --- a/server/api/log.go +++ b/server/api/log.go @@ -37,16 +37,16 @@ func newLogHandler(svr *server.Server, rd *render.Render) *logHandler { } } -// @Tags admin -// @Summary Set log level. -// @Accept json -// @Param level body string true "json params" -// @Produce json -// @Success 200 {string} string "The log level is updated." -// @Failure 400 {string} string "The input is invalid." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Failure 503 {string} string "PD server has no leader." -// @Router /admin/log [post] +// @Tags admin +// @Summary Set log level. +// @Accept json +// @Param level body string true "json params" +// @Produce json +// @Success 200 {string} string "The log level is updated." +// @Failure 400 {string} string "The input is invalid." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Failure 503 {string} string "PD server has no leader." +// @Router /admin/log [post] func (h *logHandler) SetLogLevel(w http.ResponseWriter, r *http.Request) { var level string data, err := io.ReadAll(r.Body) diff --git a/server/api/member.go b/server/api/member.go index a6c5b7156f3..eaf743c0493 100644 --- a/server/api/member.go +++ b/server/api/member.go @@ -45,12 +45,12 @@ func newMemberHandler(svr *server.Server, rd *render.Render) *memberHandler { } } -// @Tags member -// @Summary List all PD servers in the cluster. -// @Produce json -// @Success 200 {object} pdpb.GetMembersResponse -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /members [get] +// @Tags member +// @Summary List all PD servers in the cluster. +// @Produce json +// @Success 200 {object} pdpb.GetMembersResponse +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /members [get] func (h *memberHandler) GetMembers(w http.ResponseWriter, r *http.Request) { members, err := getMembers(h.svr) if err != nil { @@ -107,15 +107,15 @@ func getMembers(svr *server.Server) (*pdpb.GetMembersResponse, error) { return members, nil } -// @Tags member -// @Summary Remove a PD server from the cluster. -// @Param name path string true "PD server name" -// @Produce json -// @Success 200 {string} string "The PD server is successfully removed." -// @Failure 400 {string} string "The input is invalid." -// @Failure 404 {string} string "The member does not exist." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /members/name/{name} [delete] +// @Tags member +// @Summary Remove a PD server from the cluster. +// @Param name path string true "PD server name" +// @Produce json +// @Success 200 {string} string "The PD server is successfully removed." +// @Failure 400 {string} string "The input is invalid." +// @Failure 404 {string} string "The member does not exist." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /members/name/{name} [delete] func (h *memberHandler) DeleteMemberByName(w http.ResponseWriter, r *http.Request) { client := h.svr.GetClient() @@ -161,14 +161,14 @@ func (h *memberHandler) DeleteMemberByName(w http.ResponseWriter, r *http.Reques h.rd.JSON(w, http.StatusOK, fmt.Sprintf("removed, pd: %s", name)) } -// @Tags member -// @Summary Remove a PD server from the cluster. -// @Param id path integer true "PD server Id" -// @Produce json -// @Success 200 {string} string "The PD server is successfully removed." -// @Failure 400 {string} string "The input is invalid." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /members/id/{id} [delete] +// @Tags member +// @Summary Remove a PD server from the cluster. +// @Param id path integer true "PD server Id" +// @Produce json +// @Success 200 {string} string "The PD server is successfully removed." +// @Failure 400 {string} string "The input is invalid." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /members/id/{id} [delete] func (h *memberHandler) DeleteMemberByID(w http.ResponseWriter, r *http.Request) { idStr := mux.Vars(r)["id"] id, err := strconv.ParseUint(idStr, 10, 64) @@ -201,17 +201,17 @@ func (h *memberHandler) DeleteMemberByID(w http.ResponseWriter, r *http.Request) } // FIXME: details of input json body params -// @Tags member -// @Summary Set leader priority of a PD member. -// @Accept json -// @Param name path string true "PD server name" -// @Param body body object true "json params" -// @Produce json -// @Success 200 {string} string "The leader priority is updated." -// @Failure 400 {string} string "The input is invalid." -// @Failure 404 {string} string "The member does not exist." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /members/name/{name} [post] +// @Tags member +// @Summary Set leader priority of a PD member. +// @Accept json +// @Param name path string true "PD server name" +// @Param body body object true "json params" +// @Produce json +// @Success 200 {string} string "The leader priority is updated." +// @Failure 400 {string} string "The input is invalid." +// @Failure 404 {string} string "The member does not exist." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /members/name/{name} [post] func (h *memberHandler) SetMemberPropertyByName(w http.ResponseWriter, r *http.Request) { members, membersErr := getMembers(h.svr) if membersErr != nil { @@ -265,21 +265,21 @@ func newLeaderHandler(svr *server.Server, rd *render.Render) *leaderHandler { } } -// @Tags leader -// @Summary Get the leader PD server of the cluster. -// @Produce json -// @Success 200 {object} pdpb.Member -// @Router /leader [get] +// @Tags leader +// @Summary Get the leader PD server of the cluster. +// @Produce json +// @Success 200 {object} pdpb.Member +// @Router /leader [get] func (h *leaderHandler) GetLeader(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusOK, h.svr.GetLeader()) } -// @Tags leader -// @Summary Transfer etcd leadership to another PD server. -// @Produce json -// @Success 200 {string} string "The resign command is submitted." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /leader/resign [post] +// @Tags leader +// @Summary Transfer etcd leadership to another PD server. +// @Produce json +// @Success 200 {string} string "The resign command is submitted." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /leader/resign [post] func (h *leaderHandler) ResignLeader(w http.ResponseWriter, r *http.Request) { err := h.svr.GetMember().ResignEtcdLeader(h.svr.Context(), h.svr.Name(), "") if err != nil { @@ -290,13 +290,13 @@ func (h *leaderHandler) ResignLeader(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusOK, "The resign command is submitted.") } -// @Tags leader -// @Summary Transfer etcd leadership to the specific PD server. -// @Param nextLeader path string true "PD server that transfer leader to" -// @Produce json -// @Success 200 {string} string "The transfer command is submitted." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /leader/transfer/{nextLeader} [post] +// @Tags leader +// @Summary Transfer etcd leadership to the specific PD server. +// @Param nextLeader path string true "PD server that transfer leader to" +// @Produce json +// @Success 200 {string} string "The transfer command is submitted." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /leader/transfer/{nextLeader} [post] func (h *leaderHandler) TransferLeader(w http.ResponseWriter, r *http.Request) { err := h.svr.GetMember().ResignEtcdLeader(h.svr.Context(), h.svr.Name(), mux.Vars(r)["next_leader"]) if err != nil { diff --git a/server/api/min_resolved_ts.go b/server/api/min_resolved_ts.go index c717f0a3b42..c367aabdd1f 100644 --- a/server/api/min_resolved_ts.go +++ b/server/api/min_resolved_ts.go @@ -41,12 +41,12 @@ type minResolvedTS struct { PersistInterval typeutil.Duration `json:"persist_interval,omitempty"` } -// @Tags min_resolved_ts -// @Summary Get cluster-level min resolved ts. -// @Produce json -// @Success 200 {array} minResolvedTS -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /min-resolved-ts [get] +// @Tags min_resolved_ts +// @Summary Get cluster-level min resolved ts. +// @Produce json +// @Success 200 {array} minResolvedTS +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /min-resolved-ts [get] func (h *minResolvedTSHandler) GetMinResolvedTS(w http.ResponseWriter, r *http.Request) { c := h.svr.GetRaftCluster() value := c.GetMinResolvedTS() diff --git a/server/api/operator.go b/server/api/operator.go index d41c04f8292..dc0af9f8eda 100644 --- a/server/api/operator.go +++ b/server/api/operator.go @@ -40,14 +40,14 @@ func newOperatorHandler(handler *server.Handler, r *render.Render) *operatorHand } } -// @Tags operator -// @Summary Get a Region's pending operator. -// @Param region_id path int true "A Region's Id" -// @Produce json -// @Success 200 {object} schedule.OperatorWithStatus -// @Failure 400 {string} string "The input is invalid." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /operators/{region_id} [get] +// @Tags operator +// @Summary Get a Region's pending operator. +// @Param region_id path int true "A Region's Id" +// @Produce json +// @Success 200 {object} schedule.OperatorWithStatus +// @Failure 400 {string} string "The input is invalid." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /operators/{region_id} [get] func (h *operatorHandler) GetOperatorsByRegion(w http.ResponseWriter, r *http.Request) { id := mux.Vars(r)["region_id"] @@ -66,13 +66,13 @@ func (h *operatorHandler) GetOperatorsByRegion(w http.ResponseWriter, r *http.Re h.r.JSON(w, http.StatusOK, op) } -// @Tags operator -// @Summary List pending operators. -// @Param kind query string false "Specify the operator kind." Enums(admin, leader, region) -// @Produce json -// @Success 200 {array} operator.Operator -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /operators [get] +// @Tags operator +// @Summary List pending operators. +// @Param kind query string false "Specify the operator kind." Enums(admin, leader, region) +// @Produce json +// @Success 200 {array} operator.Operator +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /operators [get] func (h *operatorHandler) GetOperators(w http.ResponseWriter, r *http.Request) { var ( results []*operator.Operator @@ -111,15 +111,15 @@ func (h *operatorHandler) GetOperators(w http.ResponseWriter, r *http.Request) { } // FIXME: details of input json body params -// @Tags operator -// @Summary Create an operator. -// @Accept json -// @Param body body object true "json params" -// @Produce json -// @Success 200 {string} string "The operator is created." -// @Failure 400 {string} string "The input is invalid." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /operators [post] +// @Tags operator +// @Summary Create an operator. +// @Accept json +// @Param body body object true "json params" +// @Produce json +// @Success 200 {string} string "The operator is created." +// @Failure 400 {string} string "The input is invalid." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /operators [post] func (h *operatorHandler) CreateOperator(w http.ResponseWriter, r *http.Request) { var input map[string]interface{} if err := apiutil.ReadJSONRespondError(h.r, w, r.Body, &input); err != nil { @@ -320,14 +320,14 @@ func (h *operatorHandler) CreateOperator(w http.ResponseWriter, r *http.Request) h.r.JSON(w, http.StatusOK, "The operator is created.") } -// @Tags operator -// @Summary Cancel a Region's pending operator. -// @Param region_id path int true "A Region's Id" -// @Produce json -// @Success 200 {string} string "The pending operator is canceled." -// @Failure 400 {string} string "The input is invalid." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /operators/{region_id} [delete] +// @Tags operator +// @Summary Cancel a Region's pending operator. +// @Param region_id path int true "A Region's Id" +// @Produce json +// @Success 200 {string} string "The pending operator is canceled." +// @Failure 400 {string} string "The input is invalid." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /operators/{region_id} [delete] func (h *operatorHandler) DeleteOperatorByRegion(w http.ResponseWriter, r *http.Request) { id := mux.Vars(r)["region_id"] @@ -345,14 +345,14 @@ func (h *operatorHandler) DeleteOperatorByRegion(w http.ResponseWriter, r *http. h.r.JSON(w, http.StatusOK, "The pending operator is canceled.") } -// @Tags operator -// @Summary lists the finished operators since the given timestamp in second. -// @Param from query integer false "From Unix timestamp" -// @Produce json -// @Success 200 {object} []operator.OpRecord -// @Failure 400 {string} string "The request is invalid." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /operators/records [get] +// @Tags operator +// @Summary lists the finished operators since the given timestamp in second. +// @Param from query integer false "From Unix timestamp" +// @Produce json +// @Success 200 {object} []operator.OpRecord +// @Failure 400 {string} string "The request is invalid." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /operators/records [get] func (h *operatorHandler) GetOperatorRecords(w http.ResponseWriter, r *http.Request) { var from time.Time if fromStr := r.URL.Query()["from"]; len(fromStr) > 0 { diff --git a/server/api/plugin.go b/server/api/plugin.go index 3e1372ba8f5..16894304e9b 100644 --- a/server/api/plugin.go +++ b/server/api/plugin.go @@ -38,29 +38,29 @@ func newPluginHandler(handler *server.Handler, rd *render.Render) *pluginHandler } // FIXME: details of input json body params -// @Tags plugin -// @Summary Load plugin. -// @Accept json -// @Param body body object true "json params" -// @Produce json -// @Success 200 {string} string "Load plugin success." -// @Failure 400 {string} string "The input is invalid." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /plugin [post] +// @Tags plugin +// @Summary Load plugin. +// @Accept json +// @Param body body object true "json params" +// @Produce json +// @Success 200 {string} string "Load plugin success." +// @Failure 400 {string} string "The input is invalid." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /plugin [post] func (h *pluginHandler) LoadPlugin(w http.ResponseWriter, r *http.Request) { h.processPluginCommand(w, r, cluster.PluginLoad) } // FIXME: details of input json body params -// @Tags plugin -// @Summary Unload plugin. -// @Accept json -// @Param body body object true "json params" -// @Produce json -// @Success 200 {string} string "Load/Unload plugin successfully." -// @Failure 400 {string} string "The input is invalid." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /plugin [delete] +// @Tags plugin +// @Summary Unload plugin. +// @Accept json +// @Param body body object true "json params" +// @Produce json +// @Success 200 {string} string "Load/Unload plugin successfully." +// @Failure 400 {string} string "The input is invalid." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /plugin [delete] func (h *pluginHandler) UnloadPlugin(w http.ResponseWriter, r *http.Request) { h.processPluginCommand(w, r, cluster.PluginUnload) } diff --git a/server/api/pprof.go b/server/api/pprof.go index 9dd371badb1..0c180dda24c 100644 --- a/server/api/pprof.go +++ b/server/api/pprof.go @@ -47,10 +47,10 @@ func newPprofHandler(svr *server.Server, rd *render.Render) *pprofHandler { } } -// @Tags debug -// @Summary debug zip of PD servers. -// @Produce application/octet-stream -// @Router /debug/pprof/zip [get] +// @Tags debug +// @Summary debug zip of PD servers. +// @Produce application/octet-stream +// @Router /debug/pprof/zip [get] func (h *pprofHandler) PProfZip(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="pd_debug"`+time.Now().Format("20060102_150405")+".zip")) @@ -145,65 +145,65 @@ func (h *pprofHandler) PProfZip(w http.ResponseWriter, r *http.Request) { } } -// @Tags debug -// @Summary debug profile of PD servers. -// @Router /debug/pprof/profile [get] +// @Tags debug +// @Summary debug profile of PD servers. +// @Router /debug/pprof/profile [get] func (h *pprofHandler) PProfProfile(w http.ResponseWriter, r *http.Request) { pp.Profile(w, r) } -// @Tags debug -// @Summary debug trace of PD servers. -// @Router /debug/pprof/trace [get] +// @Tags debug +// @Summary debug trace of PD servers. +// @Router /debug/pprof/trace [get] func (h *pprofHandler) PProfTrace(w http.ResponseWriter, r *http.Request) { pp.Trace(w, r) } -// @Tags debug -// @Summary debug symbol of PD servers. -// @Router /debug/pprof/symbol [get] +// @Tags debug +// @Summary debug symbol of PD servers. +// @Router /debug/pprof/symbol [get] func (h *pprofHandler) PProfSymbol(w http.ResponseWriter, r *http.Request) { pp.Symbol(w, r) } -// @Tags debug -// @Summary debug heap of PD servers. -// @Router /debug/pprof/heap [get] +// @Tags debug +// @Summary debug heap of PD servers. +// @Router /debug/pprof/heap [get] func (h *pprofHandler) PProfHeap(w http.ResponseWriter, r *http.Request) { pp.Handler("heap").ServeHTTP(w, r) } -// @Tags debug -// @Summary debug mutex of PD servers. -// @Router /debug/pprof/mutex [get] +// @Tags debug +// @Summary debug mutex of PD servers. +// @Router /debug/pprof/mutex [get] func (h *pprofHandler) PProfMutex(w http.ResponseWriter, r *http.Request) { pp.Handler("mutex").ServeHTTP(w, r) } -// @Tags debug -// @Summary debug allocs of PD servers. -// @Router /debug/pprof/allocs [get] +// @Tags debug +// @Summary debug allocs of PD servers. +// @Router /debug/pprof/allocs [get] func (h *pprofHandler) PProfAllocs(w http.ResponseWriter, r *http.Request) { pp.Handler("allocs").ServeHTTP(w, r) } -// @Tags debug -// @Summary debug block of PD servers. -// @Router /debug/pprof/block [get] +// @Tags debug +// @Summary debug block of PD servers. +// @Router /debug/pprof/block [get] func (h *pprofHandler) PProfBlock(w http.ResponseWriter, r *http.Request) { pp.Handler("block").ServeHTTP(w, r) } -// @Tags debug -// @Summary debug goroutine of PD servers. -// @Router /debug/pprof/goroutine [get] +// @Tags debug +// @Summary debug goroutine of PD servers. +// @Router /debug/pprof/goroutine [get] func (h *pprofHandler) PProfGoroutine(w http.ResponseWriter, r *http.Request) { pp.Handler("goroutine").ServeHTTP(w, r) } -// @Tags debug -// @Summary debug threadcreate of PD servers. -// @Router /debug/pprof/threadcreate [get] +// @Tags debug +// @Summary debug threadcreate of PD servers. +// @Router /debug/pprof/threadcreate [get] func (h *pprofHandler) PProfThreadcreate(w http.ResponseWriter, r *http.Request) { pp.Handler("threadcreate").ServeHTTP(w, r) } diff --git a/server/api/region.go b/server/api/region.go index fa25ca1bd17..a4ae3fe0df9 100644 --- a/server/api/region.go +++ b/server/api/region.go @@ -209,13 +209,13 @@ func newRegionHandler(svr *server.Server, rd *render.Render) *regionHandler { } } -// @Tags region -// @Summary Search for a region by region ID. -// @Param id path integer true "Region Id" -// @Produce json -// @Success 200 {object} RegionInfo -// @Failure 400 {string} string "The input is invalid." -// @Router /region/id/{id} [get] +// @Tags region +// @Summary Search for a region by region ID. +// @Param id path integer true "Region Id" +// @Produce json +// @Success 200 {object} RegionInfo +// @Failure 400 {string} string "The input is invalid." +// @Router /region/id/{id} [get] func (h *regionHandler) GetRegionByID(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) @@ -231,12 +231,12 @@ func (h *regionHandler) GetRegionByID(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusOK, NewRegionInfo(regionInfo)) } -// @Tags region -// @Summary Search for a region by a key. GetRegion is named to be consistent with gRPC -// @Param key path string true "Region key" -// @Produce json -// @Success 200 {object} RegionInfo -// @Router /region/key/{key} [get] +// @Tags region +// @Summary Search for a region by a key. GetRegion is named to be consistent with gRPC +// @Param key path string true "Region key" +// @Produce json +// @Success 200 {object} RegionInfo +// @Router /region/key/{key} [get] func (h *regionHandler) GetRegion(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) vars := mux.Vars(r) @@ -250,14 +250,14 @@ func (h *regionHandler) GetRegion(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusOK, NewRegionInfo(regionInfo)) } -// @Tags region -// @Summary Check if regions in the given key ranges are replicated. Returns 'REPLICATED', 'INPROGRESS', or 'PENDING'. 'PENDING' means that there is at least one region pending for scheduling. Similarly, 'INPROGRESS' means there is at least one region in scheduling. -// @Param startKey query string true "Regions start key, hex encoded" -// @Param endKey query string true "Regions end key, hex encoded" -// @Produce plain -// @Success 200 {string} string "INPROGRESS" -// @Failure 400 {string} string "The input is invalid." -// @Router /regions/replicated [get] +// @Tags region +// @Summary Check if regions in the given key ranges are replicated. Returns 'REPLICATED', 'INPROGRESS', or 'PENDING'. 'PENDING' means that there is at least one region pending for scheduling. Similarly, 'INPROGRESS' means there is at least one region in scheduling. +// @Param startKey query string true "Regions start key, hex encoded" +// @Param endKey query string true "Regions end key, hex encoded" +// @Produce plain +// @Success 200 {string} string "INPROGRESS" +// @Failure 400 {string} string "The input is invalid." +// @Router /regions/replicated [get] func (h *regionsHandler) CheckRegionsReplicated(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) @@ -318,11 +318,11 @@ func convertToAPIRegions(regions []*core.RegionInfo) *RegionsInfo { } } -// @Tags region -// @Summary List all regions in the cluster. -// @Produce json -// @Success 200 {object} RegionsInfo -// @Router /regions [get] +// @Tags region +// @Summary List all regions in the cluster. +// @Produce json +// @Success 200 {object} RegionsInfo +// @Router /regions [get] func (h *regionsHandler) GetRegions(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) regions := rc.GetRegions() @@ -330,15 +330,15 @@ func (h *regionsHandler) GetRegions(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusOK, regionsInfo) } -// @Tags region -// @Summary List regions in a given range [startKey, endKey). -// @Param key query string true "Region range start key" -// @Param endkey query string true "Region range end key" -// @Param limit query integer false "Limit count" default(16) -// @Produce json -// @Success 200 {object} RegionsInfo -// @Failure 400 {string} string "The input is invalid." -// @Router /regions/key [get] +// @Tags region +// @Summary List regions in a given range [startKey, endKey). +// @Param key query string true "Region range start key" +// @Param endkey query string true "Region range end key" +// @Param limit query integer false "Limit count" default(16) +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 400 {string} string "The input is invalid." +// @Router /regions/key [get] func (h *regionsHandler) ScanRegions(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) startKey := r.URL.Query().Get("key") @@ -361,24 +361,24 @@ func (h *regionsHandler) ScanRegions(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusOK, regionsInfo) } -// @Tags region -// @Summary Get count of regions. -// @Produce json -// @Success 200 {object} RegionsInfo -// @Router /regions/count [get] +// @Tags region +// @Summary Get count of regions. +// @Produce json +// @Success 200 {object} RegionsInfo +// @Router /regions/count [get] func (h *regionsHandler) GetRegionCount(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) count := rc.GetRegionCount() h.rd.JSON(w, http.StatusOK, &RegionsInfo{Count: count}) } -// @Tags region -// @Summary List all regions of a specific store. -// @Param id path integer true "Store Id" -// @Produce json -// @Success 200 {object} RegionsInfo -// @Failure 400 {string} string "The input is invalid." -// @Router /regions/store/{id} [get] +// @Tags region +// @Summary List all regions of a specific store. +// @Param id path integer true "Store Id" +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 400 {string} string "The input is invalid." +// @Router /regions/store/{id} [get] func (h *regionsHandler) GetStoreRegions(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) @@ -393,12 +393,12 @@ func (h *regionsHandler) GetStoreRegions(w http.ResponseWriter, r *http.Request) h.rd.JSON(w, http.StatusOK, regionsInfo) } -// @Tags region -// @Summary List all regions that miss peer. -// @Produce json -// @Success 200 {object} RegionsInfo -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /regions/check/miss-peer [get] +// @Tags region +// @Summary List all regions that miss peer. +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /regions/check/miss-peer [get] func (h *regionsHandler) GetMissPeerRegions(w http.ResponseWriter, r *http.Request) { handler := h.svr.GetHandler() regions, err := handler.GetRegionsByType(statistics.MissPeer) @@ -410,12 +410,12 @@ func (h *regionsHandler) GetMissPeerRegions(w http.ResponseWriter, r *http.Reque h.rd.JSON(w, http.StatusOK, regionsInfo) } -// @Tags region -// @Summary List all regions that has extra peer. -// @Produce json -// @Success 200 {object} RegionsInfo -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /regions/check/extra-peer [get] +// @Tags region +// @Summary List all regions that has extra peer. +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /regions/check/extra-peer [get] func (h *regionsHandler) GetExtraPeerRegions(w http.ResponseWriter, r *http.Request) { handler := h.svr.GetHandler() regions, err := handler.GetRegionsByType(statistics.ExtraPeer) @@ -427,12 +427,12 @@ func (h *regionsHandler) GetExtraPeerRegions(w http.ResponseWriter, r *http.Requ h.rd.JSON(w, http.StatusOK, regionsInfo) } -// @Tags region -// @Summary List all regions that has pending peer. -// @Produce json -// @Success 200 {object} RegionsInfo -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /regions/check/pending-peer [get] +// @Tags region +// @Summary List all regions that has pending peer. +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /regions/check/pending-peer [get] func (h *regionsHandler) GetPendingPeerRegions(w http.ResponseWriter, r *http.Request) { handler := h.svr.GetHandler() regions, err := handler.GetRegionsByType(statistics.PendingPeer) @@ -444,12 +444,12 @@ func (h *regionsHandler) GetPendingPeerRegions(w http.ResponseWriter, r *http.Re h.rd.JSON(w, http.StatusOK, regionsInfo) } -// @Tags region -// @Summary List all regions that has down peer. -// @Produce json -// @Success 200 {object} RegionsInfo -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /regions/check/down-peer [get] +// @Tags region +// @Summary List all regions that has down peer. +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /regions/check/down-peer [get] func (h *regionsHandler) GetDownPeerRegions(w http.ResponseWriter, r *http.Request) { handler := h.svr.GetHandler() regions, err := handler.GetRegionsByType(statistics.DownPeer) @@ -461,12 +461,12 @@ func (h *regionsHandler) GetDownPeerRegions(w http.ResponseWriter, r *http.Reque h.rd.JSON(w, http.StatusOK, regionsInfo) } -// @Tags region -// @Summary List all regions that has learner peer. -// @Produce json -// @Success 200 {object} RegionsInfo -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /regions/check/learner-peer [get] +// @Tags region +// @Summary List all regions that has learner peer. +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /regions/check/learner-peer [get] func (h *regionsHandler) GetLearnerPeerRegions(w http.ResponseWriter, r *http.Request) { handler := h.svr.GetHandler() regions, err := handler.GetRegionsByType(statistics.LearnerPeer) @@ -478,12 +478,12 @@ func (h *regionsHandler) GetLearnerPeerRegions(w http.ResponseWriter, r *http.Re h.rd.JSON(w, http.StatusOK, regionsInfo) } -// @Tags region -// @Summary List all regions that has offline peer. -// @Produce json -// @Success 200 {object} RegionsInfo -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /regions/check/offline-peer [get] +// @Tags region +// @Summary List all regions that has offline peer. +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /regions/check/offline-peer [get] func (h *regionsHandler) GetOfflinePeerRegions(w http.ResponseWriter, r *http.Request) { handler := h.svr.GetHandler() regions, err := handler.GetOfflinePeer(statistics.OfflinePeer) @@ -495,12 +495,12 @@ func (h *regionsHandler) GetOfflinePeerRegions(w http.ResponseWriter, r *http.Re h.rd.JSON(w, http.StatusOK, regionsInfo) } -// @Tags region -// @Summary List all regions that are oversized. -// @Produce json -// @Success 200 {object} RegionsInfo -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /regions/check/oversized-region [get] +// @Tags region +// @Summary List all regions that are oversized. +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /regions/check/oversized-region [get] func (h *regionsHandler) GetOverSizedRegions(w http.ResponseWriter, r *http.Request) { handler := h.svr.GetHandler() regions, err := handler.GetRegionsByType(statistics.OversizedRegion) @@ -512,12 +512,12 @@ func (h *regionsHandler) GetOverSizedRegions(w http.ResponseWriter, r *http.Requ h.rd.JSON(w, http.StatusOK, regionsInfo) } -// @Tags region -// @Summary List all regions that are undersized. -// @Produce json -// @Success 200 {object} RegionsInfo -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /regions/check/undersized-region [get] +// @Tags region +// @Summary List all regions that are undersized. +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /regions/check/undersized-region [get] func (h *regionsHandler) GetUndersizedRegions(w http.ResponseWriter, r *http.Request) { handler := h.svr.GetHandler() regions, err := handler.GetRegionsByType(statistics.UndersizedRegion) @@ -529,12 +529,12 @@ func (h *regionsHandler) GetUndersizedRegions(w http.ResponseWriter, r *http.Req h.rd.JSON(w, http.StatusOK, regionsInfo) } -// @Tags region -// @Summary List all empty regions. -// @Produce json -// @Success 200 {object} RegionsInfo -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /regions/check/empty-region [get] +// @Tags region +// @Summary List all empty regions. +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /regions/check/empty-region [get] func (h *regionsHandler) GetEmptyRegions(w http.ResponseWriter, r *http.Request) { handler := h.svr.GetHandler() regions, err := handler.GetRegionsByType(statistics.EmptyRegion) @@ -566,13 +566,13 @@ func (hist histSlice) Less(i, j int) bool { return hist[i].Start < hist[j].Start } -// @Tags region -// @Summary Get size of histogram. -// @Param bound query integer false "Size bound of region histogram" minimum(1) -// @Produce json -// @Success 200 {array} histItem -// @Failure 400 {string} string "The input is invalid." -// @Router /regions/check/hist-size [get] +// @Tags region +// @Summary Get size of histogram. +// @Param bound query integer false "Size bound of region histogram" minimum(1) +// @Produce json +// @Success 200 {array} histItem +// @Failure 400 {string} string "The input is invalid." +// @Router /regions/check/hist-size [get] func (h *regionsHandler) GetSizeHistogram(w http.ResponseWriter, r *http.Request) { bound := minRegionHistogramSize bound, err := calBound(bound, r) @@ -590,13 +590,13 @@ func (h *regionsHandler) GetSizeHistogram(w http.ResponseWriter, r *http.Request h.rd.JSON(w, http.StatusOK, histItems) } -// @Tags region -// @Summary Get keys of histogram. -// @Param bound query integer false "Key bound of region histogram" minimum(1000) -// @Produce json -// @Success 200 {array} histItem -// @Failure 400 {string} string "The input is invalid." -// @Router /regions/check/hist-keys [get] +// @Tags region +// @Summary Get keys of histogram. +// @Param bound query integer false "Key bound of region histogram" minimum(1000) +// @Produce json +// @Success 200 {array} histItem +// @Failure 400 {string} string "The input is invalid." +// @Router /regions/check/hist-keys [get] func (h *regionsHandler) GetKeysHistogram(w http.ResponseWriter, r *http.Request) { bound := minRegionHistogramKeys bound, err := calBound(bound, r) @@ -649,24 +649,24 @@ func calHist(bound int, list *[]int64) *[]*histItem { return &histItems } -// @Tags region -// @Summary List all range holes whitout any region info. -// @Produce json -// @Success 200 {object} [][]string -// @Router /regions/range-holes [get] +// @Tags region +// @Summary List all range holes whitout any region info. +// @Produce json +// @Success 200 {object} [][]string +// @Router /regions/range-holes [get] func (h *regionsHandler) GetRangeHoles(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) h.rd.JSON(w, http.StatusOK, rc.GetRangeHoles()) } -// @Tags region -// @Summary List sibling regions of a specific region. -// @Param id path integer true "Region Id" -// @Produce json -// @Success 200 {object} RegionsInfo -// @Failure 400 {string} string "The input is invalid." -// @Failure 404 {string} string "The region does not exist." -// @Router /regions/sibling/{id} [get] +// @Tags region +// @Summary List sibling regions of a specific region. +// @Param id path integer true "Region Id" +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 400 {string} string "The input is invalid." +// @Failure 404 {string} string "The region does not exist." +// @Router /regions/sibling/{id} [get] func (h *regionsHandler) GetRegionSiblings(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) @@ -694,89 +694,89 @@ const ( minRegionHistogramKeys = 1000 ) -// @Tags region -// @Summary List regions with the highest write flow. -// @Param limit query integer false "Limit count" default(16) -// @Produce json -// @Success 200 {object} RegionsInfo -// @Failure 400 {string} string "The input is invalid." -// @Router /regions/writeflow [get] +// @Tags region +// @Summary List regions with the highest write flow. +// @Param limit query integer false "Limit count" default(16) +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 400 {string} string "The input is invalid." +// @Router /regions/writeflow [get] func (h *regionsHandler) GetTopWriteFlowRegions(w http.ResponseWriter, r *http.Request) { h.GetTopNRegions(w, r, func(a, b *core.RegionInfo) bool { return a.GetBytesWritten() < b.GetBytesWritten() }) } -// @Tags region -// @Summary List regions with the highest read flow. -// @Param limit query integer false "Limit count" default(16) -// @Produce json -// @Success 200 {object} RegionsInfo -// @Failure 400 {string} string "The input is invalid." -// @Router /regions/readflow [get] +// @Tags region +// @Summary List regions with the highest read flow. +// @Param limit query integer false "Limit count" default(16) +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 400 {string} string "The input is invalid." +// @Router /regions/readflow [get] func (h *regionsHandler) GetTopReadFlowRegions(w http.ResponseWriter, r *http.Request) { h.GetTopNRegions(w, r, func(a, b *core.RegionInfo) bool { return a.GetBytesRead() < b.GetBytesRead() }) } -// @Tags region -// @Summary List regions with the largest conf version. -// @Param limit query integer false "Limit count" default(16) -// @Produce json -// @Success 200 {object} RegionsInfo -// @Failure 400 {string} string "The input is invalid." -// @Router /regions/confver [get] +// @Tags region +// @Summary List regions with the largest conf version. +// @Param limit query integer false "Limit count" default(16) +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 400 {string} string "The input is invalid." +// @Router /regions/confver [get] func (h *regionsHandler) GetTopConfVerRegions(w http.ResponseWriter, r *http.Request) { h.GetTopNRegions(w, r, func(a, b *core.RegionInfo) bool { return a.GetMeta().GetRegionEpoch().GetConfVer() < b.GetMeta().GetRegionEpoch().GetConfVer() }) } -// @Tags region -// @Summary List regions with the largest version. -// @Param limit query integer false "Limit count" default(16) -// @Produce json -// @Success 200 {object} RegionsInfo -// @Failure 400 {string} string "The input is invalid." -// @Router /regions/version [get] +// @Tags region +// @Summary List regions with the largest version. +// @Param limit query integer false "Limit count" default(16) +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 400 {string} string "The input is invalid." +// @Router /regions/version [get] func (h *regionsHandler) GetTopVersionRegions(w http.ResponseWriter, r *http.Request) { h.GetTopNRegions(w, r, func(a, b *core.RegionInfo) bool { return a.GetMeta().GetRegionEpoch().GetVersion() < b.GetMeta().GetRegionEpoch().GetVersion() }) } -// @Tags region -// @Summary List regions with the largest size. -// @Param limit query integer false "Limit count" default(16) -// @Produce json -// @Success 200 {object} RegionsInfo -// @Failure 400 {string} string "The input is invalid." -// @Router /regions/size [get] +// @Tags region +// @Summary List regions with the largest size. +// @Param limit query integer false "Limit count" default(16) +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 400 {string} string "The input is invalid." +// @Router /regions/size [get] func (h *regionsHandler) GetTopSizeRegions(w http.ResponseWriter, r *http.Request) { h.GetTopNRegions(w, r, func(a, b *core.RegionInfo) bool { return a.GetApproximateSize() < b.GetApproximateSize() }) } -// @Tags region -// @Summary List regions with the largest keys. -// @Param limit query integer false "Limit count" default(16) -// @Produce json -// @Success 200 {object} RegionsInfo -// @Failure 400 {string} string "The input is invalid." -// @Router /regions/keys [get] +// @Tags region +// @Summary List regions with the largest keys. +// @Param limit query integer false "Limit count" default(16) +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 400 {string} string "The input is invalid." +// @Router /regions/keys [get] func (h *regionsHandler) GetTopKeysRegions(w http.ResponseWriter, r *http.Request) { h.GetTopNRegions(w, r, func(a, b *core.RegionInfo) bool { return a.GetApproximateKeys() < b.GetApproximateKeys() }) } -// @Tags region -// @Summary Accelerate regions scheduling a in given range, only receive hex format for keys -// @Accept json -// @Param body body object true "json params" -// @Param limit query integer false "Limit count" default(256) -// @Produce json -// @Success 200 {string} string "Accelerate regions scheduling in a given range [startKey, endKey)" -// @Failure 400 {string} string "The input is invalid." -// @Router /regions/accelerate-schedule [post] +// @Tags region +// @Summary Accelerate regions scheduling a in given range, only receive hex format for keys +// @Accept json +// @Param body body object true "json params" +// @Param limit query integer false "Limit count" default(256) +// @Produce json +// @Success 200 {string} string "Accelerate regions scheduling in a given range [startKey, endKey)" +// @Failure 400 {string} string "The input is invalid." +// @Router /regions/accelerate-schedule [post] func (h *regionsHandler) AccelerateRegionsScheduleInRange(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) var input map[string]interface{} @@ -838,14 +838,14 @@ func (h *regionsHandler) GetTopNRegions(w http.ResponseWriter, r *http.Request, h.rd.JSON(w, http.StatusOK, regionsInfo) } -// @Tags region -// @Summary Scatter regions by given key ranges or regions id distributed by given group with given retry limit -// @Accept json -// @Param body body object true "json params" -// @Produce json -// @Success 200 {string} string "Scatter regions by given key ranges or regions id distributed by given group with given retry limit" -// @Failure 400 {string} string "The input is invalid." -// @Router /regions/scatter [post] +// @Tags region +// @Summary Scatter regions by given key ranges or regions id distributed by given group with given retry limit +// @Accept json +// @Param body body object true "json params" +// @Produce json +// @Success 200 {string} string "Scatter regions by given key ranges or regions id distributed by given group with given retry limit" +// @Failure 400 {string} string "The input is invalid." +// @Router /regions/scatter [post] func (h *regionsHandler) ScatterRegions(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) var input map[string]interface{} @@ -919,14 +919,14 @@ func (h *regionsHandler) ScatterRegions(w http.ResponseWriter, r *http.Request) h.rd.JSON(w, http.StatusOK, &s) } -// @Tags region -// @Summary Split regions with given split keys -// @Accept json -// @Param body body object true "json params" -// @Produce json -// @Success 200 {string} string "Split regions with given split keys" -// @Failure 400 {string} string "The input is invalid." -// @Router /regions/split [post] +// @Tags region +// @Summary Split regions with given split keys +// @Accept json +// @Param body body object true "json params" +// @Produce json +// @Success 200 {string} string "Split regions with given split keys" +// @Failure 400 {string} string "The input is invalid." +// @Router /regions/split [post] func (h *regionsHandler) SplitRegions(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) var input map[string]interface{} diff --git a/server/api/region_label.go b/server/api/region_label.go index 6eae4206914..539f4126c1f 100644 --- a/server/api/region_label.go +++ b/server/api/region_label.go @@ -39,26 +39,26 @@ func newRegionLabelHandler(s *server.Server, rd *render.Render) *regionLabelHand } } -// @Tags region_label -// @Summary List all label rules of cluster. -// @Produce json -// @Success 200 {array} labeler.LabelRule -// @Router /config/region-label/rules [get] +// @Tags region_label +// @Summary List all label rules of cluster. +// @Produce json +// @Success 200 {array} labeler.LabelRule +// @Router /config/region-label/rules [get] func (h *regionLabelHandler) GetAllRegionLabelRules(w http.ResponseWriter, r *http.Request) { cluster := getCluster(r) rules := cluster.GetRegionLabeler().GetAllLabelRules() h.rd.JSON(w, http.StatusOK, rules) } -// @Tags region_label -// @Summary Update region label rules in batch. -// @Accept json -// @Param patch body labeler.LabelRulePatch true "Patch to update rules" -// @Produce json -// @Success 200 {string} string "Update region label rules successfully." -// @Failure 400 {string} string "The input is invalid." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /config/region-label/rules [patch] +// @Tags region_label +// @Summary Update region label rules in batch. +// @Accept json +// @Param patch body labeler.LabelRulePatch true "Patch to update rules" +// @Produce json +// @Success 200 {string} string "Update region label rules successfully." +// @Failure 400 {string} string "The input is invalid." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/region-label/rules [patch] func (h *regionLabelHandler) PatchRegionLabelRules(w http.ResponseWriter, r *http.Request) { cluster := getCluster(r) var patch labeler.LabelRulePatch @@ -76,14 +76,14 @@ func (h *regionLabelHandler) PatchRegionLabelRules(w http.ResponseWriter, r *htt h.rd.JSON(w, http.StatusOK, "Update region label rules successfully.") } -// @Tags region_label -// @Summary Get label rules of cluster by ids. -// @Param body body []string true "IDs of query rules" -// @Produce json -// @Success 200 {array} labeler.LabelRule -// @Failure 400 {string} string "The input is invalid." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /config/region-label/rule/ids [get] +// @Tags region_label +// @Summary Get label rules of cluster by ids. +// @Param body body []string true "IDs of query rules" +// @Produce json +// @Success 200 {array} labeler.LabelRule +// @Failure 400 {string} string "The input is invalid." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/region-label/rule/ids [get] func (h *regionLabelHandler) GetRegionLabelRulesByIDs(w http.ResponseWriter, r *http.Request) { cluster := getCluster(r) var ids []string @@ -98,13 +98,13 @@ func (h *regionLabelHandler) GetRegionLabelRulesByIDs(w http.ResponseWriter, r * h.rd.JSON(w, http.StatusOK, rules) } -// @Tags region_label -// @Summary Get label rule of cluster by id. -// @Param id path string true "Rule Id" -// @Produce json -// @Success 200 {object} labeler.LabelRule -// @Failure 404 {string} string "The rule does not exist." -// @Router /config/region-label/rule/{id} [get] +// @Tags region_label +// @Summary Get label rule of cluster by id. +// @Param id path string true "Rule Id" +// @Produce json +// @Success 200 {object} labeler.LabelRule +// @Failure 404 {string} string "The rule does not exist." +// @Router /config/region-label/rule/{id} [get] func (h *regionLabelHandler) GetRegionLabelRuleByID(w http.ResponseWriter, r *http.Request) { cluster := getCluster(r) id, err := url.PathUnescape(mux.Vars(r)["id"]) @@ -120,14 +120,14 @@ func (h *regionLabelHandler) GetRegionLabelRuleByID(w http.ResponseWriter, r *ht h.rd.JSON(w, http.StatusOK, rule) } -// @Tags region_label -// @Summary Delete label rule of cluster by id. -// @Param id path string true "Rule Id" -// @Produce json -// @Success 200 {string} string "Delete rule successfully." -// @Failure 404 {string} string "The rule does not exist." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /config/region-label/rule/{id} [delete] +// @Tags region_label +// @Summary Delete label rule of cluster by id. +// @Param id path string true "Rule Id" +// @Produce json +// @Success 200 {string} string "Delete rule successfully." +// @Failure 404 {string} string "The rule does not exist." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/region-label/rule/{id} [delete] func (h *regionLabelHandler) DeleteRegionLabelRule(w http.ResponseWriter, r *http.Request) { cluster := getCluster(r) id, err := url.PathUnescape(mux.Vars(r)["id"]) @@ -147,15 +147,15 @@ func (h *regionLabelHandler) DeleteRegionLabelRule(w http.ResponseWriter, r *htt h.rd.Text(w, http.StatusOK, "Delete rule successfully.") } -// @Tags region_label -// @Summary Update region label rule of cluster. -// @Accept json -// @Param rule body labeler.LabelRule true "Parameters of label rule" -// @Produce json -// @Success 200 {string} string "Update rule successfully." -// @Failure 400 {string} string "The input is invalid." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /config/region-label/rule [post] +// @Tags region_label +// @Summary Update region label rule of cluster. +// @Accept json +// @Param rule body labeler.LabelRule true "Parameters of label rule" +// @Produce json +// @Success 200 {string} string "Update rule successfully." +// @Failure 400 {string} string "The input is invalid." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/region-label/rule [post] func (h *regionLabelHandler) SetRegionLabelRule(w http.ResponseWriter, r *http.Request) { cluster := getCluster(r) var rule labeler.LabelRule @@ -173,15 +173,15 @@ func (h *regionLabelHandler) SetRegionLabelRule(w http.ResponseWriter, r *http.R h.rd.JSON(w, http.StatusOK, "Update region label rule successfully.") } -// @Tags region_label -// @Summary Get label of a region. -// @Param id path integer true "Region Id" -// @Param key path string true "Label key" -// @Produce json -// @Success 200 {string} string -// @Failure 400 {string} string "The input is invalid." -// @Failure 404 {string} string "The region does not exist." -// @Router /region/id/{id}/label/{key} [get] +// @Tags region_label +// @Summary Get label of a region. +// @Param id path integer true "Region Id" +// @Param key path string true "Label key" +// @Produce json +// @Success 200 {string} string +// @Failure 400 {string} string "The input is invalid." +// @Failure 404 {string} string "The region does not exist." +// @Router /region/id/{id}/label/{key} [get] func (h *regionLabelHandler) GetRegionLabelByKey(w http.ResponseWriter, r *http.Request) { cluster := getCluster(r) regionID, labelKey := mux.Vars(r)["id"], mux.Vars(r)["key"] @@ -199,14 +199,14 @@ func (h *regionLabelHandler) GetRegionLabelByKey(w http.ResponseWriter, r *http. h.rd.JSON(w, http.StatusOK, labelValue) } -// @Tags region_label -// @Summary Get labels of a region. -// @Param id path integer true "Region Id" -// @Produce json -// @Success 200 {string} string -// @Failure 400 {string} string "The input is invalid." -// @Failure 404 {string} string "The region does not exist." -// @Router /region/id/{id}/labels [get] +// @Tags region_label +// @Summary Get labels of a region. +// @Param id path integer true "Region Id" +// @Produce json +// @Success 200 {string} string +// @Failure 400 {string} string "The input is invalid." +// @Failure 404 {string} string "The region does not exist." +// @Router /region/id/{id}/labels [get] func (h *regionLabelHandler) GetRegionLabels(w http.ResponseWriter, r *http.Request) { cluster := getCluster(r) regionID, err := strconv.ParseUint(mux.Vars(r)["id"], 10, 64) diff --git a/server/api/replication_mode.go b/server/api/replication_mode.go index 9ba7050dc24..4fe2ef5da09 100644 --- a/server/api/replication_mode.go +++ b/server/api/replication_mode.go @@ -33,11 +33,11 @@ func newReplicationModeHandler(svr *server.Server, rd *render.Render) *replicati } } -// @Tags replication_mode -// @Summary Get status of replication mode -// @Produce json -// @Success 200 {object} replication.HTTPReplicationStatus -// @Router /replication_mode/status [get] +// @Tags replication_mode +// @Summary Get status of replication mode +// @Produce json +// @Success 200 {object} replication.HTTPReplicationStatus +// @Router /replication_mode/status [get] func (h *replicationModeHandler) GetReplicationModeStatus(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusOK, getCluster(r).GetReplicationMode().GetReplicationStatusHTTP()) } diff --git a/server/api/router.go b/server/api/router.go index 3e8061fd74e..155a33845f2 100644 --- a/server/api/router.go +++ b/server/api/router.go @@ -72,15 +72,15 @@ func getFunctionName(f interface{}) string { } // The returned function is used as a lazy router to avoid the data race problem. -// @title Placement Driver Core API -// @version 1.0 -// @description This is placement driver. -// @contact.name Placement Driver Support -// @contact.url https://github.com/tikv/pd/issues -// @contact.email info@pingcap.com -// @license.name Apache 2.0 -// @license.url http://www.apache.org/licenses/LICENSE-2.0.html -// @BasePath /pd/api/v1 +// @title Placement Driver Core API +// @version 1.0 +// @description This is placement driver. +// @contact.name Placement Driver Support +// @contact.url https://github.com/tikv/pd/issues +// @contact.email info@pingcap.com +// @license.name Apache 2.0 +// @license.url http://www.apache.org/licenses/LICENSE-2.0.html +// @BasePath /pd/api/v1 func createRouter(prefix string, svr *server.Server) *mux.Router { serviceMiddle := newServiceMiddlewareBuilder(svr) registerPrefix := func(router *mux.Router, prefixPath string, diff --git a/server/api/rule.go b/server/api/rule.go index f0eb43128f2..4148cfcb668 100644 --- a/server/api/rule.go +++ b/server/api/rule.go @@ -44,12 +44,12 @@ func newRulesHandler(svr *server.Server, rd *render.Render) *ruleHandler { } } -// @Tags rule -// @Summary List all rules of cluster. -// @Produce json -// @Success 200 {array} placement.Rule -// @Failure 412 {string} string "Placement rules feature is disabled." -// @Router /config/rules [get] +// @Tags rule +// @Summary List all rules of cluster. +// @Produce json +// @Success 200 {array} placement.Rule +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Router /config/rules [get] func (h *ruleHandler) GetAllRules(w http.ResponseWriter, r *http.Request) { cluster := getCluster(r) if !cluster.GetOpts().IsPlacementRulesEnabled() { @@ -60,15 +60,15 @@ func (h *ruleHandler) GetAllRules(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusOK, rules) } -// @Tags rule -// @Summary Set all rules for the cluster. If there is an error, modifications are promised to be rollback in memory, but may fail to rollback disk. You probably want to request again to make rules in memory/disk consistent. -// @Produce json -// @Param rules body []placement.Rule true "Parameters of rules" -// @Success 200 {string} string "Update rules successfully." -// @Failure 400 {string} string "The input is invalid." -// @Failure 412 {string} string "Placement rules feature is disabled." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /config/rules [get] +// @Tags rule +// @Summary Set all rules for the cluster. If there is an error, modifications are promised to be rollback in memory, but may fail to rollback disk. You probably want to request again to make rules in memory/disk consistent. +// @Produce json +// @Param rules body []placement.Rule true "Parameters of rules" +// @Success 200 {string} string "Update rules successfully." +// @Failure 400 {string} string "The input is invalid." +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/rules [get] func (h *ruleHandler) SetAllRules(w http.ResponseWriter, r *http.Request) { cluster := getCluster(r) if !cluster.GetOpts().IsPlacementRulesEnabled() { @@ -97,13 +97,13 @@ func (h *ruleHandler) SetAllRules(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusOK, "Update rules successfully.") } -// @Tags rule -// @Summary List all rules of cluster by group. -// @Param group path string true "The name of group" -// @Produce json -// @Success 200 {array} placement.Rule -// @Failure 412 {string} string "Placement rules feature is disabled." -// @Router /config/rules/group/{group} [get] +// @Tags rule +// @Summary List all rules of cluster by group. +// @Param group path string true "The name of group" +// @Produce json +// @Success 200 {array} placement.Rule +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Router /config/rules/group/{group} [get] func (h *ruleHandler) GetRuleByGroup(w http.ResponseWriter, r *http.Request) { cluster := getCluster(r) if !cluster.GetOpts().IsPlacementRulesEnabled() { @@ -115,15 +115,15 @@ func (h *ruleHandler) GetRuleByGroup(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusOK, rules) } -// @Tags rule -// @Summary List all rules of cluster by region. -// @Param region path string true "The name of region" -// @Produce json -// @Success 200 {array} placement.Rule -// @Failure 400 {string} string "The input is invalid." -// @Failure 404 {string} string "The region does not exist." -// @Failure 412 {string} string "Placement rules feature is disabled." -// @Router /config/rules/region/{region} [get] +// @Tags rule +// @Summary List all rules of cluster by region. +// @Param region path string true "The name of region" +// @Produce json +// @Success 200 {array} placement.Rule +// @Failure 400 {string} string "The input is invalid." +// @Failure 404 {string} string "The region does not exist." +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Router /config/rules/region/{region} [get] func (h *ruleHandler) GetRulesByRegion(w http.ResponseWriter, r *http.Request) { cluster := getCluster(r) if !cluster.GetOpts().IsPlacementRulesEnabled() { @@ -145,14 +145,14 @@ func (h *ruleHandler) GetRulesByRegion(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusOK, rules) } -// @Tags rule -// @Summary List all rules of cluster by key. -// @Param key path string true "The name of key" -// @Produce json -// @Success 200 {array} placement.Rule -// @Failure 400 {string} string "The input is invalid." -// @Failure 412 {string} string "Placement rules feature is disabled." -// @Router /config/rules/key/{key} [get] +// @Tags rule +// @Summary List all rules of cluster by key. +// @Param key path string true "The name of key" +// @Produce json +// @Success 200 {array} placement.Rule +// @Failure 400 {string} string "The input is invalid." +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Router /config/rules/key/{key} [get] func (h *ruleHandler) GetRulesByKey(w http.ResponseWriter, r *http.Request) { cluster := getCluster(r) if !cluster.GetOpts().IsPlacementRulesEnabled() { @@ -169,15 +169,15 @@ func (h *ruleHandler) GetRulesByKey(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusOK, rules) } -// @Tags rule -// @Summary Get rule of cluster by group and id. -// @Param group path string true "The name of group" -// @Param id path string true "Rule Id" -// @Produce json -// @Success 200 {object} placement.Rule -// @Failure 404 {string} string "The rule does not exist." -// @Failure 412 {string} string "Placement rules feature is disabled." -// @Router /config/rule/{group}/{id} [get] +// @Tags rule +// @Summary Get rule of cluster by group and id. +// @Param group path string true "The name of group" +// @Param id path string true "Rule Id" +// @Produce json +// @Success 200 {object} placement.Rule +// @Failure 404 {string} string "The rule does not exist." +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Router /config/rule/{group}/{id} [get] func (h *ruleHandler) GetRuleByGroupAndID(w http.ResponseWriter, r *http.Request) { cluster := getCluster(r) if !cluster.GetOpts().IsPlacementRulesEnabled() { @@ -193,16 +193,16 @@ func (h *ruleHandler) GetRuleByGroupAndID(w http.ResponseWriter, r *http.Request h.rd.JSON(w, http.StatusOK, rule) } -// @Tags rule -// @Summary Update rule of cluster. -// @Accept json -// @Param rule body placement.Rule true "Parameters of rule" -// @Produce json -// @Success 200 {string} string "Update rule successfully." -// @Failure 400 {string} string "The input is invalid." -// @Failure 412 {string} string "Placement rules feature is disabled." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /config/rule [post] +// @Tags rule +// @Summary Update rule of cluster. +// @Accept json +// @Param rule body placement.Rule true "Parameters of rule" +// @Produce json +// @Success 200 {string} string "Update rule successfully." +// @Failure 400 {string} string "The input is invalid." +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/rule [post] func (h *ruleHandler) SetRule(w http.ResponseWriter, r *http.Request) { cluster := getCluster(r) if !cluster.GetOpts().IsPlacementRulesEnabled() { @@ -247,15 +247,15 @@ func (h *ruleHandler) syncReplicateConfigWithDefaultRule(rule *placement.Rule) e return nil } -// @Tags rule -// @Summary Delete rule of cluster. -// @Param group path string true "The name of group" -// @Param id path string true "Rule Id" -// @Produce json -// @Success 200 {string} string "Delete rule successfully." -// @Failure 412 {string} string "Placement rules feature is disabled." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /config/rule/{group}/{id} [delete] +// @Tags rule +// @Summary Delete rule of cluster. +// @Param group path string true "The name of group" +// @Param id path string true "Rule Id" +// @Produce json +// @Success 200 {string} string "Delete rule successfully." +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/rule/{group}/{id} [delete] func (h *ruleHandler) DeleteRuleByGroup(w http.ResponseWriter, r *http.Request) { cluster := getCluster(r) if !cluster.GetOpts().IsPlacementRulesEnabled() { @@ -275,15 +275,15 @@ func (h *ruleHandler) DeleteRuleByGroup(w http.ResponseWriter, r *http.Request) h.rd.JSON(w, http.StatusOK, "Delete rule successfully.") } -// @Tags rule -// @Summary Batch operations for the cluster. Operations should be independent(different ID). If there is an error, modifications are promised to be rollback in memory, but may fail to rollback disk. You probably want to request again to make rules in memory/disk consistent. -// @Produce json -// @Param operations body []placement.RuleOp true "Parameters of rule operations" -// @Success 200 {string} string "Batch operations successfully." -// @Failure 400 {string} string "The input is invalid." -// @Failure 412 {string} string "Placement rules feature is disabled." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /config/rules/batch [post] +// @Tags rule +// @Summary Batch operations for the cluster. Operations should be independent(different ID). If there is an error, modifications are promised to be rollback in memory, but may fail to rollback disk. You probably want to request again to make rules in memory/disk consistent. +// @Produce json +// @Param operations body []placement.RuleOp true "Parameters of rule operations" +// @Success 200 {string} string "Batch operations successfully." +// @Failure 400 {string} string "The input is invalid." +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/rules/batch [post] func (h *ruleHandler) BatchRules(w http.ResponseWriter, r *http.Request) { cluster := getCluster(r) if !cluster.GetOpts().IsPlacementRulesEnabled() { @@ -306,14 +306,14 @@ func (h *ruleHandler) BatchRules(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusOK, "Batch operations successfully.") } -// @Tags rule -// @Summary Get rule group config by group id. -// @Param id path string true "Group Id" -// @Produce json -// @Success 200 {object} placement.RuleGroup -// @Failure 404 {string} string "The RuleGroup does not exist." -// @Failure 412 {string} string "Placement rules feature is disabled." -// @Router /config/rule_group/{id} [get] +// @Tags rule +// @Summary Get rule group config by group id. +// @Param id path string true "Group Id" +// @Produce json +// @Success 200 {object} placement.RuleGroup +// @Failure 404 {string} string "The RuleGroup does not exist." +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Router /config/rule_group/{id} [get] func (h *ruleHandler) GetGroupConfig(w http.ResponseWriter, r *http.Request) { cluster := getCluster(r) if !cluster.GetOpts().IsPlacementRulesEnabled() { @@ -329,16 +329,16 @@ func (h *ruleHandler) GetGroupConfig(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusOK, group) } -// @Tags rule -// @Summary Update rule group config. -// @Accept json -// @Param rule body placement.RuleGroup true "Parameters of rule group" -// @Produce json -// @Success 200 {string} string "Update rule group config successfully." -// @Failure 400 {string} string "The input is invalid." -// @Failure 412 {string} string "Placement rules feature is disabled." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /config/rule_group [post] +// @Tags rule +// @Summary Update rule group config. +// @Accept json +// @Param rule body placement.RuleGroup true "Parameters of rule group" +// @Produce json +// @Success 200 {string} string "Update rule group config successfully." +// @Failure 400 {string} string "The input is invalid." +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/rule_group [post] func (h *ruleHandler) SetGroupConfig(w http.ResponseWriter, r *http.Request) { cluster := getCluster(r) if !cluster.GetOpts().IsPlacementRulesEnabled() { @@ -359,14 +359,14 @@ func (h *ruleHandler) SetGroupConfig(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusOK, "Update rule group successfully.") } -// @Tags rule -// @Summary Delete rule group config. -// @Param id path string true "Group Id" -// @Produce json -// @Success 200 {string} string "Delete rule group config successfully." -// @Failure 412 {string} string "Placement rules feature is disabled." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /config/rule_group/{id} [delete] +// @Tags rule +// @Summary Delete rule group config. +// @Param id path string true "Group Id" +// @Produce json +// @Success 200 {string} string "Delete rule group config successfully." +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/rule_group/{id} [delete] func (h *ruleHandler) DeleteGroupConfig(w http.ResponseWriter, r *http.Request) { cluster := getCluster(r) if !cluster.GetOpts().IsPlacementRulesEnabled() { @@ -385,12 +385,12 @@ func (h *ruleHandler) DeleteGroupConfig(w http.ResponseWriter, r *http.Request) h.rd.JSON(w, http.StatusOK, "Delete rule group successfully.") } -// @Tags rule -// @Summary List all rule group configs. -// @Produce json -// @Success 200 {array} placement.RuleGroup -// @Failure 412 {string} string "Placement rules feature is disabled." -// @Router /config/rule_groups [get] +// @Tags rule +// @Summary List all rule group configs. +// @Produce json +// @Success 200 {array} placement.RuleGroup +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Router /config/rule_groups [get] func (h *ruleHandler) GetAllGroupConfigs(w http.ResponseWriter, r *http.Request) { cluster := getCluster(r) if !cluster.GetOpts().IsPlacementRulesEnabled() { @@ -401,12 +401,12 @@ func (h *ruleHandler) GetAllGroupConfigs(w http.ResponseWriter, r *http.Request) h.rd.JSON(w, http.StatusOK, ruleGroups) } -// @Tags rule -// @Summary List all rules and groups configuration. -// @Produce json -// @Success 200 {array} placement.GroupBundle -// @Failure 412 {string} string "Placement rules feature is disabled." -// @Router /config/placement-rule [get] +// @Tags rule +// @Summary List all rules and groups configuration. +// @Produce json +// @Success 200 {array} placement.GroupBundle +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Router /config/placement-rule [get] func (h *ruleHandler) GetPlacementRules(w http.ResponseWriter, r *http.Request) { cluster := getCluster(r) if !cluster.GetOpts().IsPlacementRulesEnabled() { @@ -417,15 +417,15 @@ func (h *ruleHandler) GetPlacementRules(w http.ResponseWriter, r *http.Request) h.rd.JSON(w, http.StatusOK, bundles) } -// @Tags rule -// @Summary Update all rules and groups configuration. -// @Param partial query bool false "if partially update rules" default(false) -// @Produce json -// @Success 200 {string} string "Update rules and groups successfully." -// @Failure 400 {string} string "The input is invalid." -// @Failure 412 {string} string "Placement rules feature is disabled." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /config/placement-rule [post] +// @Tags rule +// @Summary Update all rules and groups configuration. +// @Param partial query bool false "if partially update rules" default(false) +// @Produce json +// @Success 200 {string} string "Update rules and groups successfully." +// @Failure 400 {string} string "The input is invalid." +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/placement-rule [post] func (h *ruleHandler) SetPlacementRules(w http.ResponseWriter, r *http.Request) { cluster := getCluster(r) if !cluster.GetOpts().IsPlacementRulesEnabled() { @@ -449,13 +449,13 @@ func (h *ruleHandler) SetPlacementRules(w http.ResponseWriter, r *http.Request) h.rd.JSON(w, http.StatusOK, "Update rules and groups successfully.") } -// @Tags rule -// @Summary Get group config and all rules belong to the group. -// @Param group path string true "The name of group" -// @Produce json -// @Success 200 {object} placement.GroupBundle -// @Failure 412 {string} string "Placement rules feature is disabled." -// @Router /config/placement-rule/{group} [get] +// @Tags rule +// @Summary Get group config and all rules belong to the group. +// @Param group path string true "The name of group" +// @Produce json +// @Success 200 {object} placement.GroupBundle +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Router /config/placement-rule/{group} [get] func (h *ruleHandler) GetPlacementRuleByGroup(w http.ResponseWriter, r *http.Request) { cluster := getCluster(r) if !cluster.GetOpts().IsPlacementRulesEnabled() { @@ -466,15 +466,15 @@ func (h *ruleHandler) GetPlacementRuleByGroup(w http.ResponseWriter, r *http.Req h.rd.JSON(w, http.StatusOK, group) } -// @Tags rule -// @Summary Get group config and all rules belong to the group. -// @Param group path string true "The name or name pattern of group" -// @Param regexp query bool false "Use regular expression" default(false) -// @Produce plain -// @Success 200 {string} string "Delete group and rules successfully." -// @Failure 400 {string} string "Bad request." -// @Failure 412 {string} string "Placement rules feature is disabled." -// @Router /config/placement-rule [delete] +// @Tags rule +// @Summary Get group config and all rules belong to the group. +// @Param group path string true "The name or name pattern of group" +// @Param regexp query bool false "Use regular expression" default(false) +// @Produce plain +// @Success 200 {string} string "Delete group and rules successfully." +// @Failure 400 {string} string "Bad request." +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Router /config/placement-rule [delete] func (h *ruleHandler) DeletePlacementRuleByGroup(w http.ResponseWriter, r *http.Request) { cluster := getCluster(r) if !cluster.GetOpts().IsPlacementRulesEnabled() { @@ -495,14 +495,14 @@ func (h *ruleHandler) DeletePlacementRuleByGroup(w http.ResponseWriter, r *http. h.rd.JSON(w, http.StatusOK, "Delete group and rules successfully.") } -// @Tags rule -// @Summary Update group and all rules belong to it. -// @Produce json -// @Success 200 {string} string "Update group and rules successfully." -// @Failure 400 {string} string "The input is invalid." -// @Failure 412 {string} string "Placement rules feature is disabled." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /config/placement-rule/{group} [post] +// @Tags rule +// @Summary Update group and all rules belong to it. +// @Produce json +// @Success 200 {string} string "Update group and rules successfully." +// @Failure 400 {string} string "The input is invalid." +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/placement-rule/{group} [post] func (h *ruleHandler) SetPlacementRuleByGroup(w http.ResponseWriter, r *http.Request) { cluster := getCluster(r) if !cluster.GetOpts().IsPlacementRulesEnabled() { diff --git a/server/api/scheduler.go b/server/api/scheduler.go index 5faa01c764b..9b690a93249 100644 --- a/server/api/scheduler.go +++ b/server/api/scheduler.go @@ -51,12 +51,12 @@ type schedulerPausedPeriod struct { ResumeAt time.Time `json:"resume_at"` } -// @Tags scheduler -// @Summary List all created schedulers by status. -// @Produce json -// @Success 200 {array} string -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /schedulers [get] +// @Tags scheduler +// @Summary List all created schedulers by status. +// @Produce json +// @Success 200 {array} string +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /schedulers [get] func (h *schedulerHandler) GetSchedulers(w http.ResponseWriter, r *http.Request) { schedulers, err := h.Handler.GetSchedulers() if err != nil { @@ -128,15 +128,15 @@ func (h *schedulerHandler) GetSchedulers(w http.ResponseWriter, r *http.Request) } // FIXME: details of input json body params -// @Tags scheduler -// @Summary Create a scheduler. -// @Accept json -// @Param body body object true "json params" -// @Produce json -// @Success 200 {string} string "The scheduler is created." -// @Failure 400 {string} string "Bad format request." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /schedulers [post] +// @Tags scheduler +// @Summary Create a scheduler. +// @Accept json +// @Param body body object true "json params" +// @Produce json +// @Success 200 {string} string "The scheduler is created." +// @Failure 400 {string} string "Bad format request." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /schedulers [post] func (h *schedulerHandler) CreateScheduler(w http.ResponseWriter, r *http.Request) { var input map[string]interface{} if err := apiutil.ReadJSONRespondError(h.r, w, r.Body, &input); err != nil { @@ -269,14 +269,14 @@ func (h *schedulerHandler) addEvictOrGrant(w http.ResponseWriter, input map[stri } } -// @Tags scheduler -// @Summary Delete a scheduler. -// @Param name path string true "The name of the scheduler." -// @Produce json -// @Success 200 {string} string "The scheduler is removed." -// @Failure 404 {string} string "The scheduler is not found." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /schedulers/{name} [delete] +// @Tags scheduler +// @Summary Delete a scheduler. +// @Param name path string true "The name of the scheduler." +// @Produce json +// @Success 200 {string} string "The scheduler is removed." +// @Failure 404 {string} string "The scheduler is not found." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /schedulers/{name} [delete] func (h *schedulerHandler) DeleteScheduler(w http.ResponseWriter, r *http.Request) { name := mux.Vars(r)["name"] switch { @@ -316,16 +316,16 @@ func (h *schedulerHandler) redirectSchedulerDelete(w http.ResponseWriter, name, } // FIXME: details of input json body params -// @Tags scheduler -// @Summary Pause or resume a scheduler. -// @Accept json -// @Param name path string true "The name of the scheduler." -// @Param body body object true "json params" -// @Produce json -// @Success 200 {string} string "Pause or resume the scheduler successfully." -// @Failure 400 {string} string "Bad format request." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /schedulers/{name} [post] +// @Tags scheduler +// @Summary Pause or resume a scheduler. +// @Accept json +// @Param name path string true "The name of the scheduler." +// @Param body body object true "json params" +// @Produce json +// @Success 200 {string} string "Pause or resume the scheduler successfully." +// @Failure 400 {string} string "Bad format request." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /schedulers/{name} [post] func (h *schedulerHandler) PauseOrResumeScheduler(w http.ResponseWriter, r *http.Request) { var input map[string]int64 if err := apiutil.ReadJSONRespondError(h.r, w, r.Body, &input); err != nil { diff --git a/server/api/service_gc_safepoint.go b/server/api/service_gc_safepoint.go index 40c3aff1076..9df3700a30b 100644 --- a/server/api/service_gc_safepoint.go +++ b/server/api/service_gc_safepoint.go @@ -41,12 +41,12 @@ type listServiceGCSafepoint struct { GCSafePoint uint64 `json:"gc_safe_point"` } -// @Tags service_gc_safepoint -// @Summary Get all service GC safepoint. -// @Produce json -// @Success 200 {array} listServiceGCSafepoint -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /gc/safepoint [get] +// @Tags service_gc_safepoint +// @Summary Get all service GC safepoint. +// @Produce json +// @Success 200 {array} listServiceGCSafepoint +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /gc/safepoint [get] func (h *serviceGCSafepointHandler) GetGCSafePoint(w http.ResponseWriter, r *http.Request) { storage := h.svr.GetStorage() gcSafepoint, err := storage.LoadGCSafePoint() @@ -66,14 +66,14 @@ func (h *serviceGCSafepointHandler) GetGCSafePoint(w http.ResponseWriter, r *htt h.rd.JSON(w, http.StatusOK, list) } -// @Tags service_gc_safepoint -// @Summary Delete a service GC safepoint. -// @Param service_id path string true "Service ID" -// @Produce json -// @Success 200 {string} string "Delete service GC safepoint successfully." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /gc/safepoint/{service_id} [delete] -// @Tags rule +// @Tags service_gc_safepoint +// @Summary Delete a service GC safepoint. +// @Param service_id path string true "Service ID" +// @Produce json +// @Success 200 {string} string "Delete service GC safepoint successfully." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /gc/safepoint/{service_id} [delete] +// @Tags rule func (h *serviceGCSafepointHandler) DeleteGCSafePoint(w http.ResponseWriter, r *http.Request) { storage := h.svr.GetStorage() serviceID := mux.Vars(r)["service_id"] diff --git a/server/api/service_middleware.go b/server/api/service_middleware.go index 426399a1d6e..c4489e93fa1 100644 --- a/server/api/service_middleware.go +++ b/server/api/service_middleware.go @@ -45,24 +45,24 @@ func newServiceMiddlewareHandler(svr *server.Server, rd *render.Render) *service } } -// @Tags service_middleware -// @Summary Get Service Middleware config. -// @Produce json -// @Success 200 {object} config.Config -// @Router /service-middleware/config [get] +// @Tags service_middleware +// @Summary Get Service Middleware config. +// @Produce json +// @Success 200 {object} config.Config +// @Router /service-middleware/config [get] func (h *serviceMiddlewareHandler) GetServiceMiddlewareConfig(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusOK, h.svr.GetServiceMiddlewareConfig()) } -// @Tags service_middleware -// @Summary Update some service-middleware's config items. -// @Accept json -// @Param body body object false "json params" -// @Produce json -// @Success 200 {string} string "The config is updated." -// @Failure 400 {string} string "The input is invalid." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /service-middleware/config [post] +// @Tags service_middleware +// @Summary Update some service-middleware's config items. +// @Accept json +// @Param body body object false "json params" +// @Produce json +// @Success 200 {string} string "The config is updated." +// @Failure 400 {string} string "The input is invalid." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /service-middleware/config [post] func (h *serviceMiddlewareHandler) SetServiceMiddlewareConfig(w http.ResponseWriter, r *http.Request) { cfg := h.svr.GetServiceMiddlewareConfig() data, err := io.ReadAll(r.Body) @@ -131,14 +131,14 @@ func (h *serviceMiddlewareHandler) updateAudit(config *config.ServiceMiddlewareC return err } -// @Tags service_middleware -// @Summary update ratelimit config -// @Param body body object string "json params" -// @Produce json -// @Success 200 {string} string -// @Failure 400 {string} string "The input is invalid." -// @Failure 500 {string} string "config item not found" -// @Router /service-middleware/config/rate-limit [POST] +// @Tags service_middleware +// @Summary update ratelimit config +// @Param body body object string "json params" +// @Produce json +// @Success 200 {string} string +// @Failure 400 {string} string "The input is invalid." +// @Failure 500 {string} string "config item not found" +// @Router /service-middleware/config/rate-limit [POST] func (h *serviceMiddlewareHandler) SetRatelimitConfig(w http.ResponseWriter, r *http.Request) { var input map[string]interface{} if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &input); err != nil { diff --git a/server/api/stats.go b/server/api/stats.go index d0a14e9a051..9a5983f43ba 100644 --- a/server/api/stats.go +++ b/server/api/stats.go @@ -33,13 +33,13 @@ func newStatsHandler(svr *server.Server, rd *render.Render) *statsHandler { } } -// @Tags stats -// @Summary Get region statistics of a specified range. -// @Param start_key query string true "Start key" -// @Param end_key query string true "End key" -// @Produce json -// @Success 200 {object} statistics.RegionStats -// @Router /stats/region [get] +// @Tags stats +// @Summary Get region statistics of a specified range. +// @Param start_key query string true "Start key" +// @Param end_key query string true "End key" +// @Produce json +// @Success 200 {object} statistics.RegionStats +// @Router /stats/region [get] func (h *statsHandler) GetRegionStatus(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) startKey, endKey := r.URL.Query().Get("start_key"), r.URL.Query().Get("end_key") diff --git a/server/api/status.go b/server/api/status.go index 1f318fbb546..12a25554f46 100644 --- a/server/api/status.go +++ b/server/api/status.go @@ -42,10 +42,10 @@ func newStatusHandler(svr *server.Server, rd *render.Render) *statusHandler { } } -// @Summary Get the build info of PD server. -// @Produce json -// @Success 200 {object} status -// @Router /status [get] +// @Summary Get the build info of PD server. +// @Produce json +// @Success 200 {object} status +// @Router /status [get] func (h *statusHandler) GetPDStatus(w http.ResponseWriter, r *http.Request) { version := status{ BuildTS: versioninfo.PDBuildTS, diff --git a/server/api/store.go b/server/api/store.go index 27aa7b59655..9ee784eaa29 100644 --- a/server/api/store.go +++ b/server/api/store.go @@ -140,15 +140,15 @@ func newStoreHandler(handler *server.Handler, rd *render.Render) *storeHandler { } } -// @Tags store -// @Summary Get a store's information. -// @Param id path integer true "Store Id" -// @Produce json -// @Success 200 {object} StoreInfo -// @Failure 400 {string} string "The input is invalid." -// @Failure 404 {string} string "The store does not exist." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /store/{id} [get] +// @Tags store +// @Summary Get a store's information. +// @Param id path integer true "Store Id" +// @Produce json +// @Success 200 {object} StoreInfo +// @Failure 400 {string} string "The input is invalid." +// @Failure 404 {string} string "The store does not exist." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /store/{id} [get] func (h *storeHandler) GetStore(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) vars := mux.Vars(r) @@ -168,17 +168,17 @@ func (h *storeHandler) GetStore(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusOK, storeInfo) } -// @Tags store -// @Summary Take down a store from the cluster. -// @Param id path integer true "Store Id" -// @Param force query string true "force" Enums(true, false), when force is true it means the store is physically destroyed and can never up gain -// @Produce json -// @Success 200 {string} string "The store is set as Offline." -// @Failure 400 {string} string "The input is invalid." -// @Failure 404 {string} string "The store does not exist." -// @Failure 410 {string} string "The store has already been removed." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /store/{id} [delete] +// @Tags store +// @Summary Take down a store from the cluster. +// @Param id path integer true "Store Id" +// @Param force query string true "force" Enums(true, false) +// @Produce json +// @Success 200 {string} string "The store is set as Offline." +// @Failure 400 {string} string "The input is invalid." +// @Failure 404 {string} string "The store does not exist." +// @Failure 410 {string} string "The store has already been removed." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /store/{id} [delete] func (h *storeHandler) DeleteStore(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) vars := mux.Vars(r) @@ -199,16 +199,16 @@ func (h *storeHandler) DeleteStore(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusOK, "The store is set as Offline.") } -// @Tags store -// @Summary Set the store's state. -// @Param id path integer true "Store Id" -// @Param state query string true "state" Enums(Up, Offline) -// @Produce json -// @Success 200 {string} string "The store's state is updated." -// @Failure 400 {string} string "The input is invalid." -// @Failure 404 {string} string "The store does not exist." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /store/{id}/state [post] +// @Tags store +// @Summary Set the store's state. +// @Param id path integer true "Store Id" +// @Param state query string true "state" Enums(Up, Offline) +// @Produce json +// @Success 200 {string} string "The store's state is updated." +// @Failure 400 {string} string "The input is invalid." +// @Failure 404 {string} string "The store does not exist." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /store/{id}/state [post] func (h *storeHandler) SetStoreState(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) vars := mux.Vars(r) @@ -252,15 +252,15 @@ func (h *storeHandler) responseStoreErr(w http.ResponseWriter, err error, storeI } // FIXME: details of input json body params -// @Tags store -// @Summary Set the store's label. -// @Param id path integer true "Store Id" -// @Param body body object true "Labels in json format" -// @Produce json -// @Success 200 {string} string "The store's label is updated." -// @Failure 400 {string} string "The input is invalid." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /store/{id}/label [post] +// @Tags store +// @Summary Set the store's label. +// @Param id path integer true "Store Id" +// @Param body body object true "Labels in json format" +// @Produce json +// @Success 200 {string} string "The store's label is updated." +// @Failure 400 {string} string "The input is invalid." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /store/{id}/label [post] func (h *storeHandler) SetStoreLabel(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) vars := mux.Vars(r) @@ -298,15 +298,15 @@ func (h *storeHandler) SetStoreLabel(w http.ResponseWriter, r *http.Request) { } // FIXME: details of input json body params -// @Tags store -// @Summary Set the store's leader/region weight. -// @Param id path integer true "Store Id" -// @Param body body object true "json params" -// @Produce json -// @Success 200 {string} string "The store's label is updated." -// @Failure 400 {string} string "The input is invalid." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /store/{id}/weight [post] +// @Tags store +// @Summary Set the store's leader/region weight. +// @Param id path integer true "Store Id" +// @Param body body object true "json params" +// @Produce json +// @Success 200 {string} string "The store's label is updated." +// @Failure 400 {string} string "The input is invalid." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /store/{id}/weight [post] func (h *storeHandler) SetStoreWeight(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) vars := mux.Vars(r) @@ -351,16 +351,16 @@ func (h *storeHandler) SetStoreWeight(w http.ResponseWriter, r *http.Request) { } // FIXME: details of input json body params -// @Tags store -// @Summary Set the store's limit. -// @Param ttlSecond query integer false "ttl". ttl param is only for BR and lightning now. Don't use it. -// @Param id path integer true "Store Id" -// @Param body body object true "json params" -// @Produce json -// @Success 200 {string} string "The store's label is updated." -// @Failure 400 {string} string "The input is invalid." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /store/{id}/limit [post] +// @Tags store +// @Summary Set the store's limit. +// @Param ttlSecond query integer false "ttl param is only for BR and lightning now. Don't use it." +// @Param id path integer true "Store Id" +// @Param body body object true "json params" +// @Produce json +// @Success 200 {string} string "The store's label is updated." +// @Failure 400 {string} string "The input is invalid." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /store/{id}/limit [post] func (h *storeHandler) SetStoreLimit(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) vars := mux.Vars(r) @@ -435,12 +435,12 @@ func newStoresHandler(handler *server.Handler, rd *render.Render) *storesHandler } } -// @Tags store -// @Summary Remove tombstone records in the cluster. -// @Produce json -// @Success 200 {string} string "Remove tombstone successfully." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /stores/remove-tombstone [delete] +// @Tags store +// @Summary Remove tombstone records in the cluster. +// @Produce json +// @Success 200 {string} string "Remove tombstone successfully." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /stores/remove-tombstone [delete] func (h *storesHandler) RemoveTombStone(w http.ResponseWriter, r *http.Request) { err := getCluster(r).RemoveTombStoneRecords() if err != nil { @@ -452,16 +452,16 @@ func (h *storesHandler) RemoveTombStone(w http.ResponseWriter, r *http.Request) } // FIXME: details of input json body params -// @Tags store -// @Summary Set limit of all stores in the cluster. -// @Accept json -// @Param ttlSecond query integer false "ttl". ttl param is only for BR and lightning now. Don't use it. -// @Param body body object true "json params" -// @Produce json -// @Success 200 {string} string "Set store limit successfully." -// @Failure 400 {string} string "The input is invalid." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /stores/limit [post] +// @Tags store +// @Summary Set limit of all stores in the cluster. +// @Accept json +// @Param ttlSecond query integer false "ttl param is only for BR and lightning now. Don't use it." +// @Param body body object true "json params" +// @Produce json +// @Success 200 {string} string "Set store limit successfully." +// @Failure 400 {string} string "The input is invalid." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /stores/limit [post] func (h *storesHandler) SetAllStoresLimit(w http.ResponseWriter, r *http.Request) { var input map[string]interface{} if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &input); err != nil { @@ -535,13 +535,13 @@ func (h *storesHandler) SetAllStoresLimit(w http.ResponseWriter, r *http.Request } // FIXME: details of output json body -// @Tags store -// @Summary Get limit of all stores in the cluster. -// @Param include_tombstone query bool false "include Tombstone" default(false) -// @Produce json -// @Success 200 {object} string -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /stores/limit [get] +// @Tags store +// @Summary Get limit of all stores in the cluster. +// @Param include_tombstone query bool false "include Tombstone" default(false) +// @Produce json +// @Success 200 {object} string +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /stores/limit [get] func (h *storesHandler) GetAllStoresLimit(w http.ResponseWriter, r *http.Request) { limits := h.GetScheduleConfig().StoreLimit includeTombstone := false @@ -569,15 +569,15 @@ func (h *storesHandler) GetAllStoresLimit(w http.ResponseWriter, r *http.Request h.rd.JSON(w, http.StatusOK, limits) } -// @Tags store -// @Summary Set limit scene in the cluster. -// @Accept json -// @Param body body storelimit.Scene true "Store limit scene" -// @Produce json -// @Success 200 {string} string "Set store limit scene successfully." -// @Failure 400 {string} string "The input is invalid." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /stores/limit/scene [post] +// @Tags store +// @Summary Set limit scene in the cluster. +// @Accept json +// @Param body body storelimit.Scene true "Store limit scene" +// @Produce json +// @Success 200 {string} string "Set store limit scene successfully." +// @Failure 400 {string} string "The input is invalid." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /stores/limit/scene [post] func (h *storesHandler) SetStoreLimitScene(w http.ResponseWriter, r *http.Request) { typeName := r.URL.Query().Get("type") typeValue, err := parseStoreLimitType(typeName) @@ -593,11 +593,11 @@ func (h *storesHandler) SetStoreLimitScene(w http.ResponseWriter, r *http.Reques h.rd.JSON(w, http.StatusOK, "Set store limit scene successfully.") } -// @Tags store -// @Summary Get limit scene in the cluster. -// @Produce json -// @Success 200 {string} string "Get store limit scene successfully." -// @Router /stores/limit/scene [get] +// @Tags store +// @Summary Get limit scene in the cluster. +// @Produce json +// @Success 200 {string} string "Get store limit scene successfully." +// @Router /stores/limit/scene [get] func (h *storesHandler) GetStoreLimitScene(w http.ResponseWriter, r *http.Request) { typeName := r.URL.Query().Get("type") typeValue, err := parseStoreLimitType(typeName) @@ -618,13 +618,13 @@ type Progress struct { LeftSeconds float64 `json:"left_seconds"` } -// @Tags stores -// @Summary Get store progress in the cluster. -// @Produce json -// @Success 200 {object} Progress -// @Failure 400 {string} string "The input is invalid." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /stores/progress [get] +// @Tags stores +// @Summary Get store progress in the cluster. +// @Produce json +// @Success 200 {object} Progress +// @Failure 400 {string} string "The input is invalid." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /stores/progress [get] func (h *storesHandler) GetStoresProgress(w http.ResponseWriter, r *http.Request) { if v := r.URL.Query().Get("id"); v != "" { storeID, err := strconv.ParseUint(v, 10, 64) @@ -668,13 +668,13 @@ func (h *storesHandler) GetStoresProgress(w http.ResponseWriter, r *http.Request h.rd.JSON(w, http.StatusBadRequest, "need query parameters") } -// @Tags store -// @Summary Get stores in the cluster. -// @Param state query array true "Specify accepted store states." -// @Produce json -// @Success 200 {object} StoresInfo -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /stores [get] +// @Tags store +// @Summary Get stores in the cluster. +// @Param state query array true "Specify accepted store states." +// @Produce json +// @Success 200 {object} StoresInfo +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /stores [get] func (h *storesHandler) GetStores(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) stores := rc.GetMetaStores() diff --git a/server/api/trend.go b/server/api/trend.go index 537167705fc..e100b31cbd9 100644 --- a/server/api/trend.go +++ b/server/api/trend.go @@ -80,14 +80,14 @@ func newTrendHandler(s *server.Server, rd *render.Render) *trendHandler { } } -// @Tags trend -// @Summary Get the growth and changes of data in the most recent period of time. -// @Param from query integer false "From Unix timestamp" -// @Produce json -// @Success 200 {object} Trend -// @Failure 400 {string} string "The request is invalid." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /trend [get] +// @Tags trend +// @Summary Get the growth and changes of data in the most recent period of time. +// @Param from query integer false "From Unix timestamp" +// @Produce json +// @Success 200 {object} Trend +// @Failure 400 {string} string "The request is invalid." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /trend [get] func (h *trendHandler) GetTrend(w http.ResponseWriter, r *http.Request) { var from time.Time if fromStr := r.URL.Query()["from"]; len(fromStr) > 0 { diff --git a/server/api/tso.go b/server/api/tso.go index df7db500899..cbe30887b73 100644 --- a/server/api/tso.go +++ b/server/api/tso.go @@ -35,17 +35,17 @@ func newTSOHandler(svr *server.Server, rd *render.Render) *tsoHandler { } } -// @Tags tso -// @Summary Transfer Local TSO Allocator -// @Accept json -// @Param name path string true "PD server name" -// @Param body body object true "json params" -// @Produce json -// @Success 200 {string} string "The transfer command is submitted." -// @Failure 400 {string} string "The input is invalid." -// @Failure 404 {string} string "The member does not exist." -// @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /tso/allocator/transfer/{name} [post] +// @Tags tso +// @Summary Transfer Local TSO Allocator +// @Accept json +// @Param name path string true "PD server name" +// @Param body body object true "json params" +// @Produce json +// @Success 200 {string} string "The transfer command is submitted." +// @Failure 400 {string} string "The input is invalid." +// @Failure 404 {string} string "The member does not exist." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /tso/allocator/transfer/{name} [post] func (h *tsoHandler) TransferLocalTSOAllocator(w http.ResponseWriter, r *http.Request) { members, membersErr := getMembers(h.svr) if membersErr != nil { diff --git a/server/api/unsafe_operation.go b/server/api/unsafe_operation.go index 83912c120a0..c45771619b0 100644 --- a/server/api/unsafe_operation.go +++ b/server/api/unsafe_operation.go @@ -35,15 +35,15 @@ func newUnsafeOperationHandler(svr *server.Server, rd *render.Render) *unsafeOpe } } -// @Tags unsafe -// @Summary Remove failed stores unsafely. -// @Accept json -// @Param body body object true "json params" -// @Produce json +// @Tags unsafe +// @Summary Remove failed stores unsafely. +// @Accept json +// @Param body body object true "json params" +// @Produce json // Success 200 {string} string "Request has been accepted." // Failure 400 {string} string "The input is invalid." // Failure 500 {string} string "PD server failed to proceed the request." -// @Router /admin/unsafe/remove-failed-stores [POST] +// @Router /admin/unsafe/remove-failed-stores [POST] func (h *unsafeOperationHandler) RemoveFailedStores(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) var input map[string]interface{} @@ -72,11 +72,11 @@ func (h *unsafeOperationHandler) RemoveFailedStores(w http.ResponseWriter, r *ht h.rd.JSON(w, http.StatusOK, "Request has been accepted.") } -// @Tags unsafe -// @Summary Show the current status of failed stores removal. -// @Produce json +// @Tags unsafe +// @Summary Show the current status of failed stores removal. +// @Produce json // Success 200 {object} []StageOutput -// @Router /admin/unsafe/remove-failed-stores/show [GET] +// @Router /admin/unsafe/remove-failed-stores/show [GET] func (h *unsafeOperationHandler) GetFailedStoresRemovalStatus(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) h.rd.JSON(w, http.StatusOK, rc.GetUnsafeRecoveryController().Show()) diff --git a/server/api/version.go b/server/api/version.go index 2a1d7e36fbd..38e5a12d8c2 100644 --- a/server/api/version.go +++ b/server/api/version.go @@ -39,10 +39,10 @@ func newVersionHandler(rd *render.Render) *versionHandler { } } -// @Summary Get the version of PD server. -// @Produce json -// @Success 200 {object} version -// @Router /version [get] +// @Summary Get the version of PD server. +// @Produce json +// @Success 200 {object} version +// @Router /version [get] func (h *versionHandler) GetVersion(w http.ResponseWriter, r *http.Request) { version := &version{ Version: versioninfo.PDReleaseVersion, diff --git a/tests/client/go.sum b/tests/client/go.sum index b9d76d704a8..ad49f3bc358 100644 --- a/tests/client/go.sum +++ b/tests/client/go.sum @@ -79,6 +79,7 @@ github.com/corona10/goimagehash v1.0.2/go.mod h1:/l9umBhvcHQXVtQO1V6Gp1yD20STawk github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v0.0.0-20161028175848-04cdfd42973b/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -133,20 +134,24 @@ github.com/go-ole/go-ole v1.2.4 h1:nNBDSCOigTSiarFpYE9J/KtEA1IOW4CNeqT9TQDqCxI= github.com/go-ole/go-ole v1.2.4/go.mod h1:XCwSNxSkXRo4vlyPy93sltvi/qJq0jqQhjqQNIwKuxM= github.com/go-openapi/jsonpointer v0.17.0/go.mod h1:cOnomiV+CVVwFLk0A/MExoFMjwdsUdVpsRhURCKh+3M= github.com/go-openapi/jsonpointer v0.19.2/go.mod h1:3akKfEdA7DF1sugOqz1dVQHBcuDBPKZGEoHC/NkiQRg= -github.com/go-openapi/jsonpointer v0.19.3 h1:gihV7YNZK1iK6Tgwwsxo2rJbD1GTbdm72325Bq8FI3w= github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= +github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY= +github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= github.com/go-openapi/jsonreference v0.17.0/go.mod h1:g4xxGn04lDIRh0GJb5QlpE3HfopLOL6uZrK/VgnsK9I= github.com/go-openapi/jsonreference v0.19.0/go.mod h1:g4xxGn04lDIRh0GJb5QlpE3HfopLOL6uZrK/VgnsK9I= github.com/go-openapi/jsonreference v0.19.2/go.mod h1:jMjeRr2HHw6nAVajTXJ4eiUwohSTlpa0o73RUL1owJc= -github.com/go-openapi/jsonreference v0.19.3 h1:5cxNfTy0UVC3X8JL5ymxzyoUZmo8iZb+jeTWn7tUa8o= github.com/go-openapi/jsonreference v0.19.3/go.mod h1:rjx6GuL8TTa9VaixXglHmQmIL98+wF9xc8zWvFonSJ8= +github.com/go-openapi/jsonreference v0.19.6 h1:UBIxjkht+AWIgYzCDSv2GN+E/togfwXUJFRTWhl2Jjs= +github.com/go-openapi/jsonreference v0.19.6/go.mod h1:diGHMEHg2IqXZGKxqyvWdfWU/aim5Dprw5bqpKkTvns= github.com/go-openapi/spec v0.19.0/go.mod h1:XkF/MOi14NmjsfZ8VtAKf8pIlbZzyoTvZsdfssdxcBI= -github.com/go-openapi/spec v0.19.4 h1:ixzUSnHTd6hCemgtAJgluaTSGYpLNpJY4mA2DIkdOAo= github.com/go-openapi/spec v0.19.4/go.mod h1:FpwSN1ksY1eteniUU7X0N/BgJ7a4WvBFVA8Lj9mJglo= +github.com/go-openapi/spec v0.20.4 h1:O8hJrt0UMnhHcluhIdUgCLRWyM2x7QkBXRvOs7m+O1M= +github.com/go-openapi/spec v0.20.4/go.mod h1:faYFR1CvsJZ0mNsmsphTMSoRrNV3TEDoAM7FOEWeq8I= github.com/go-openapi/swag v0.17.0/go.mod h1:AByQ+nYG6gQg71GINrmuDXCPWdL640yX49/kXLo40Tg= github.com/go-openapi/swag v0.19.2/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= -github.com/go-openapi/swag v0.19.5 h1:lTz6Ys4CmqqCQmZPBlbQENR1/GucA2bzYTE12Pw4tFY= github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= +github.com/go-openapi/swag v0.19.15 h1:D2NRCBzS9/pEY3gP9Nl8aDqGUcPFrwG2p+CNFrLyrCM= +github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8ceBS/t7Q= @@ -287,6 +292,8 @@ github.com/jonboulle/clockwork v0.2.2 h1:UOGuzwb1PwsrDAObMuhUnj0p5ULPj8V/xJ7Kx9q github.com/jonboulle/clockwork v0.2.2/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8= github.com/joomcode/errorx v1.0.1 h1:CalpDWz14ZHd68fIqluJasJosAewpz2TFaJALrUxjrk= github.com/joomcode/errorx v1.0.1/go.mod h1:kgco15ekB6cs+4Xjzo7SPeXzx38PbJzBwbnu9qfVNHQ= +github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.5/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= @@ -308,13 +315,13 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxv github.com/konsorten/go-windows-terminal-sequences v1.0.3 h1:CE8S1cTafDpPvMhIxNJKvHsGVBgn1xWYf1NbHQhywc8= github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.5/go.mod h1:9r2w37qlBe7rQ6e1fg1S/9xpWHSnaqNdHD3WcMdbPDA= github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= @@ -326,8 +333,9 @@ github.com/lucasb-eyer/go-colorful v1.0.3/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/mailru/easyjson v0.0.0-20180823135443-60711f1a8329/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= -github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e h1:hB2xlXdHp/pmPZq0y3QnmWAArdw9PqbmotexnWx/FU8= github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.7.6 h1:8yTIVnZgCoiM1TgqoeTl+LfU5Jg6/xL3QhGQnimLYnA= +github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= @@ -371,6 +379,8 @@ github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f h1:KUppIJq7/+ github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5 h1:BvoENQQU+fZ9uukda/RzCAL/191HHwJA5b13R6diVlY= github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/oleiade/reflections v1.0.1 h1:D1XO3LVEYroYskEsoSiGItp9RUxG6jWnCVvrqH0HHQM= github.com/oleiade/reflections v1.0.1/go.mod h1:rdFxbxq4QXVZWj0F+e9jqjDkc7dbp97vkRixKo2JR60= @@ -515,8 +525,9 @@ github.com/swaggo/http-swagger v0.0.0-20200308142732-58ac5e232fba h1:lUPlXKqgbqT github.com/swaggo/http-swagger v0.0.0-20200308142732-58ac5e232fba/go.mod h1:O1lAbCgAAX/KZ80LM/OXwtWFI/5TvZlwxSg8Cq08PV0= github.com/swaggo/swag v1.5.1/go.mod h1:1Bl9F/ZBpVWh22nY0zmYyASPO1lI/zIwRDrpZU+tv8Y= github.com/swaggo/swag v1.6.3/go.mod h1:wcc83tB4Mb2aNiL/HP4MFeQdpHUrca+Rp/DRNgWAUio= -github.com/swaggo/swag v1.6.6-0.20200529100950-7c765ddd0476 h1:UjnSXdNPIG+5FJ6xLQODEdk7gSnJlMldu3sPAxxCO+4= github.com/swaggo/swag v1.6.6-0.20200529100950-7c765ddd0476/go.mod h1:xDhTyuFIujYiN3DKWC/H/83xcfHp+UE/IzWWampG7Zc= +github.com/swaggo/swag v1.8.3 h1:3pZSSCQ//gAH88lfmxM3Cd1+JCsxV8Md6f36b9hrZ5s= +github.com/swaggo/swag v1.8.3/go.mod h1:jMLeXOOmYyjk8PvHTsXBdrubsNd9gUJTTCzL5iBnseg= github.com/syndtr/goleveldb v1.0.1-0.20190318030020-c3a204f8e965 h1:1oFLiOyVl+W7bnBzGhf7BbIv9loSFQcieWWYIjLqcAw= github.com/syndtr/goleveldb v1.0.1-0.20190318030020-c3a204f8e965/go.mod h1:9OrXJhf154huy1nPWmuSrkgjPUtUNhA+Zmy+6AESzuA= github.com/thoas/go-funk v0.8.0 h1:JP9tKSvnpFVclYgDM0Is7FD9M4fhPvqA0s0BsXmzSRQ= @@ -549,6 +560,7 @@ github.com/unrolled/render v1.0.1 h1:VDDnQQVfBMsOsp3VaCJszSO0nkBIVEYoPWeRThk9spY github.com/unrolled/render v1.0.1/go.mod h1:gN9T0NhL4Bfbwu8ann7Ry/TGHYfosul+J0obPf6NBdM= github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= github.com/urfave/cli/v2 v2.1.1/go.mod h1:SE9GqnLQmjVa0iPEY0f1w3ygNIYcIJ0OKPMoW2caLfQ= +github.com/urfave/cli/v2 v2.3.0/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI= github.com/urfave/negroni v0.3.0 h1:PaXOb61mWeZJxc1Ji2xJjpVg9QfPo0rrB+lHyBxGNSU= github.com/urfave/negroni v0.3.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4= github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU= @@ -561,6 +573,7 @@ github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1: github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= @@ -614,8 +627,9 @@ golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200204104054-c9f3fb736b72/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/image v0.0.0-20200119044424-58c23975cae1 h1:5h3ngYt7+vXCDZCup/HkCQgW5XwmSvR/nA2JmJ0RErg= golang.org/x/image v0.0.0-20200119044424-58c23975cae1/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= @@ -629,8 +643,9 @@ golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKG golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.4.2 h1:Gz96sIWK3OalVv/I/qNygP42zyoKp3xptRVCWRFEBvo= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 h1:kQgndtyPBW/JIYERgdxfwMYh3AVStj88WQTlNDi2a+o= +golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -653,8 +668,12 @@ golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 h1:4nGaVu0QrbjT/AK2PRLuQfQuh6DJve+pELhqTdAj3x0= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM= +golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4 h1:HVyaeDAYux4pnY+D/SiwmLOR36ewZ4iGQIIrtnuCjFA= +golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421 h1:Wo7BWFiOk0QRFMLYMqJGFMd9CgUAcGx7V+qEg/h5IBI= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -700,16 +719,23 @@ golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210217105451-b926d437f341/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210420072515-93ed5bcd2bfe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac h1:oN6lz7iLW/YC7un8pq+9bOLyXrprv2+DKfkJY+2LJJw= golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= -golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 h1:M73Iuj3xbbb9Uk1DYhzydthsj6oOd6l9bpuFcNoUvTs= @@ -743,8 +769,9 @@ golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roY golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20210112230658-8b4aab62c064/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= -golang.org/x/tools v0.1.5 h1:ouewzE6p+/VEB31YYnTbEJdi8pFqKp4P4n85vwo3DHA= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.1.10 h1:QjFRCZxdOhBJ/UNgnBZLbNV13DlbnK0quyivTnXJM20= +golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -779,8 +806,9 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= @@ -804,6 +832,7 @@ gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/driver/mysql v1.0.6 h1:mA0XRPjIKi4bkE9nv+NKs6qj6QWOchqUSdWOcpd3x1E= From 3d53b06e1664ae176f33e4ee7636f077b8447da8 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Tue, 21 Jun 2022 13:30:36 +0800 Subject: [PATCH 61/82] Makefile: increase the golangci-lint timeout (#5192) close tikv/pd#5078 Increase the golangci-lint timeout. Signed-off-by: JmPotato Co-authored-by: Ti Chi Robot --- .golangci.yml | 2 +- Makefile | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 778b8d2b047..bfb9954fff9 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,5 +1,5 @@ run: - deadline: 120s + timeout: 3m linters: enable: - misspell diff --git a/Makefile b/Makefile index 7ac146e39eb..e76488e2b2d 100644 --- a/Makefile +++ b/Makefile @@ -148,7 +148,7 @@ static: install-tools @ echo "gofmt ..." @ gofmt -s -l -d $(PACKAGE_DIRECTORIES) 2>&1 | awk '{ print } END { if (NR > 0) { exit 1 } }' @ echo "golangci-lint ..." - @ golangci-lint run $(PACKAGE_DIRECTORIES) + @ golangci-lint run --verbose $(PACKAGE_DIRECTORIES) @ echo "revive ..." @ revive -formatter friendly -config revive.toml $(PACKAGES) From 3d1e6c5336fecc6ae299b14aa0f805a01e557b26 Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Tue, 21 Jun 2022 15:06:37 +0800 Subject: [PATCH 62/82] bucket: migrate test framework to testify (#5195) ref tikv/pd#4813 Signed-off-by: lhy1024 Co-authored-by: Ti Chi Robot --- .../buckets/hot_bucket_cache_test.go | 83 +++++++++---------- .../buckets/hot_bucket_task_test.go | 58 ++++++------- 2 files changed, 69 insertions(+), 72 deletions(-) diff --git a/server/statistics/buckets/hot_bucket_cache_test.go b/server/statistics/buckets/hot_bucket_cache_test.go index 7c8cc85e99c..a55a505957b 100644 --- a/server/statistics/buckets/hot_bucket_cache_test.go +++ b/server/statistics/buckets/hot_bucket_cache_test.go @@ -18,19 +18,12 @@ import ( "context" "testing" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/require" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testHotBucketCache{}) - -type testHotBucketCache struct{} - -func (t *testHotBucketCache) TestPutItem(c *C) { +func TestPutItem(t *testing.T) { + re := require.New(t) cache := NewBucketsCache(context.Background()) testdata := []struct { regionID uint64 @@ -90,17 +83,18 @@ func (t *testHotBucketCache) TestPutItem(c *C) { }} for _, v := range testdata { bucket := convertToBucketTreeItem(newTestBuckets(v.regionID, v.version, v.keys, 10)) - c.Assert(bucket.GetStartKey(), BytesEquals, v.keys[0]) - c.Assert(bucket.GetEndKey(), BytesEquals, v.keys[len(v.keys)-1]) + re.Equal(v.keys[0], bucket.GetStartKey()) + re.Equal(v.keys[len(v.keys)-1], bucket.GetEndKey()) cache.putItem(bucket, cache.getBucketsByKeyRange(bucket.GetStartKey(), bucket.GetEndKey())) - c.Assert(cache.bucketsOfRegion, HasLen, v.regionCount) - c.Assert(cache.tree.Len(), Equals, v.treeLen) - c.Assert(cache.bucketsOfRegion[v.regionID], NotNil) - c.Assert(cache.getBucketsByKeyRange([]byte("10"), nil), NotNil) + re.Len(cache.bucketsOfRegion, v.regionCount) + re.Equal(v.treeLen, cache.tree.Len()) + re.NotNil(cache.bucketsOfRegion[v.regionID]) + re.NotNil(cache.getBucketsByKeyRange([]byte("10"), nil)) } } -func (t *testHotBucketCache) TestConvertToBucketTreeStat(c *C) { +func TestConvertToBucketTreeStat(t *testing.T) { + re := require.New(t) buckets := &metapb.Buckets{ RegionId: 1, Version: 0, @@ -116,14 +110,15 @@ func (t *testHotBucketCache) TestConvertToBucketTreeStat(c *C) { PeriodInMs: 1000, } item := convertToBucketTreeItem(buckets) - c.Assert(item.startKey, BytesEquals, []byte{'1'}) - c.Assert(item.endKey, BytesEquals, []byte{'5'}) - c.Assert(item.regionID, Equals, uint64(1)) - c.Assert(item.version, Equals, uint64(0)) - c.Assert(item.stats, HasLen, 4) + re.Equal([]byte{'1'}, item.startKey) + re.Equal([]byte{'5'}, item.endKey) + re.Equal(uint64(1), item.regionID) + re.Equal(uint64(0), item.version) + re.Len(item.stats, 4) } -func (t *testHotBucketCache) TestGetBucketsByKeyRange(c *C) { +func TestGetBucketsByKeyRange(t *testing.T) { + re := require.New(t) cache := NewBucketsCache(context.Background()) bucket1 := newTestBuckets(1, 1, [][]byte{[]byte(""), []byte("015")}, 0) bucket2 := newTestBuckets(2, 1, [][]byte{[]byte("015"), []byte("020")}, 0) @@ -131,15 +126,16 @@ func (t *testHotBucketCache) TestGetBucketsByKeyRange(c *C) { cache.putItem(cache.checkBucketsFlow(bucket1)) cache.putItem(cache.checkBucketsFlow(bucket2)) cache.putItem(cache.checkBucketsFlow(bucket3)) - c.Assert(cache.getBucketsByKeyRange([]byte(""), []byte("100")), HasLen, 3) - c.Assert(cache.getBucketsByKeyRange([]byte("030"), []byte("100")), HasLen, 1) - c.Assert(cache.getBucketsByKeyRange([]byte("010"), []byte("030")), HasLen, 3) - c.Assert(cache.getBucketsByKeyRange([]byte("015"), []byte("020")), HasLen, 1) - c.Assert(cache.getBucketsByKeyRange([]byte("001"), []byte("")), HasLen, 3) - c.Assert(cache.bucketsOfRegion, HasLen, 3) + re.Len(cache.getBucketsByKeyRange([]byte(""), []byte("100")), 3) + re.Len(cache.getBucketsByKeyRange([]byte("030"), []byte("100")), 1) + re.Len(cache.getBucketsByKeyRange([]byte("010"), []byte("030")), 3) + re.Len(cache.getBucketsByKeyRange([]byte("015"), []byte("020")), 1) + re.Len(cache.getBucketsByKeyRange([]byte("001"), []byte("")), 3) + re.Len(cache.bucketsOfRegion, 3) } -func (t *testHotBucketCache) TestInherit(c *C) { +func TestInherit(t *testing.T) { + re := require.New(t) originBucketItem := convertToBucketTreeItem(newTestBuckets(1, 1, [][]byte{[]byte(""), []byte("20"), []byte("50"), []byte("")}, 0)) originBucketItem.stats[0].HotDegree = 3 originBucketItem.stats[1].HotDegree = 2 @@ -173,15 +169,15 @@ func (t *testHotBucketCache) TestInherit(c *C) { for _, v := range testdata { buckets := convertToBucketTreeItem(v.buckets) buckets.inherit([]*BucketTreeItem{originBucketItem}) - c.Assert(buckets.stats, HasLen, len(v.expect)) + re.Len(buckets.stats, len(v.expect)) for k, v := range v.expect { - c.Assert(buckets.stats[k].HotDegree, Equals, v) + re.Equal(v, buckets.stats[k].HotDegree) } } } -func (t *testHotBucketCache) TestBucketTreeItemClone(c *C) { - // bucket range: [010,020][020,100] +func TestBucketTreeItemClone(t *testing.T) { + re := require.New(t) origin := convertToBucketTreeItem(newTestBuckets(1, 1, [][]byte{[]byte("010"), []byte("020"), []byte("100")}, uint64(0))) testdata := []struct { startKey []byte @@ -221,30 +217,31 @@ func (t *testHotBucketCache) TestBucketTreeItemClone(c *C) { }} for _, v := range testdata { copy := origin.cloneBucketItemByRange(v.startKey, v.endKey) - c.Assert(copy.startKey, BytesEquals, v.startKey) - c.Assert(copy.endKey, BytesEquals, v.endKey) - c.Assert(copy.stats, HasLen, v.count) + re.Equal(v.startKey, copy.startKey) + re.Equal(v.endKey, copy.endKey) + re.Len(copy.stats, v.count) if v.count > 0 && v.strict { - c.Assert(copy.stats[0].StartKey, BytesEquals, v.startKey) - c.Assert(copy.stats[len(copy.stats)-1].EndKey, BytesEquals, v.endKey) + re.Equal(v.startKey, copy.stats[0].StartKey) + re.Equal(v.endKey, copy.stats[len(copy.stats)-1].EndKey) } } } -func (t *testHotBucketCache) TestCalculateHotDegree(c *C) { +func TestCalculateHotDegree(t *testing.T) { + re := require.New(t) origin := convertToBucketTreeItem(newTestBuckets(1, 1, [][]byte{[]byte("010"), []byte("100")}, uint64(0))) origin.calculateHotDegree() - c.Assert(origin.stats[0].HotDegree, Equals, -1) + re.Equal(-1, origin.stats[0].HotDegree) // case1: the dimension of read will be hot origin.stats[0].Loads = []uint64{minHotThresholds[0] + 1, minHotThresholds[1] + 1, 0, 0, 0, 0} origin.calculateHotDegree() - c.Assert(origin.stats[0].HotDegree, Equals, 0) + re.Equal(0, origin.stats[0].HotDegree) // case1: the dimension of write will be hot origin.stats[0].Loads = []uint64{0, 0, 0, minHotThresholds[3] + 1, minHotThresholds[4] + 1, 0} origin.calculateHotDegree() - c.Assert(origin.stats[0].HotDegree, Equals, 1) + re.Equal(1, origin.stats[0].HotDegree) } func newTestBuckets(regionID uint64, version uint64, keys [][]byte, flow uint64) *metapb.Buckets { diff --git a/server/statistics/buckets/hot_bucket_task_test.go b/server/statistics/buckets/hot_bucket_task_test.go index a5fe0d7ad8c..49f60116c9d 100644 --- a/server/statistics/buckets/hot_bucket_task_test.go +++ b/server/statistics/buckets/hot_bucket_task_test.go @@ -18,24 +18,21 @@ import ( "context" "math" "strconv" + "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testHotBucketTaskCache{}) - -type testHotBucketTaskCache struct { -} - func getAllBucketStats(ctx context.Context, hotCache *HotBucketCache) map[uint64][]*BucketStat { task := NewCollectBucketStatsTask(minHotDegree) hotCache.CheckAsync(task) return task.WaitRet(ctx) } -func (s *testHotBucketTaskCache) TestColdHot(c *C) { +func TestColdHot(t *testing.T) { + re := require.New(t) ctx, cancelFn := context.WithCancel(context.Background()) defer cancelFn() hotCache := NewBucketsCache(ctx) @@ -52,60 +49,63 @@ func (s *testHotBucketTaskCache) TestColdHot(c *C) { for _, v := range testdata { for i := 0; i < 20; i++ { task := NewCheckPeerTask(v.buckets) - c.Assert(hotCache.CheckAsync(task), IsTrue) + re.True(hotCache.CheckAsync(task)) hotBuckets := getAllBucketStats(ctx, hotCache) time.Sleep(time.Millisecond * 10) item := hotBuckets[v.buckets.RegionId] - c.Assert(item, NotNil) + re.NotNil(item) if v.isHot { - c.Assert(item[0].HotDegree, Equals, i+1) + re.Equal(i+1, item[0].HotDegree) } else { - c.Assert(item[0].HotDegree, Equals, -i-1) + re.Equal(-i-1, item[0].HotDegree) } } } } -func (s *testHotBucketTaskCache) TestCheckBucketsTask(c *C) { +func TestCheckBucketsTask(t *testing.T) { + re := require.New(t) ctx, cancelFn := context.WithCancel(context.Background()) defer cancelFn() hotCache := NewBucketsCache(ctx) // case1: add bucket successfully buckets := newTestBuckets(1, 1, [][]byte{[]byte("10"), []byte("20"), []byte("30")}, 0) task := NewCheckPeerTask(buckets) - c.Assert(hotCache.CheckAsync(task), IsTrue) + re.True(hotCache.CheckAsync(task)) time.Sleep(time.Millisecond * 10) hotBuckets := getAllBucketStats(ctx, hotCache) - c.Assert(hotBuckets, HasLen, 1) + re.Len(hotBuckets, 1) item := hotBuckets[uint64(1)] - c.Assert(item, NotNil) - c.Assert(item, HasLen, 2) - c.Assert(item[0].HotDegree, Equals, -1) - c.Assert(item[1].HotDegree, Equals, -1) + re.NotNil(item) + + re.Len(item, 2) + re.Equal(-1, item[0].HotDegree) + re.Equal(-1, item[1].HotDegree) // case2: add bucket successful and the hot degree should inherit from the old one. buckets = newTestBuckets(2, 1, [][]byte{[]byte("20"), []byte("30")}, 0) task = NewCheckPeerTask(buckets) - c.Assert(hotCache.CheckAsync(task), IsTrue) + re.True(hotCache.CheckAsync(task)) hotBuckets = getAllBucketStats(ctx, hotCache) time.Sleep(time.Millisecond * 10) item = hotBuckets[uint64(2)] - c.Assert(item, HasLen, 1) - c.Assert(item[0].HotDegree, Equals, -2) + re.Len(item, 1) + re.Equal(-2, item[0].HotDegree) // case3:add bucket successful and the hot degree should inherit from the old one. buckets = newTestBuckets(1, 1, [][]byte{[]byte("10"), []byte("20")}, 0) task = NewCheckPeerTask(buckets) - c.Assert(hotCache.CheckAsync(task), IsTrue) + re.True(hotCache.CheckAsync(task)) hotBuckets = getAllBucketStats(ctx, hotCache) time.Sleep(time.Millisecond * 10) item = hotBuckets[uint64(1)] - c.Assert(item, HasLen, 1) - c.Assert(item[0].HotDegree, Equals, -2) + re.Len(item, 1) + re.Equal(-2, item[0].HotDegree) } -func (s *testHotBucketTaskCache) TestCollectBucketStatsTask(c *C) { +func TestCollectBucketStatsTask(t *testing.T) { + re := require.New(t) ctx, cancelFn := context.WithCancel(context.Background()) defer cancelFn() hotCache := NewBucketsCache(ctx) @@ -117,11 +117,11 @@ func (s *testHotBucketTaskCache) TestCollectBucketStatsTask(c *C) { } time.Sleep(time.Millisecond * 10) task := NewCollectBucketStatsTask(-100) - c.Assert(hotCache.CheckAsync(task), IsTrue) + re.True(hotCache.CheckAsync(task)) stats := task.WaitRet(ctx) - c.Assert(stats, HasLen, 10) + re.Len(stats, 10) task = NewCollectBucketStatsTask(1) - c.Assert(hotCache.CheckAsync(task), IsTrue) + re.True(hotCache.CheckAsync(task)) stats = task.WaitRet(ctx) - c.Assert(stats, HasLen, 0) + re.Len(stats, 0) } From b100049b4548ac84c8da2566cb65b1f204c59a10 Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Tue, 21 Jun 2022 21:14:37 +0800 Subject: [PATCH 63/82] *: replace `testcase` with `testCase` (#5206) ref tikv/pd#4399 Signed-off-by: lhy1024 --- server/core/region_test.go | 24 +++++++-------- server/schedule/checker/rule_checker_test.go | 12 ++++---- .../placement/region_rule_cache_test.go | 14 ++++----- server/schedule/placement/rule_test.go | 10 +++---- server/schedule/region_scatterer_test.go | 30 +++++++++---------- server/schedulers/hot_region_test.go | 24 +++++++-------- server/schedulers/scheduler_test.go | 10 +++---- 7 files changed, 62 insertions(+), 62 deletions(-) diff --git a/server/core/region_test.go b/server/core/region_test.go index edf55c8ac7b..a86a8490bd8 100644 --- a/server/core/region_test.go +++ b/server/core/region_test.go @@ -67,7 +67,7 @@ func TestNeedMerge(t *testing.T) { func TestSortedEqual(t *testing.T) { re := require.New(t) - testcases := []struct { + testCases := []struct { idsA []int idsB []int isEqual bool @@ -146,7 +146,7 @@ func TestSortedEqual(t *testing.T) { return peers } // test NewRegionInfo - for _, test := range testcases { + for _, test := range testCases { regionA := NewRegionInfo(&metapb.Region{Id: 100, Peers: pickPeers(test.idsA)}, nil) regionB := NewRegionInfo(&metapb.Region{Id: 100, Peers: pickPeers(test.idsB)}, nil) re.Equal(test.isEqual, SortedPeersEqual(regionA.GetVoters(), regionB.GetVoters())) @@ -154,7 +154,7 @@ func TestSortedEqual(t *testing.T) { } // test RegionFromHeartbeat - for _, test := range testcases { + for _, test := range testCases { regionA := RegionFromHeartbeat(&pdpb.RegionHeartbeatRequest{ Region: &metapb.Region{Id: 100, Peers: pickPeers(test.idsA)}, DownPeers: pickPeerStats(test.idsA), @@ -173,7 +173,7 @@ func TestSortedEqual(t *testing.T) { // test Clone region := NewRegionInfo(meta, meta.Peers[0]) - for _, test := range testcases { + for _, test := range testCases { downPeersA := pickPeerStats(test.idsA) downPeersB := pickPeerStats(test.idsB) pendingPeersA := pickPeers(test.idsA) @@ -190,7 +190,7 @@ func TestInherit(t *testing.T) { re := require.New(t) // size in MB // case for approximateSize - testcases := []struct { + testCases := []struct { originExists bool originSize uint64 size uint64 @@ -202,7 +202,7 @@ func TestInherit(t *testing.T) { {true, 1, 2, 2}, {true, 2, 0, 2}, } - for _, test := range testcases { + for _, test := range testCases { var origin *RegionInfo if test.originExists { origin = NewRegionInfo(&metapb.Region{Id: 100}, nil) @@ -240,7 +240,7 @@ func TestInherit(t *testing.T) { func TestRegionRoundingFlow(t *testing.T) { re := require.New(t) - testcases := []struct { + testCases := []struct { flow uint64 digit int expect uint64 @@ -254,7 +254,7 @@ func TestRegionRoundingFlow(t *testing.T) { {252623, math.MaxInt64, 0}, {252623, math.MinInt64, 252623}, } - for _, test := range testcases { + for _, test := range testCases { r := NewRegionInfo(&metapb.Region{Id: 100}, nil, WithFlowRoundByDigit(test.digit)) r.readBytes = test.flow r.writtenBytes = test.flow @@ -264,7 +264,7 @@ func TestRegionRoundingFlow(t *testing.T) { func TestRegionWriteRate(t *testing.T) { re := require.New(t) - testcases := []struct { + testCases := []struct { bytes uint64 keys uint64 interval uint64 @@ -280,7 +280,7 @@ func TestRegionWriteRate(t *testing.T) { {0, 0, 500, 0, 0}, {10, 3, 500, 0, 0}, } - for _, test := range testcases { + for _, test := range testCases { r := NewRegionInfo(&metapb.Region{Id: 100}, nil, SetWrittenBytes(test.bytes), SetWrittenKeys(test.keys), SetReportInterval(test.interval)) bytesRate, keysRate := r.GetWriteRate() re.Equal(test.expectBytesRate, bytesRate) @@ -304,7 +304,7 @@ func TestNeedSync(t *testing.T) { } region := NewRegionInfo(meta, meta.Peers[0]) - testcases := []struct { + testCases := []struct { optionsA []RegionCreateOption optionsB []RegionCreateOption needSync bool @@ -357,7 +357,7 @@ func TestNeedSync(t *testing.T) { }, } - for _, test := range testcases { + for _, test := range testCases { regionA := region.Clone(test.optionsA...) regionB := region.Clone(test.optionsB...) _, _, _, needSync := RegionGuide(regionA, regionB) diff --git a/server/schedule/checker/rule_checker_test.go b/server/schedule/checker/rule_checker_test.go index ea9a369348e..3eb01d3655a 100644 --- a/server/schedule/checker/rule_checker_test.go +++ b/server/schedule/checker/rule_checker_test.go @@ -605,7 +605,7 @@ func (suite *ruleCheckerTestSuite) TestRuleCache() { region = region.Clone(core.WithIncConfVer(), core.WithIncVersion()) suite.Nil(suite.rc.Check(region)) - testcases := []struct { + testCases := []struct { name string region *core.RegionInfo stillCached bool @@ -643,15 +643,15 @@ func (suite *ruleCheckerTestSuite) TestRuleCache() { stillCached: false, }, } - for _, testcase := range testcases { - suite.T().Log(testcase.name) - if testcase.stillCached { + for _, testCase := range testCases { + suite.T().Log(testCase.name) + if testCase.stillCached { suite.NoError(failpoint.Enable("github.com/tikv/pd/server/schedule/checker/assertShouldCache", "return(true)")) - suite.rc.Check(testcase.region) + suite.rc.Check(testCase.region) suite.NoError(failpoint.Disable("github.com/tikv/pd/server/schedule/checker/assertShouldCache")) } else { suite.NoError(failpoint.Enable("github.com/tikv/pd/server/schedule/checker/assertShouldNotCache", "return(true)")) - suite.rc.Check(testcase.region) + suite.rc.Check(testCase.region) suite.NoError(failpoint.Disable("github.com/tikv/pd/server/schedule/checker/assertShouldNotCache")) } } diff --git a/server/schedule/placement/region_rule_cache_test.go b/server/schedule/placement/region_rule_cache_test.go index f38d13eba87..7c578a11e34 100644 --- a/server/schedule/placement/region_rule_cache_test.go +++ b/server/schedule/placement/region_rule_cache_test.go @@ -30,7 +30,7 @@ func TestRegionRuleFitCache(t *testing.T) { originRules := addExtraRules(0) originStores := mockStores(3) cache := mockRegionRuleFitCache(originRegion, originRules, originStores) - testcases := []struct { + testCases := []struct { name string region *core.RegionInfo rules []*Rule @@ -175,13 +175,13 @@ func TestRegionRuleFitCache(t *testing.T) { unchanged: false, }, } - for _, testcase := range testcases { - t.Log(testcase.name) - re.Equal(testcase.unchanged, cache.IsUnchanged(testcase.region, testcase.rules, mockStores(3))) + for _, testCase := range testCases { + t.Log(testCase.name) + re.Equal(testCase.unchanged, cache.IsUnchanged(testCase.region, testCase.rules, mockStores(3))) } - for _, testcase := range testcases { - t.Log(testcase.name) - re.Equal(false, cache.IsUnchanged(testcase.region, testcase.rules, mockStoresNoHeartbeat(3))) + for _, testCase := range testCases { + t.Log(testCase.name) + re.Equal(false, cache.IsUnchanged(testCase.region, testCase.rules, mockStoresNoHeartbeat(3))) } // Invalid Input4 re.False(cache.IsUnchanged(mockRegion(3, 0), addExtraRules(0), nil)) diff --git a/server/schedule/placement/rule_test.go b/server/schedule/placement/rule_test.go index 94f623ef93d..ba2d1bf50e4 100644 --- a/server/schedule/placement/rule_test.go +++ b/server/schedule/placement/rule_test.go @@ -132,7 +132,7 @@ func TestBuildRuleList(t *testing.T) { Count: 5, } - testcases := []struct { + testCases := []struct { name string rules map[[2]string]*Rule expect ruleList @@ -178,11 +178,11 @@ func TestBuildRuleList(t *testing.T) { }, } - for _, testcase := range testcases { - t.Log(testcase.name) - config := &ruleConfig{rules: testcase.rules} + for _, testCase := range testCases { + t.Log(testCase.name) + config := &ruleConfig{rules: testCase.rules} result, err := buildRuleList(config) re.NoError(err) - re.Equal(testcase.expect.ranges, result.ranges) + re.Equal(testCase.expect.ranges, result.ranges) } } diff --git a/server/schedule/region_scatterer_test.go b/server/schedule/region_scatterer_test.go index 373681f9f11..bfbd99f1e4e 100644 --- a/server/schedule/region_scatterer_test.go +++ b/server/schedule/region_scatterer_test.go @@ -255,7 +255,7 @@ func (s *testScatterRegionSuite) TestScatterCheck(c *C) { for i := uint64(1); i <= 5; i++ { tc.AddRegionStore(i, 0) } - testcases := []struct { + testCases := []struct { name string checkRegion *core.RegionInfo needFix bool @@ -276,11 +276,11 @@ func (s *testScatterRegionSuite) TestScatterCheck(c *C) { needFix: true, }, } - for _, testcase := range testcases { - c.Logf(testcase.name) + for _, testCase := range testCases { + c.Logf(testCase.name) scatterer := NewRegionScatterer(ctx, tc) - _, err := scatterer.Scatter(testcase.checkRegion, "") - if testcase.needFix { + _, err := scatterer.Scatter(testCase.checkRegion, "") + if testCase.needFix { c.Assert(err, NotNil) c.Assert(tc.CheckRegionUnderSuspect(1), IsTrue) } else { @@ -303,7 +303,7 @@ func (s *testScatterRegionSuite) TestScatterGroupInConcurrency(c *C) { tc.SetStoreLastHeartbeatInterval(i, -10*time.Minute) } - testcases := []struct { + testCases := []struct { name string groupCount int }{ @@ -322,12 +322,12 @@ func (s *testScatterRegionSuite) TestScatterGroupInConcurrency(c *C) { } // We send scatter interweave request for each group to simulate scattering multiple region groups in concurrency. - for _, testcase := range testcases { - c.Logf(testcase.name) + for _, testCase := range testCases { + c.Logf(testCase.name) scatterer := NewRegionScatterer(ctx, tc) regionID := 1 for i := 0; i < 100; i++ { - for j := 0; j < testcase.groupCount; j++ { + for j := 0; j < testCase.groupCount; j++ { scatterer.scatterRegion(tc.AddLeaderRegion(uint64(regionID), 1, 2, 3), fmt.Sprintf("group-%v", j)) regionID++ @@ -335,7 +335,7 @@ func (s *testScatterRegionSuite) TestScatterGroupInConcurrency(c *C) { } checker := func(ss *selectedStores, expected uint64, delta float64) { - for i := 0; i < testcase.groupCount; i++ { + for i := 0; i < testCase.groupCount; i++ { // comparing the leader distribution group := fmt.Sprintf("group-%v", i) max := uint64(0) @@ -369,7 +369,7 @@ func (s *testScatterRegionSuite) TestScattersGroup(c *C) { for i := uint64(1); i <= 5; i++ { tc.AddRegionStore(i, 0) } - testcases := []struct { + testCases := []struct { name string failure bool }{ @@ -383,15 +383,15 @@ func (s *testScatterRegionSuite) TestScattersGroup(c *C) { }, } group := "group" - for _, testcase := range testcases { + for _, testCase := range testCases { scatterer := NewRegionScatterer(ctx, tc) regions := map[uint64]*core.RegionInfo{} for i := 1; i <= 100; i++ { regions[uint64(i)] = tc.AddLeaderRegion(uint64(i), 1, 2, 3) } - c.Log(testcase.name) + c.Log(testCase.name) failures := map[uint64]error{} - if testcase.failure { + if testCase.failure { c.Assert(failpoint.Enable("github.com/tikv/pd/server/schedule/scatterFail", `return(true)`), IsNil) } @@ -412,7 +412,7 @@ func (s *testScatterRegionSuite) TestScattersGroup(c *C) { c.Assert(min, LessEqual, uint64(20)) c.Assert(max, GreaterEqual, uint64(20)) c.Assert(max-min, LessEqual, uint64(3)) - if testcase.failure { + if testCase.failure { c.Assert(failures, HasLen, 1) _, ok := failures[1] c.Assert(ok, IsTrue) diff --git a/server/schedulers/hot_region_test.go b/server/schedulers/hot_region_test.go index 66ed9ec3c9c..a0abd95e2ba 100644 --- a/server/schedulers/hot_region_test.go +++ b/server/schedulers/hot_region_test.go @@ -1479,7 +1479,7 @@ func addRegionLeaderReadInfo(tc *mockcluster.Cluster, regions []testRegionInfo) } func (s *testHotCacheSuite) TestCheckRegionFlow(c *C) { - testcases := []struct { + testCases := []struct { kind statistics.RWType onlyLeader bool DegreeAfterTransferLeader int @@ -1501,7 +1501,7 @@ func (s *testHotCacheSuite) TestCheckRegionFlow(c *C) { }, } - for _, testcase := range testcases { + for _, testCase := range testCases { ctx, cancel := context.WithCancel(context.Background()) opt := config.NewTestOptions() tc := mockcluster.NewCluster(ctx, opt) @@ -1512,8 +1512,8 @@ func (s *testHotCacheSuite) TestCheckRegionFlow(c *C) { c.Assert(err, IsNil) hb := sche.(*hotScheduler) heartbeat := tc.AddLeaderRegionWithWriteInfo - if testcase.kind == statistics.Read { - if testcase.onlyLeader { + if testCase.kind == statistics.Read { + if testCase.onlyLeader { heartbeat = tc.AddRegionLeaderWithReadInfo } else { heartbeat = tc.AddRegionWithReadInfo @@ -1522,7 +1522,7 @@ func (s *testHotCacheSuite) TestCheckRegionFlow(c *C) { tc.AddRegionStore(2, 20) tc.UpdateStorageReadStats(2, 9.5*MB*statistics.StoreHeartBeatReportInterval, 9.5*MB*statistics.StoreHeartBeatReportInterval) reportInterval := uint64(statistics.WriteReportInterval) - if testcase.kind == statistics.Read { + if testCase.kind == statistics.Read { reportInterval = uint64(statistics.ReadReportInterval) } // hot degree increase @@ -1537,15 +1537,15 @@ func (s *testHotCacheSuite) TestCheckRegionFlow(c *C) { items = heartbeat(1, 2, 512*KB*reportInterval, 0, 0, reportInterval, []uint64{1, 3}, 1) for _, item := range items { if item.StoreID == 2 { - c.Check(item.HotDegree, Equals, testcase.DegreeAfterTransferLeader) + c.Check(item.HotDegree, Equals, testCase.DegreeAfterTransferLeader) } } - if testcase.DegreeAfterTransferLeader >= 3 { + if testCase.DegreeAfterTransferLeader >= 3 { // try schedule - hb.prepareForBalance(testcase.kind, tc) - leaderSolver := newBalanceSolver(hb, tc, testcase.kind, transferLeader) - leaderSolver.cur = &solution{srcStore: hb.stLoadInfos[toResourceType(testcase.kind, transferLeader)][2]} + hb.prepareForBalance(testCase.kind, tc) + leaderSolver := newBalanceSolver(hb, tc, testCase.kind, transferLeader) + leaderSolver.cur = &solution{srcStore: hb.stLoadInfos[toResourceType(testCase.kind, transferLeader)][2]} c.Check(leaderSolver.filterHotPeers(leaderSolver.cur.srcStore), HasLen, 0) // skip schedule threshold := tc.GetHotRegionCacheHitsThreshold() leaderSolver.minHotDegree = 0 @@ -1557,7 +1557,7 @@ func (s *testHotCacheSuite) TestCheckRegionFlow(c *C) { items = heartbeat(1, 2, 512*KB*reportInterval, 0, 0, reportInterval, []uint64{1, 3, 4}, 1) c.Check(len(items), Greater, 0) for _, item := range items { - c.Check(item.HotDegree, Equals, testcase.DegreeAfterTransferLeader+1) + c.Check(item.HotDegree, Equals, testCase.DegreeAfterTransferLeader+1) } items = heartbeat(1, 2, 512*KB*reportInterval, 0, 0, reportInterval, []uint64{1, 4}, 1) c.Check(len(items), Greater, 0) @@ -1566,7 +1566,7 @@ func (s *testHotCacheSuite) TestCheckRegionFlow(c *C) { c.Check(item.GetActionType(), Equals, statistics.Remove) continue } - c.Check(item.HotDegree, Equals, testcase.DegreeAfterTransferLeader+2) + c.Check(item.HotDegree, Equals, testCase.DegreeAfterTransferLeader+2) } cancel() } diff --git a/server/schedulers/scheduler_test.go b/server/schedulers/scheduler_test.go index b51fbbfd73f..64da676d0a8 100644 --- a/server/schedulers/scheduler_test.go +++ b/server/schedulers/scheduler_test.go @@ -471,7 +471,7 @@ func (s *testBalanceLeaderSchedulerWithRuleEnabledSuite) TestBalanceLeaderWithCo // Leaders: 16 0 0 // Region1: L F F s.tc.UpdateLeaderCount(1, 16) - testcases := []struct { + testCases := []struct { name string rule *placement.Rule schedule bool @@ -534,10 +534,10 @@ func (s *testBalanceLeaderSchedulerWithRuleEnabledSuite) TestBalanceLeaderWithCo }, } - for _, testcase := range testcases { - c.Logf(testcase.name) - c.Check(s.tc.SetRule(testcase.rule), IsNil) - if testcase.schedule { + for _, testCase := range testCases { + c.Logf(testCase.name) + c.Check(s.tc.SetRule(testCase.rule), IsNil) + if testCase.schedule { c.Check(len(s.schedule()), Equals, 1) } else { c.Assert(s.schedule(), HasLen, 0) From ff82285c147ad65464b707210c0dcdfa3f4f278b Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Wed, 22 Jun 2022 12:16:36 +0800 Subject: [PATCH 64/82] *: replace `Id` with `ID` (#5209) ref tikv/pd#4399 Signed-off-by: lhy1024 --- server/api/region_test.go | 36 ++++++++++----------- server/core/basic_cluster.go | 2 +- server/core/region.go | 4 +-- server/core/region_option.go | 8 ++--- server/core/region_tree_test.go | 2 +- server/schedule/checker/merge_checker.go | 2 +- server/schedule/checker/replica_strategy.go | 2 +- server/schedule/test_util.go | 4 +-- server/schedulers/balance_region.go | 2 +- server/schedulers/grant_hot_region.go | 2 +- server/schedulers/hot_region.go | 2 +- server/schedulers/shuffle_hot_region.go | 2 +- server/schedulers/shuffle_region.go | 2 +- tools/pd-simulator/simulator/event.go | 2 +- tools/pd-simulator/simulator/raft.go | 6 ++-- 15 files changed, 39 insertions(+), 39 deletions(-) diff --git a/server/api/region_test.go b/server/api/region_test.go index 7fc80254132..816fefcc092 100644 --- a/server/api/region_test.go +++ b/server/api/region_test.go @@ -456,57 +456,57 @@ func (s *testGetRegionSuite) TestScanRegionByKeys(c *C) { mustRegionHeartbeat(c, s.svr, r) url := fmt.Sprintf("%s/regions/key?key=%s", s.urlPrefix, "b") - regionIds := []uint64{3, 4, 5, 99} + regionIDs := []uint64{3, 4, 5, 99} regions := &RegionsInfo{} err := tu.ReadGetJSON(c, testDialClient, url, regions) c.Assert(err, IsNil) - c.Assert(regionIds, HasLen, regions.Count) - for i, v := range regionIds { + c.Assert(regionIDs, HasLen, regions.Count) + for i, v := range regionIDs { c.Assert(v, Equals, regions.Regions[i].ID) } url = fmt.Sprintf("%s/regions/key?key=%s", s.urlPrefix, "d") - regionIds = []uint64{4, 5, 99} + regionIDs = []uint64{4, 5, 99} regions = &RegionsInfo{} err = tu.ReadGetJSON(c, testDialClient, url, regions) c.Assert(err, IsNil) - c.Assert(regionIds, HasLen, regions.Count) - for i, v := range regionIds { + c.Assert(regionIDs, HasLen, regions.Count) + for i, v := range regionIDs { c.Assert(v, Equals, regions.Regions[i].ID) } url = fmt.Sprintf("%s/regions/key?key=%s", s.urlPrefix, "g") - regionIds = []uint64{5, 99} + regionIDs = []uint64{5, 99} regions = &RegionsInfo{} err = tu.ReadGetJSON(c, testDialClient, url, regions) c.Assert(err, IsNil) - c.Assert(regionIds, HasLen, regions.Count) - for i, v := range regionIds { + c.Assert(regionIDs, HasLen, regions.Count) + for i, v := range regionIDs { c.Assert(v, Equals, regions.Regions[i].ID) } url = fmt.Sprintf("%s/regions/key?end_key=%s", s.urlPrefix, "e") - regionIds = []uint64{2, 3, 4} + regionIDs = []uint64{2, 3, 4} regions = &RegionsInfo{} err = tu.ReadGetJSON(c, testDialClient, url, regions) c.Assert(err, IsNil) - c.Assert(len(regionIds), Equals, regions.Count) - for i, v := range regionIds { + c.Assert(len(regionIDs), Equals, regions.Count) + for i, v := range regionIDs { c.Assert(v, Equals, regions.Regions[i].ID) } url = fmt.Sprintf("%s/regions/key?key=%s&end_key=%s", s.urlPrefix, "b", "g") - regionIds = []uint64{3, 4} + regionIDs = []uint64{3, 4} regions = &RegionsInfo{} err = tu.ReadGetJSON(c, testDialClient, url, regions) c.Assert(err, IsNil) - c.Assert(len(regionIds), Equals, regions.Count) - for i, v := range regionIds { + c.Assert(len(regionIDs), Equals, regions.Count) + for i, v := range regionIDs { c.Assert(v, Equals, regions.Regions[i].ID) } url = fmt.Sprintf("%s/regions/key?key=%s&end_key=%s", s.urlPrefix, "b", []byte{0xFF, 0xFF, 0xCC}) - regionIds = []uint64{3, 4, 5, 99} + regionIDs = []uint64{3, 4, 5, 99} regions = &RegionsInfo{} err = tu.ReadGetJSON(c, testDialClient, url, regions) c.Assert(err, IsNil) - c.Assert(len(regionIds), Equals, regions.Count) - for i, v := range regionIds { + c.Assert(len(regionIDs), Equals, regions.Count) + for i, v := range regionIDs { c.Assert(v, Equals, regions.Regions[i].ID) } } diff --git a/server/core/basic_cluster.go b/server/core/basic_cluster.go index 9100636d376..6509857398d 100644 --- a/server/core/basic_cluster.go +++ b/server/core/basic_cluster.go @@ -95,7 +95,7 @@ func (bc *BasicCluster) GetRegionStores(region *RegionInfo) []*StoreInfo { bc.RLock() defer bc.RUnlock() var Stores []*StoreInfo - for id := range region.GetStoreIds() { + for id := range region.GetStoreIDs() { if store := bc.Stores.GetStore(id); store != nil { Stores = append(Stores, store) } diff --git a/server/core/region.go b/server/core/region.go index cc688712ad8..8114a0e411d 100644 --- a/server/core/region.go +++ b/server/core/region.go @@ -349,8 +349,8 @@ func (r *RegionInfo) GetStoreLearner(storeID uint64) *metapb.Peer { return nil } -// GetStoreIds returns a map indicate the region distributed. -func (r *RegionInfo) GetStoreIds() map[uint64]struct{} { +// GetStoreIDs returns a map indicate the region distributed. +func (r *RegionInfo) GetStoreIDs() map[uint64]struct{} { peers := r.meta.GetPeers() stores := make(map[uint64]struct{}, len(peers)) for _, peer := range peers { diff --git a/server/core/region_option.go b/server/core/region_option.go index b06c784c0f6..c0d2204de2b 100644 --- a/server/core/region_option.go +++ b/server/core/region_option.go @@ -96,14 +96,14 @@ func WithNewRegionID(id uint64) RegionCreateOption { } } -// WithNewPeerIds sets new ids for peers. -func WithNewPeerIds(peerIds ...uint64) RegionCreateOption { +// WithNewPeerIDs sets new ids for peers. +func WithNewPeerIDs(peerIDs ...uint64) RegionCreateOption { return func(region *RegionInfo) { - if len(peerIds) != len(region.meta.GetPeers()) { + if len(peerIDs) != len(region.meta.GetPeers()) { return } for i, p := range region.meta.GetPeers() { - p.Id = peerIds[i] + p.Id = peerIDs[i] } } } diff --git a/server/core/region_tree_test.go b/server/core/region_tree_test.go index 0f813717fcb..7538f04dd74 100644 --- a/server/core/region_tree_test.go +++ b/server/core/region_tree_test.go @@ -80,7 +80,7 @@ func TestRegionInfo(t *testing.T) { r = r.Clone(WithEndKey([]byte{1})) re.Regexp(".*EndKey Changed.*", DiffRegionKeyInfo(r, info)) - stores := r.GetStoreIds() + stores := r.GetStoreIDs() re.Len(stores, int(n)) for i := uint64(0); i < n; i++ { _, ok := stores[i] diff --git a/server/schedule/checker/merge_checker.go b/server/schedule/checker/merge_checker.go index e5624d48d2d..fcfc087a3cb 100644 --- a/server/schedule/checker/merge_checker.go +++ b/server/schedule/checker/merge_checker.go @@ -282,7 +282,7 @@ func isTableIDSame(region, adjacent *core.RegionInfo) bool { // while the source region has no peer on it. This is to prevent from bringing // any other peer into an offline store to slow down the offline process. func checkPeerStore(cluster schedule.Cluster, region, adjacent *core.RegionInfo) bool { - regionStoreIDs := region.GetStoreIds() + regionStoreIDs := region.GetStoreIDs() for _, peer := range adjacent.GetPeers() { storeID := peer.GetStoreId() store := cluster.GetStore(storeID) diff --git a/server/schedule/checker/replica_strategy.go b/server/schedule/checker/replica_strategy.go index 6ccad30a32d..4fb9d87410e 100644 --- a/server/schedule/checker/replica_strategy.go +++ b/server/schedule/checker/replica_strategy.go @@ -53,7 +53,7 @@ func (s *ReplicaStrategy) SelectStoreToAdd(coLocationStores []*core.StoreInfo, e // The reason for it is to prevent the non-optimal replica placement due // to the short-term state, resulting in redundant scheduling. filters := []filter.Filter{ - filter.NewExcludedFilter(s.checkerName, nil, s.region.GetStoreIds()), + filter.NewExcludedFilter(s.checkerName, nil, s.region.GetStoreIDs()), filter.NewStorageThresholdFilter(s.checkerName), filter.NewSpecialUseFilter(s.checkerName), &filter.StoreStateFilter{ActionScope: s.checkerName, MoveRegion: true, AllowTemporaryStates: true}, diff --git a/server/schedule/test_util.go b/server/schedule/test_util.go index b6484042d3b..67b6f891474 100644 --- a/server/schedule/test_util.go +++ b/server/schedule/test_util.go @@ -79,10 +79,10 @@ func ApplyOperator(mc *mockcluster.Cluster, op *operator.Operator) { region = ApplyOperatorStep(region, op) } mc.PutRegion(region) - for id := range region.GetStoreIds() { + for id := range region.GetStoreIDs() { mc.UpdateStoreStatus(id) } - for id := range origin.GetStoreIds() { + for id := range origin.GetStoreIDs() { mc.UpdateStoreStatus(id) } } diff --git a/server/schedulers/balance_region.go b/server/schedulers/balance_region.go index f8d4679acbc..033446ae380 100644 --- a/server/schedulers/balance_region.go +++ b/server/schedulers/balance_region.go @@ -216,7 +216,7 @@ func (s *balanceRegionScheduler) Schedule(cluster schedule.Cluster) []*operator. // transferPeer selects the best store to create a new peer to replace the old peer. func (s *balanceRegionScheduler) transferPeer(plan *balancePlan) *operator.Operator { filters := []filter.Filter{ - filter.NewExcludedFilter(s.GetName(), nil, plan.region.GetStoreIds()), + filter.NewExcludedFilter(s.GetName(), nil, plan.region.GetStoreIDs()), filter.NewPlacementSafeguard(s.GetName(), plan.GetOpts(), plan.GetBasicCluster(), plan.GetRuleManager(), plan.region, plan.source), filter.NewRegionScoreFilter(s.GetName(), plan.source, plan.GetOpts()), filter.NewSpecialUseFilter(s.GetName()), diff --git a/server/schedulers/grant_hot_region.go b/server/schedulers/grant_hot_region.go index 4decd1b1340..086bd7e26e2 100644 --- a/server/schedulers/grant_hot_region.go +++ b/server/schedulers/grant_hot_region.go @@ -361,7 +361,7 @@ func (s *grantHotRegionScheduler) transfer(cluster schedule.Cluster, regionID ui candidate = []uint64{s.conf.GetStoreLeaderID()} } else { filters = append(filters, &filter.StoreStateFilter{ActionScope: s.GetName(), MoveRegion: true}, - filter.NewExcludedFilter(s.GetName(), srcRegion.GetStoreIds(), srcRegion.GetStoreIds())) + filter.NewExcludedFilter(s.GetName(), srcRegion.GetStoreIDs(), srcRegion.GetStoreIDs())) candidate = s.conf.StoreIDs } for _, storeID := range candidate { diff --git a/server/schedulers/hot_region.go b/server/schedulers/hot_region.go index 3ca40a7133f..f6a38e8ecb5 100644 --- a/server/schedulers/hot_region.go +++ b/server/schedulers/hot_region.go @@ -728,7 +728,7 @@ func (bs *balanceSolver) filterDstStores() map[uint64]*statistics.StoreLoadDetai case movePeer: filters = []filter.Filter{ &filter.StoreStateFilter{ActionScope: bs.sche.GetName(), MoveRegion: true}, - filter.NewExcludedFilter(bs.sche.GetName(), bs.cur.region.GetStoreIds(), bs.cur.region.GetStoreIds()), + filter.NewExcludedFilter(bs.sche.GetName(), bs.cur.region.GetStoreIDs(), bs.cur.region.GetStoreIDs()), filter.NewSpecialUseFilter(bs.sche.GetName(), filter.SpecialUseHotRegion), filter.NewPlacementSafeguard(bs.sche.GetName(), bs.GetOpts(), bs.GetBasicCluster(), bs.GetRuleManager(), bs.cur.region, srcStore), } diff --git a/server/schedulers/shuffle_hot_region.go b/server/schedulers/shuffle_hot_region.go index 55a894f3c70..d33f3a5159b 100644 --- a/server/schedulers/shuffle_hot_region.go +++ b/server/schedulers/shuffle_hot_region.go @@ -179,7 +179,7 @@ func (s *shuffleHotRegionScheduler) randomSchedule(cluster schedule.Cluster, loa filters := []filter.Filter{ &filter.StoreStateFilter{ActionScope: s.GetName(), MoveRegion: true}, - filter.NewExcludedFilter(s.GetName(), srcRegion.GetStoreIds(), srcRegion.GetStoreIds()), + filter.NewExcludedFilter(s.GetName(), srcRegion.GetStoreIDs(), srcRegion.GetStoreIDs()), filter.NewPlacementSafeguard(s.GetName(), cluster.GetOpts(), cluster.GetBasicCluster(), cluster.GetRuleManager(), srcRegion, srcStore), } stores := cluster.GetStores() diff --git a/server/schedulers/shuffle_region.go b/server/schedulers/shuffle_region.go index b0a6286ceb0..f947ca79ad2 100644 --- a/server/schedulers/shuffle_region.go +++ b/server/schedulers/shuffle_region.go @@ -159,7 +159,7 @@ func (s *shuffleRegionScheduler) scheduleAddPeer(cluster schedule.Cluster, regio return nil } scoreGuard := filter.NewPlacementSafeguard(s.GetName(), cluster.GetOpts(), cluster.GetBasicCluster(), cluster.GetRuleManager(), region, store) - excludedFilter := filter.NewExcludedFilter(s.GetName(), nil, region.GetStoreIds()) + excludedFilter := filter.NewExcludedFilter(s.GetName(), nil, region.GetStoreIDs()) target := filter.NewCandidates(cluster.GetStores()). FilterTarget(cluster.GetOpts(), s.filters...). diff --git a/tools/pd-simulator/simulator/event.go b/tools/pd-simulator/simulator/event.go index 13032e45d20..f2504de9021 100644 --- a/tools/pd-simulator/simulator/event.go +++ b/tools/pd-simulator/simulator/event.go @@ -187,7 +187,7 @@ func (e *DeleteNodes) Run(raft *RaftEngine, tickCount int64) bool { regions := raft.GetRegions() for _, region := range regions { - storeIDs := region.GetStoreIds() + storeIDs := region.GetStoreIDs() if _, ok := storeIDs[id]; ok { downPeer := &pdpb.PeerStats{ Peer: region.GetStorePeer(id), diff --git a/tools/pd-simulator/simulator/raft.go b/tools/pd-simulator/simulator/raft.go index c3194071884..644a86ef7d5 100644 --- a/tools/pd-simulator/simulator/raft.go +++ b/tools/pd-simulator/simulator/raft.go @@ -146,7 +146,7 @@ func (r *RaftEngine) stepSplit(region *core.RegionInfo) { } left := region.Clone( core.WithNewRegionID(ids[len(ids)-1]), - core.WithNewPeerIds(ids[0:len(ids)-1]...), + core.WithNewPeerIDs(ids[0:len(ids)-1]...), core.WithIncVersion(), core.SetApproximateKeys(region.GetApproximateKeys()/2), core.SetApproximateSize(region.GetApproximateSize()/2), @@ -196,7 +196,7 @@ func (r *RaftEngine) updateRegionStore(region *core.RegionInfo, size int64) { core.SetApproximateSize(region.GetApproximateSize()+size), core.SetWrittenBytes(uint64(size)), ) - storeIDs := region.GetStoreIds() + storeIDs := region.GetStoreIDs() for storeID := range storeIDs { r.conn.Nodes[storeID].incUsedSize(uint64(size)) } @@ -220,7 +220,7 @@ func (r *RaftEngine) electNewLeader(region *core.RegionInfo) *metapb.Peer { unhealthy int newLeaderStoreID uint64 ) - ids := region.GetStoreIds() + ids := region.GetStoreIDs() for id := range ids { if r.conn.nodeHealth(id) { newLeaderStoreID = id From ac9383f16008c83484444047ab5b2849429c5e12 Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Wed, 22 Jun 2022 13:56:37 +0800 Subject: [PATCH 65/82] *: replace `re.Len 0` with `re.Empty` (#5210) ref tikv/pd#4813 Signed-off-by: lhy1024 --- pkg/autoscaling/calculation_test.go | 2 +- pkg/cache/cache_test.go | 2 +- pkg/encryption/master_key_test.go | 4 +-- pkg/progress/progress_test.go | 2 +- pkg/rangetree/range_tree_test.go | 12 ++++---- pkg/reflectutil/tag_test.go | 2 +- server/region_syncer/history_buffer_test.go | 2 +- .../schedule/placement/rule_manager_test.go | 4 +-- .../buckets/hot_bucket_task_test.go | 2 +- server/statistics/region_collection_test.go | 28 +++++++++---------- server/storage/storage_gc_test.go | 2 +- tests/pdctl/hot/hot_test.go | 2 +- tests/server/cluster/cluster_test.go | 4 +-- 13 files changed, 34 insertions(+), 34 deletions(-) diff --git a/pkg/autoscaling/calculation_test.go b/pkg/autoscaling/calculation_test.go index f5db53fdabd..a411cce632e 100644 --- a/pkg/autoscaling/calculation_test.go +++ b/pkg/autoscaling/calculation_test.go @@ -194,7 +194,7 @@ func TestGetScaledTiKVGroups(t *testing.T) { t.Log(testCase.name) plans, err := getScaledTiKVGroups(testCase.informer, testCase.healthyInstances) if testCase.expectedPlan == nil { - re.Len(plans, 0) + re.Empty(plans) testCase.errorChecker(err) } else { re.Equal(testCase.expectedPlan, plans) diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go index 05db409f6dc..f708f966401 100644 --- a/pkg/cache/cache_test.go +++ b/pkg/cache/cache_test.go @@ -355,7 +355,7 @@ func TestPriorityQueue(t *testing.T) { // case4 remove all element pq.Remove(uint64(2)) re.Equal(0, pq.Len()) - re.Len(pq.items, 0) + re.Empty(pq.items) re.Nil(pq.Peek()) re.Nil(pq.Tail()) } diff --git a/pkg/encryption/master_key_test.go b/pkg/encryption/master_key_test.go index 79a6bb390d9..fb32da5e79b 100644 --- a/pkg/encryption/master_key_test.go +++ b/pkg/encryption/master_key_test.go @@ -34,12 +34,12 @@ func TestPlaintextMasterKey(t *testing.T) { masterKey, err := NewMasterKey(config, nil) re.NoError(err) re.NotNil(masterKey) - re.Len(masterKey.key, 0) + re.Empty(masterKey.key) plaintext := "this is a plaintext" ciphertext, iv, err := masterKey.Encrypt([]byte(plaintext)) re.NoError(err) - re.Len(iv, 0) + re.Empty(iv) re.Equal(plaintext, string(ciphertext)) plaintext2, err := masterKey.Decrypt(ciphertext, iv) diff --git a/pkg/progress/progress_test.go b/pkg/progress/progress_test.go index cdb60c9573f..e6799fb0ff8 100644 --- a/pkg/progress/progress_test.go +++ b/pkg/progress/progress_test.go @@ -64,7 +64,7 @@ func TestProgress(t *testing.T) { ps = m.GetProgresses(func(p string) bool { return strings.Contains(p, "a") }) - re.Len(ps, 0) + re.Empty(ps) re.True(m.RemoveProgress(n)) re.False(m.RemoveProgress(n)) } diff --git a/pkg/rangetree/range_tree_test.go b/pkg/rangetree/range_tree_test.go index a9071ed5f1f..9e8e7e9dca8 100644 --- a/pkg/rangetree/range_tree_test.go +++ b/pkg/rangetree/range_tree_test.go @@ -95,11 +95,11 @@ func TestRingPutItem(t *testing.T) { re.Equal(2, bucketTree.Len()) // init key range: [002,100], [100,200] - re.Len(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("000"), []byte("002"))), 0) + re.Empty(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("000"), []byte("002")))) re.Len(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("000"), []byte("009"))), 1) re.Len(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("010"), []byte("090"))), 1) re.Len(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("010"), []byte("110"))), 2) - re.Len(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("200"), []byte("300"))), 0) + re.Empty(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("200"), []byte("300")))) // test1: insert one key range, the old overlaps will retain like split buckets. // key range: [002,010],[010,090],[090,100],[100,200] @@ -126,15 +126,15 @@ func TestDebris(t *testing.T) { ringItem := newSimpleBucketItem([]byte("010"), []byte("090")) var overlaps []RangeItem overlaps = bucketDebrisFactory([]byte("000"), []byte("100"), ringItem) - re.Len(overlaps, 0) + re.Empty(overlaps) overlaps = bucketDebrisFactory([]byte("000"), []byte("080"), ringItem) re.Len(overlaps, 1) overlaps = bucketDebrisFactory([]byte("020"), []byte("080"), ringItem) re.Len(overlaps, 2) overlaps = bucketDebrisFactory([]byte("010"), []byte("090"), ringItem) - re.Len(overlaps, 0) + re.Empty(overlaps) overlaps = bucketDebrisFactory([]byte("010"), []byte("100"), ringItem) - re.Len(overlaps, 0) + re.Empty(overlaps) overlaps = bucketDebrisFactory([]byte("100"), []byte("200"), ringItem) - re.Len(overlaps, 0) + re.Empty(overlaps) } diff --git a/pkg/reflectutil/tag_test.go b/pkg/reflectutil/tag_test.go index bff74b36f40..d8619898ea5 100644 --- a/pkg/reflectutil/tag_test.go +++ b/pkg/reflectutil/tag_test.go @@ -47,7 +47,7 @@ func TestFindJSONFullTagByChildTag(t *testing.T) { key = "disable" result = FindJSONFullTagByChildTag(reflect.TypeOf(testStruct1{}), key) - re.Len(result, 0) + re.Empty(result) } func TestFindSameFieldByJSON(t *testing.T) { diff --git a/server/region_syncer/history_buffer_test.go b/server/region_syncer/history_buffer_test.go index 49cbebdf266..26e6a113dad 100644 --- a/server/region_syncer/history_buffer_test.go +++ b/server/region_syncer/history_buffer_test.go @@ -80,7 +80,7 @@ func TestBufferSize(t *testing.T) { re.Equal("106", s) histories := h2.RecordsFrom(uint64(1)) - re.Len(histories, 0) + re.Empty(histories) histories = h2.RecordsFrom(h2.firstIndex()) re.Len(histories, 100) re.Equal(uint64(7), h2.firstIndex()) diff --git a/server/schedule/placement/rule_manager_test.go b/server/schedule/placement/rule_manager_test.go index 0fdfd2f67a8..997a3d9a12b 100644 --- a/server/schedule/placement/rule_manager_test.go +++ b/server/schedule/placement/rule_manager_test.go @@ -44,8 +44,8 @@ func TestDefault(t *testing.T) { re.Equal("pd", rules[0].GroupID) re.Equal("default", rules[0].ID) re.Equal(0, rules[0].Index) - re.Len(rules[0].StartKey, 0) - re.Len(rules[0].EndKey, 0) + re.Empty(rules[0].StartKey) + re.Empty(rules[0].EndKey) re.Equal(Voter, rules[0].Role) re.Equal([]string{"zone", "rack", "host"}, rules[0].LocationLabels) } diff --git a/server/statistics/buckets/hot_bucket_task_test.go b/server/statistics/buckets/hot_bucket_task_test.go index 49f60116c9d..f2f28ef3d02 100644 --- a/server/statistics/buckets/hot_bucket_task_test.go +++ b/server/statistics/buckets/hot_bucket_task_test.go @@ -123,5 +123,5 @@ func TestCollectBucketStatsTask(t *testing.T) { task = NewCollectBucketStatsTask(1) re.True(hotCache.CheckAsync(task)) stats = task.WaitRet(ctx) - re.Len(stats, 0) + re.Empty(stats) } diff --git a/server/statistics/region_collection_test.go b/server/statistics/region_collection_test.go index 932c35f139e..7c686f1c9ce 100644 --- a/server/statistics/region_collection_test.go +++ b/server/statistics/region_collection_test.go @@ -81,15 +81,15 @@ func TestRegionStatistics(t *testing.T) { ) regionStats.Observe(region1, stores) re.Len(regionStats.stats[ExtraPeer], 1) - re.Len(regionStats.stats[MissPeer], 0) + re.Empty(regionStats.stats[MissPeer]) re.Len(regionStats.stats[DownPeer], 1) re.Len(regionStats.stats[PendingPeer], 1) re.Len(regionStats.stats[LearnerPeer], 1) - re.Len(regionStats.stats[EmptyRegion], 0) + re.Empty(regionStats.stats[EmptyRegion]) re.Len(regionStats.stats[OversizedRegion], 1) - re.Len(regionStats.stats[UndersizedRegion], 0) + re.Empty(regionStats.stats[UndersizedRegion]) re.Len(regionStats.offlineStats[ExtraPeer], 1) - re.Len(regionStats.offlineStats[MissPeer], 0) + re.Empty(regionStats.offlineStats[MissPeer]) re.Len(regionStats.offlineStats[DownPeer], 1) re.Len(regionStats.offlineStats[PendingPeer], 1) re.Len(regionStats.offlineStats[LearnerPeer], 1) @@ -105,7 +105,7 @@ func TestRegionStatistics(t *testing.T) { re.Len(regionStats.stats[OversizedRegion], 1) re.Len(regionStats.stats[UndersizedRegion], 1) re.Len(regionStats.offlineStats[ExtraPeer], 1) - re.Len(regionStats.offlineStats[MissPeer], 0) + re.Empty(regionStats.offlineStats[MissPeer]) re.Len(regionStats.offlineStats[DownPeer], 1) re.Len(regionStats.offlineStats[PendingPeer], 1) re.Len(regionStats.offlineStats[LearnerPeer], 1) @@ -113,22 +113,22 @@ func TestRegionStatistics(t *testing.T) { region1 = region1.Clone(core.WithRemoveStorePeer(7)) regionStats.Observe(region1, stores[0:3]) - re.Len(regionStats.stats[ExtraPeer], 0) + re.Empty(regionStats.stats[ExtraPeer]) re.Len(regionStats.stats[MissPeer], 1) re.Len(regionStats.stats[DownPeer], 2) re.Len(regionStats.stats[PendingPeer], 1) - re.Len(regionStats.stats[LearnerPeer], 0) - re.Len(regionStats.offlineStats[ExtraPeer], 0) - re.Len(regionStats.offlineStats[MissPeer], 0) - re.Len(regionStats.offlineStats[DownPeer], 0) - re.Len(regionStats.offlineStats[PendingPeer], 0) - re.Len(regionStats.offlineStats[LearnerPeer], 0) - re.Len(regionStats.offlineStats[OfflinePeer], 0) + re.Empty(regionStats.stats[LearnerPeer]) + re.Empty(regionStats.offlineStats[ExtraPeer]) + re.Empty(regionStats.offlineStats[MissPeer]) + re.Empty(regionStats.offlineStats[DownPeer]) + re.Empty(regionStats.offlineStats[PendingPeer]) + re.Empty(regionStats.offlineStats[LearnerPeer]) + re.Empty(regionStats.offlineStats[OfflinePeer]) store3 = stores[3].Clone(core.UpStore()) stores[3] = store3 regionStats.Observe(region1, stores) - re.Len(regionStats.stats[OfflinePeer], 0) + re.Empty(regionStats.stats[OfflinePeer]) } func TestRegionStatisticsWithPlacementRule(t *testing.T) { diff --git a/server/storage/storage_gc_test.go b/server/storage/storage_gc_test.go index 371dc5759f9..76540f05a9c 100644 --- a/server/storage/storage_gc_test.go +++ b/server/storage/storage_gc_test.go @@ -214,5 +214,5 @@ func TestLoadEmpty(t *testing.T) { // loading empty key spaces should return empty slices safePoints, err := storage.LoadAllKeySpaceGCSafePoints(true) re.NoError(err) - re.Len(safePoints, 0) + re.Empty(safePoints) } diff --git a/tests/pdctl/hot/hot_test.go b/tests/pdctl/hot/hot_test.go index 74148b40955..787da8b1bd4 100644 --- a/tests/pdctl/hot/hot_test.go +++ b/tests/pdctl/hot/hot_test.go @@ -336,7 +336,7 @@ func TestHistoryHotRegions(t *testing.T) { output, err = pdctl.ExecuteCommand(cmd, args...) re.NoError(err) re.NoError(json.Unmarshal(output, &hotRegions)) - re.Len(hotRegions.HistoryHotRegion, 0) + re.Empty(hotRegions.HistoryHotRegion) args = []string{"-u", pdAddr, "hot", "history"} output, err = pdctl.ExecuteCommand(cmd, args...) re.NoError(err) diff --git a/tests/server/cluster/cluster_test.go b/tests/server/cluster/cluster_test.go index 5c3ea03827e..1bf2ecad485 100644 --- a/tests/server/cluster/cluster_test.go +++ b/tests/server/cluster/cluster_test.go @@ -708,7 +708,7 @@ func TestSetScheduleOpt(t *testing.T) { re.Equal("testKey", persistOptions.GetLabelPropertyConfig()[typ][0].Key) re.Equal("testValue", persistOptions.GetLabelPropertyConfig()[typ][0].Value) re.NoError(svr.DeleteLabelProperty(typ, labelKey, labelValue)) - re.Len(persistOptions.GetLabelPropertyConfig()[typ], 0) + re.Empty(persistOptions.GetLabelPropertyConfig()[typ]) // PUT GET failed re.NoError(failpoint.Enable("github.com/tikv/pd/server/storage/kv/etcdSaveFailed", `return(true)`)) @@ -722,7 +722,7 @@ func TestSetScheduleOpt(t *testing.T) { re.Equal(5, persistOptions.GetMaxReplicas()) re.Equal(uint64(10), persistOptions.GetMaxSnapshotCount()) re.True(persistOptions.GetPDServerConfig().UseRegionStorage) - re.Len(persistOptions.GetLabelPropertyConfig()[typ], 0) + re.Empty(persistOptions.GetLabelPropertyConfig()[typ]) // DELETE failed re.NoError(failpoint.Disable("github.com/tikv/pd/server/storage/kv/etcdSaveFailed")) From cff28d79e8863b1567a8a1f38394c95eb6089fa7 Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Wed, 22 Jun 2022 14:04:37 +0800 Subject: [PATCH 66/82] tests/server: migrate test framework to testify (#5197) ref tikv/pd#4813 Signed-off-by: lhy1024 Co-authored-by: Ti Chi Robot --- tests/server/server_test.go | 92 +++++++++++++++++-------------------- 1 file changed, 41 insertions(+), 51 deletions(-) diff --git a/tests/server/server_test.go b/tests/server/server_test.go index f75ef4e15f0..4369d03c966 100644 --- a/tests/server/server_test.go +++ b/tests/server/server_test.go @@ -18,7 +18,7 @@ import ( "context" "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/tempurl" "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server/config" @@ -29,46 +29,30 @@ import ( _ "github.com/tikv/pd/server/schedulers" ) -func Test(t *testing.T) { - TestingT(t) -} - func TestMain(m *testing.M) { goleak.VerifyTestMain(m, testutil.LeakOptions...) } -var _ = Suite(&serverTestSuite{}) - -type serverTestSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *serverTestSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) -} - -func (s *serverTestSuite) TearDownSuite(c *C) { - s.cancel() -} - -func (s *serverTestSuite) TestUpdateAdvertiseUrls(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 2) +func TestUpdateAdvertiseUrls(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + re := require.New(t) + cluster, err := tests.NewTestCluster(ctx, 2) defer cluster.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) // AdvertisePeerUrls should equals to PeerUrls. for _, conf := range cluster.GetConfig().InitialServers { serverConf := cluster.GetServer(conf.Name).GetConfig() - c.Assert(serverConf.AdvertisePeerUrls, Equals, conf.PeerURLs) - c.Assert(serverConf.AdvertiseClientUrls, Equals, conf.ClientURLs) + re.Equal(conf.PeerURLs, serverConf.AdvertisePeerUrls) + re.Equal(conf.ClientURLs, serverConf.AdvertiseClientUrls) } err = cluster.StopAll() - c.Assert(err, IsNil) + re.NoError(err) // Change config will not affect peer urls. // Recreate servers with new peer URLs. @@ -77,66 +61,72 @@ func (s *serverTestSuite) TestUpdateAdvertiseUrls(c *C) { } for _, conf := range cluster.GetConfig().InitialServers { serverConf, e := conf.Generate() - c.Assert(e, IsNil) - s, e := tests.NewTestServer(s.ctx, serverConf) - c.Assert(e, IsNil) + re.NoError(e) + s, e := tests.NewTestServer(ctx, serverConf) + re.NoError(e) cluster.GetServers()[conf.Name] = s } err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) for _, conf := range cluster.GetConfig().InitialServers { serverConf := cluster.GetServer(conf.Name).GetConfig() - c.Assert(serverConf.AdvertisePeerUrls, Equals, conf.PeerURLs) + re.Equal(conf.PeerURLs, serverConf.AdvertisePeerUrls) } } -func (s *serverTestSuite) TestClusterID(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 3) +func TestClusterID(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + re := require.New(t) + cluster, err := tests.NewTestCluster(ctx, 3) defer cluster.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) clusterID := cluster.GetServer("pd1").GetClusterID() for _, s := range cluster.GetServers() { - c.Assert(s.GetClusterID(), Equals, clusterID) + re.Equal(clusterID, s.GetClusterID()) } // Restart all PDs. err = cluster.StopAll() - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) // All PDs should have the same cluster ID as before. for _, s := range cluster.GetServers() { - c.Assert(s.GetClusterID(), Equals, clusterID) + re.Equal(clusterID, s.GetClusterID()) } - cluster2, err := tests.NewTestCluster(s.ctx, 3, func(conf *config.Config, serverName string) { conf.InitialClusterToken = "foobar" }) + cluster2, err := tests.NewTestCluster(ctx, 3, func(conf *config.Config, serverName string) { conf.InitialClusterToken = "foobar" }) defer cluster2.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = cluster2.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) clusterID2 := cluster2.GetServer("pd1").GetClusterID() - c.Assert(clusterID2, Not(Equals), clusterID) + re.NotEqual(clusterID, clusterID2) } -func (s *serverTestSuite) TestLeader(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 3) +func TestLeader(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + re := require.New(t) + cluster, err := tests.NewTestCluster(ctx, 3) defer cluster.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) leader1 := cluster.WaitLeader() - c.Assert(leader1, Not(Equals), "") + re.NotEqual("", leader1) err = cluster.GetServer(leader1).Stop() - c.Assert(err, IsNil) - testutil.WaitUntil(c, func() bool { + re.NoError(err) + testutil.Eventually(re, func() bool { leader := cluster.GetLeader() return leader != leader1 }) From da533ebdeec03dd54651ad7e29ba52eaf08aa39d Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Wed, 22 Jun 2022 14:14:37 +0800 Subject: [PATCH 67/82] *: migrate test framework to testify (#5193) ref tikv/pd#4813 Signed-off-by: Ryan Leung Co-authored-by: Ti Chi Robot --- pkg/assertutil/assertutil.go | 12 +- pkg/assertutil/assertutil_test.go | 5 +- server/api/version_test.go | 5 +- server/join/join_test.go | 26 +-- server/schedule/labeler/labeler_test.go | 212 +++++++++--------- server/schedule/labeler/rule_test.go | 36 ++- server/server_test.go | 5 +- tests/client/client_test.go | 5 +- tests/pdctl/helper.go | 5 +- .../global_config/global_config_test.go | 4 +- tests/server/member/member_test.go | 5 +- 11 files changed, 163 insertions(+), 157 deletions(-) diff --git a/pkg/assertutil/assertutil.go b/pkg/assertutil/assertutil.go index d750b63a7aa..5da16155674 100644 --- a/pkg/assertutil/assertutil.go +++ b/pkg/assertutil/assertutil.go @@ -22,20 +22,14 @@ type Checker struct { } // NewChecker creates Checker with FailNow function. -func NewChecker(failNow func()) *Checker { - return &Checker{ - FailNow: failNow, - } -} - -func (c *Checker) failNow() { - c.FailNow() +func NewChecker() *Checker { + return &Checker{} } // AssertNil calls the injected IsNil assertion. func (c *Checker) AssertNil(obtained interface{}) { if c.IsNil == nil { - c.failNow() + c.FailNow() return } c.IsNil(obtained) diff --git a/pkg/assertutil/assertutil_test.go b/pkg/assertutil/assertutil_test.go index 8da2ad2b164..dcdbbd4252d 100644 --- a/pkg/assertutil/assertutil_test.go +++ b/pkg/assertutil/assertutil_test.go @@ -25,9 +25,10 @@ func TestNilFail(t *testing.T) { t.Parallel() re := require.New(t) var failErr error - checker := NewChecker(func() { + checker := NewChecker() + checker.FailNow = func() { failErr = errors.New("called assert func not exist") - }) + } re.Nil(checker.IsNil) checker.AssertNil(nil) re.NotNil(failErr) diff --git a/server/api/version_test.go b/server/api/version_test.go index e28dbf08df9..7cbbab688cf 100644 --- a/server/api/version_test.go +++ b/server/api/version_test.go @@ -33,7 +33,10 @@ import ( var _ = Suite(&testVersionSuite{}) func checkerWithNilAssert(c *C) *assertutil.Checker { - checker := assertutil.NewChecker(c.FailNow) + checker := assertutil.NewChecker() + checker.FailNow = func() { + c.FailNow() + } checker.IsNil = func(obtained interface{}) { c.Assert(obtained, IsNil) } diff --git a/server/join/join_test.go b/server/join/join_test.go index 4ef4eb65c17..b8f001b5398 100644 --- a/server/join/join_test.go +++ b/server/join/join_test.go @@ -17,32 +17,28 @@ package join import ( "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/assertutil" "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server" ) -func TestJoin(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testJoinServerSuite{}) - -type testJoinServerSuite struct{} - -func checkerWithNilAssert(c *C) *assertutil.Checker { - checker := assertutil.NewChecker(c.FailNow) +func checkerWithNilAssert(re *require.Assertions) *assertutil.Checker { + checker := assertutil.NewChecker() + checker.FailNow = func() { + re.FailNow("") + } checker.IsNil = func(obtained interface{}) { - c.Assert(obtained, IsNil) + re.Nil(obtained) } return checker } // A PD joins itself. -func (s *testJoinServerSuite) TestPDJoinsItself(c *C) { - cfg := server.NewTestSingleConfig(checkerWithNilAssert(c)) +func TestPDJoinsItself(t *testing.T) { + re := require.New(t) + cfg := server.NewTestSingleConfig(checkerWithNilAssert(re)) defer testutil.CleanServer(cfg.DataDir) cfg.Join = cfg.AdvertiseClientUrls - c.Assert(PrepareJoinCluster(cfg), NotNil) + re.Error(PrepareJoinCluster(cfg)) } diff --git a/server/schedule/labeler/labeler_test.go b/server/schedule/labeler/labeler_test.go index 375e87851ec..e848ace1b12 100644 --- a/server/schedule/labeler/labeler_test.go +++ b/server/schedule/labeler/labeler_test.go @@ -23,32 +23,14 @@ import ( "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/failpoint" + "github.com/stretchr/testify/require" "github.com/tikv/pd/server/core" "github.com/tikv/pd/server/storage" - "github.com/tikv/pd/server/storage/endpoint" ) -func TestT(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testLabelerSuite{}) - -type testLabelerSuite struct { - store endpoint.RuleStorage - labeler *RegionLabeler -} - -func (s *testLabelerSuite) SetUpTest(c *C) { - s.store = storage.NewStorageWithMemoryBackend() - var err error - s.labeler, err = NewRegionLabeler(context.Background(), s.store, time.Millisecond*10) - c.Assert(err, IsNil) -} - -func (s *testLabelerSuite) TestAdjustRule(c *C) { +func TestAdjustRule(t *testing.T) { + re := require.New(t) rule := LabelRule{ ID: "foo", Labels: []RegionLabel{ @@ -58,21 +40,22 @@ func (s *testLabelerSuite) TestAdjustRule(c *C) { Data: makeKeyRanges("12abcd", "34cdef", "56abcd", "78cdef"), } err := rule.checkAndAdjust() - c.Assert(err, IsNil) - c.Assert(rule.Data.([]*KeyRangeRule), HasLen, 2) - c.Assert(rule.Data.([]*KeyRangeRule)[0].StartKey, BytesEquals, []byte{0x12, 0xab, 0xcd}) - c.Assert(rule.Data.([]*KeyRangeRule)[0].EndKey, BytesEquals, []byte{0x34, 0xcd, 0xef}) - c.Assert(rule.Data.([]*KeyRangeRule)[1].StartKey, BytesEquals, []byte{0x56, 0xab, 0xcd}) - c.Assert(rule.Data.([]*KeyRangeRule)[1].EndKey, BytesEquals, []byte{0x78, 0xcd, 0xef}) + re.NoError(err) + re.Len(rule.Data.([]*KeyRangeRule), 2) + re.Equal([]byte{0x12, 0xab, 0xcd}, rule.Data.([]*KeyRangeRule)[0].StartKey) + re.Equal([]byte{0x34, 0xcd, 0xef}, rule.Data.([]*KeyRangeRule)[0].EndKey) + re.Equal([]byte{0x56, 0xab, 0xcd}, rule.Data.([]*KeyRangeRule)[1].StartKey) + re.Equal([]byte{0x78, 0xcd, 0xef}, rule.Data.([]*KeyRangeRule)[1].EndKey) } -func (s *testLabelerSuite) TestAdjustRule2(c *C) { +func TestAdjustRule2(t *testing.T) { + re := require.New(t) ruleData := `{"id":"id", "labels": [{"key": "k1", "value": "v1"}], "rule_type":"key-range", "data": [{"start_key":"", "end_key":""}]}` var rule LabelRule err := json.Unmarshal([]byte(ruleData), &rule) - c.Assert(err, IsNil) + re.NoError(err) err = rule.checkAndAdjust() - c.Assert(err, IsNil) + re.NoError(err) badRuleData := []string{ // no id @@ -95,40 +78,44 @@ func (s *testLabelerSuite) TestAdjustRule2(c *C) { `{"id":"id", "labels": [{"key": "k1", "value": "v1"}], "rule_type":"key-range", "data": [{"start_key":"abcd", "end_key":"123"}]}`, `{"id":"id", "labels": [{"key": "k1", "value": "v1"}], "rule_type":"key-range", "data": [{"start_key":"abcd", "end_key":"1234"}]}`, } - for i, str := range badRuleData { + for _, str := range badRuleData { var rule LabelRule err := json.Unmarshal([]byte(str), &rule) - c.Assert(err, IsNil, Commentf("#%d", i)) + re.NoError(err) err = rule.checkAndAdjust() - c.Assert(err, NotNil, Commentf("#%d", i)) + re.Error(err) } } -func (s *testLabelerSuite) TestGetSetRule(c *C) { +func TestGetSetRule(t *testing.T) { + re := require.New(t) + store := storage.NewStorageWithMemoryBackend() + labeler, err := NewRegionLabeler(context.Background(), store, time.Millisecond*10) + re.NoError(err) rules := []*LabelRule{ {ID: "rule1", Labels: []RegionLabel{{Key: "k1", Value: "v1"}}, RuleType: "key-range", Data: makeKeyRanges("1234", "5678")}, {ID: "rule2", Labels: []RegionLabel{{Key: "k2", Value: "v2"}}, RuleType: "key-range", Data: makeKeyRanges("ab12", "cd12")}, {ID: "rule3", Labels: []RegionLabel{{Key: "k3", Value: "v3"}}, RuleType: "key-range", Data: makeKeyRanges("abcd", "efef")}, } for _, r := range rules { - err := s.labeler.SetLabelRule(r) - c.Assert(err, IsNil) + err := labeler.SetLabelRule(r) + re.NoError(err) } - allRules := s.labeler.GetAllLabelRules() + allRules := labeler.GetAllLabelRules() sort.Slice(allRules, func(i, j int) bool { return allRules[i].ID < allRules[j].ID }) - c.Assert(allRules, DeepEquals, rules) + re.Equal(rules, allRules) - byIDs, err := s.labeler.GetLabelRules([]string{"rule3", "rule1"}) - c.Assert(err, IsNil) - c.Assert(byIDs, DeepEquals, []*LabelRule{rules[2], rules[0]}) + byIDs, err := labeler.GetLabelRules([]string{"rule3", "rule1"}) + re.NoError(err) + re.Equal([]*LabelRule{rules[2], rules[0]}, byIDs) - err = s.labeler.DeleteLabelRule("rule2") - c.Assert(err, IsNil) - c.Assert(s.labeler.GetLabelRule("rule2"), IsNil) - byIDs, err = s.labeler.GetLabelRules([]string{"rule1", "rule2"}) - c.Assert(err, IsNil) - c.Assert(byIDs, DeepEquals, []*LabelRule{rules[0]}) + err = labeler.DeleteLabelRule("rule2") + re.NoError(err) + re.Nil(labeler.GetLabelRule("rule2")) + byIDs, err = labeler.GetLabelRules([]string{"rule1", "rule2"}) + re.NoError(err) + re.Equal([]*LabelRule{rules[0]}, byIDs) // patch patch := LabelRulePatch{ @@ -137,16 +124,20 @@ func (s *testLabelerSuite) TestGetSetRule(c *C) { }, DeleteRules: []string{"rule1"}, } - err = s.labeler.Patch(patch) - c.Assert(err, IsNil) - allRules = s.labeler.GetAllLabelRules() + err = labeler.Patch(patch) + re.NoError(err) + allRules = labeler.GetAllLabelRules() sort.Slice(allRules, func(i, j int) bool { return allRules[i].ID < allRules[j].ID }) for id, rule := range allRules { - expectSameRules(c, rule, rules[id+1]) + expectSameRules(re, rule, rules[id+1]) } } -func (s *testLabelerSuite) TestIndex(c *C) { +func TestIndex(t *testing.T) { + re := require.New(t) + store := storage.NewStorageWithMemoryBackend() + labeler, err := NewRegionLabeler(context.Background(), store, time.Millisecond*10) + re.NoError(err) rules := []*LabelRule{ {ID: "rule0", Labels: []RegionLabel{{Key: "k1", Value: "v0"}}, RuleType: "key-range", Data: makeKeyRanges("", "")}, {ID: "rule1", Index: 1, Labels: []RegionLabel{{Key: "k1", Value: "v1"}}, RuleType: "key-range", Data: makeKeyRanges("1234", "5678")}, @@ -154,8 +145,8 @@ func (s *testLabelerSuite) TestIndex(c *C) { {ID: "rule3", Index: 1, Labels: []RegionLabel{{Key: "k2", Value: "v3"}}, RuleType: "key-range", Data: makeKeyRanges("abcd", "efef")}, } for _, r := range rules { - err := s.labeler.SetLabelRule(r) - c.Assert(err, IsNil) + err := labeler.SetLabelRule(r) + re.NoError(err) } type testCase struct { @@ -173,67 +164,75 @@ func (s *testLabelerSuite) TestIndex(c *C) { start, _ := hex.DecodeString(tc.start) end, _ := hex.DecodeString(tc.end) region := core.NewTestRegionInfo(start, end) - labels := s.labeler.GetRegionLabels(region) - c.Assert(labels, HasLen, len(tc.labels)) + labels := labeler.GetRegionLabels(region) + re.Len(labels, len(tc.labels)) for _, l := range labels { - c.Assert(l.Value, Equals, tc.labels[l.Key]) + re.Equal(tc.labels[l.Key], l.Value) } for _, k := range []string{"k1", "k2"} { - c.Assert(s.labeler.GetRegionLabel(region, k), Equals, tc.labels[k]) + re.Equal(tc.labels[k], labeler.GetRegionLabel(region, k)) } } } -func (s *testLabelerSuite) TestSaveLoadRule(c *C) { +func TestSaveLoadRule(t *testing.T) { + re := require.New(t) + store := storage.NewStorageWithMemoryBackend() + labeler, err := NewRegionLabeler(context.Background(), store, time.Millisecond*10) + re.NoError(err) rules := []*LabelRule{ {ID: "rule1", Labels: []RegionLabel{{Key: "k1", Value: "v1"}}, RuleType: "key-range", Data: makeKeyRanges("1234", "5678")}, {ID: "rule2", Labels: []RegionLabel{{Key: "k2", Value: "v2"}}, RuleType: "key-range", Data: makeKeyRanges("ab12", "cd12")}, {ID: "rule3", Labels: []RegionLabel{{Key: "k3", Value: "v3"}}, RuleType: "key-range", Data: makeKeyRanges("abcd", "efef")}, } for _, r := range rules { - err := s.labeler.SetLabelRule(r) - c.Assert(err, IsNil) + err := labeler.SetLabelRule(r) + re.NoError(err) } - labeler, err := NewRegionLabeler(context.Background(), s.store, time.Millisecond*100) - c.Assert(err, IsNil) + labeler, err = NewRegionLabeler(context.Background(), store, time.Millisecond*100) + re.NoError(err) for _, r := range rules { r2 := labeler.GetLabelRule(r.ID) - expectSameRules(c, r2, r) + expectSameRules(re, r2, r) } } -func expectSameRegionLabels(c *C, r1, r2 *RegionLabel) { +func expectSameRegionLabels(re *require.Assertions, r1, r2 *RegionLabel) { r1.checkAndAdjustExpire() r2.checkAndAdjustExpire() if len(r1.TTL) == 0 { - c.Assert(r2, DeepEquals, r1) + re.Equal(r1, r2) } r2.StartAt = r1.StartAt r2.checkAndAdjustExpire() - c.Assert(r2, DeepEquals, r1) + re.Equal(r1, r2) } -func expectSameRules(c *C, r1, r2 *LabelRule) { - c.Assert(r1.Labels, HasLen, len(r2.Labels)) +func expectSameRules(re *require.Assertions, r1, r2 *LabelRule) { + re.Len(r1.Labels, len(r2.Labels)) for id := 0; id < len(r1.Labels); id++ { - expectSameRegionLabels(c, &r1.Labels[id], &r2.Labels[id]) + expectSameRegionLabels(re, &r1.Labels[id], &r2.Labels[id]) } - c.Assert(r2, DeepEquals, r1) + re.Equal(r1, r2) } -func (s *testLabelerSuite) TestKeyRange(c *C) { +func TestKeyRange(t *testing.T) { + re := require.New(t) + store := storage.NewStorageWithMemoryBackend() + labeler, err := NewRegionLabeler(context.Background(), store, time.Millisecond*10) + re.NoError(err) rules := []*LabelRule{ {ID: "rule1", Labels: []RegionLabel{{Key: "k1", Value: "v1"}}, RuleType: "key-range", Data: makeKeyRanges("1234", "5678")}, {ID: "rule2", Labels: []RegionLabel{{Key: "k2", Value: "v2"}}, RuleType: "key-range", Data: makeKeyRanges("ab12", "cd12")}, {ID: "rule3", Labels: []RegionLabel{{Key: "k3", Value: "v3"}}, RuleType: "key-range", Data: makeKeyRanges("abcd", "efef")}, } for _, r := range rules { - err := s.labeler.SetLabelRule(r) - c.Assert(err, IsNil) + err := labeler.SetLabelRule(r) + re.NoError(err) } type testCase struct { @@ -251,18 +250,22 @@ func (s *testLabelerSuite) TestKeyRange(c *C) { start, _ := hex.DecodeString(tc.start) end, _ := hex.DecodeString(tc.end) region := core.NewTestRegionInfo(start, end) - labels := s.labeler.GetRegionLabels(region) - c.Assert(labels, HasLen, len(tc.labels)) + labels := labeler.GetRegionLabels(region) + re.Len(labels, len(tc.labels)) for _, l := range labels { - c.Assert(tc.labels[l.Key], Equals, l.Value) + re.Equal(l.Value, tc.labels[l.Key]) } for _, k := range []string{"k1", "k2", "k3"} { - c.Assert(s.labeler.GetRegionLabel(region, k), Equals, tc.labels[k]) + re.Equal(tc.labels[k], labeler.GetRegionLabel(region, k)) } } } -func (s *testLabelerSuite) TestLabelerRuleTTL(c *C) { +func TestLabelerRuleTTL(t *testing.T) { + re := require.New(t) + store := storage.NewStorageWithMemoryBackend() + labeler, err := NewRegionLabeler(context.Background(), store, time.Millisecond*10) + re.NoError(err) rules := []*LabelRule{ { ID: "rule1", @@ -291,56 +294,57 @@ func (s *testLabelerSuite) TestLabelerRuleTTL(c *C) { end, _ := hex.DecodeString("5678") region := core.NewTestRegionInfo(start, end) // the region has no lable rule at the beginning. - c.Assert(s.labeler.GetRegionLabels(region), HasLen, 0) + re.Empty(labeler.GetRegionLabels(region)) // set rules for the region. for _, r := range rules { - err := s.labeler.SetLabelRule(r) - c.Assert(err, IsNil) + err := labeler.SetLabelRule(r) + re.NoError(err) } // get rule with "rule2". - c.Assert(s.labeler.GetLabelRule("rule2"), NotNil) - c.Assert(failpoint.Enable("github.com/tikv/pd/server/schedule/labeler/regionLabelExpireSub1Minute", "return(true)"), IsNil) + re.NotNil(labeler.GetLabelRule("rule2")) + re.NoError(failpoint.Enable("github.com/tikv/pd/server/schedule/labeler/regionLabelExpireSub1Minute", "return(true)")) // rule2 should expire and only 2 labels left. - labels := s.labeler.GetRegionLabels(region) - c.Assert(labels, HasLen, 2) + labels := labeler.GetRegionLabels(region) + re.Len(labels, 2) - c.Assert(failpoint.Disable("github.com/tikv/pd/server/schedule/labeler/regionLabelExpireSub1Minute"), IsNil) + re.NoError(failpoint.Disable("github.com/tikv/pd/server/schedule/labeler/regionLabelExpireSub1Minute")) // rule2 should be exist since `GetRegionLabels` won't clear it physically. - s.checkRuleInMemoryAndStoage(c, "rule2", true) - c.Assert(s.labeler.GetLabelRule("rule2"), IsNil) + checkRuleInMemoryAndStoage(re, labeler, "rule2", true) + re.Nil(labeler.GetLabelRule("rule2")) // rule2 should be physically clear. - s.checkRuleInMemoryAndStoage(c, "rule2", false) + checkRuleInMemoryAndStoage(re, labeler, "rule2", false) - c.Assert(s.labeler.GetRegionLabel(region, "k2"), Equals, "") + re.Equal("", labeler.GetRegionLabel(region, "k2")) - c.Assert(s.labeler.GetLabelRule("rule3"), NotNil) - c.Assert(s.labeler.GetLabelRule("rule1"), NotNil) + re.NotNil(labeler.GetLabelRule("rule3")) + re.NotNil(labeler.GetLabelRule("rule1")) } -func (s *testLabelerSuite) checkRuleInMemoryAndStoage(c *C, ruleID string, exist bool) { - c.Assert(s.labeler.labelRules[ruleID] != nil, Equals, exist) +func checkRuleInMemoryAndStoage(re *require.Assertions, labeler *RegionLabeler, ruleID string, exist bool) { + re.Equal(exist, labeler.labelRules[ruleID] != nil) existInStorage := false - s.labeler.storage.LoadRegionRules(func(k, v string) { + labeler.storage.LoadRegionRules(func(k, v string) { if k == ruleID { existInStorage = true } }) - c.Assert(existInStorage, Equals, exist) + re.Equal(exist, existInStorage) } -func (s *testLabelerSuite) TestGC(c *C) { +func TestGC(t *testing.T) { + re := require.New(t) // set gcInterval to 1 hour. store := storage.NewStorageWithMemoryBackend() labeler, err := NewRegionLabeler(context.Background(), store, time.Hour) - c.Assert(err, IsNil) + re.NoError(err) ttls := []string{"1ms", "1ms", "1ms", "5ms", "5ms", "10ms", "1h", "24h"} start, _ := hex.DecodeString("1234") end, _ := hex.DecodeString("5678") region := core.NewTestRegionInfo(start, end) // the region has no lable rule at the beginning. - c.Assert(labeler.GetRegionLabels(region), HasLen, 0) + re.Empty(labeler.GetRegionLabels(region)) labels := []RegionLabel{} for id, ttl := range ttls { @@ -351,10 +355,10 @@ func (s *testLabelerSuite) TestGC(c *C) { RuleType: "key-range", Data: makeKeyRanges("1234", "5678")} err := labeler.SetLabelRule(rule) - c.Assert(err, IsNil) + re.NoError(err) } - c.Assert(labeler.labelRules, HasLen, len(ttls)) + re.Len(labeler.labelRules, len(ttls)) // check all rules unitl some rule expired. for { @@ -366,14 +370,14 @@ func (s *testLabelerSuite) TestGC(c *C) { } // no rule was cleared because the gc interval is big. - c.Assert(labeler.labelRules, HasLen, len(ttls)) + re.Len(labeler.labelRules, len(ttls)) labeler.checkAndClearExpiredLabels() labeler.RLock() currentRuleLen := len(labeler.labelRules) labeler.RUnlock() - c.Assert(currentRuleLen <= 5, IsTrue) + re.LessOrEqual(currentRuleLen, 5) } func makeKeyRanges(keys ...string) []interface{} { diff --git a/server/schedule/labeler/rule_test.go b/server/schedule/labeler/rule_test.go index 5b7b3b1f76e..0b341754007 100644 --- a/server/schedule/labeler/rule_test.go +++ b/server/schedule/labeler/rule_test.go @@ -17,46 +17,44 @@ package labeler import ( "encoding/json" "math" + "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testRuleSuite{}) - -type testRuleSuite struct{} - -func (s *testLabelerSuite) TestRegionLabelTTL(c *C) { +func TestRegionLabelTTL(t *testing.T) { + re := require.New(t) label := RegionLabel{Key: "k1", Value: "v1"} // test label with no ttl. err := label.checkAndAdjustExpire() - c.Assert(err, IsNil) - c.Assert(label.StartAt, HasLen, 0) - c.Assert(label.expire, IsNil) + re.NoError(err) + re.Empty(label.StartAt) + re.Empty(label.expire) // test rule with illegal ttl. label.TTL = "ttl" err = label.checkAndAdjustExpire() - c.Assert(err, NotNil) + re.Error(err) // test legal rule with ttl label.TTL = "10h10m10s10ms" err = label.checkAndAdjustExpire() - c.Assert(err, IsNil) - c.Assert(len(label.StartAt) > 0, IsTrue) - c.Assert(label.expireBefore(time.Now().Add(time.Hour)), IsFalse) - c.Assert(label.expireBefore(time.Now().Add(24*time.Hour)), IsTrue) + re.NoError(err) + re.Greater(len(label.StartAt), 0) + re.False(label.expireBefore(time.Now().Add(time.Hour))) + re.True(label.expireBefore(time.Now().Add(24 * time.Hour))) // test legal rule with ttl, rule unmarshal from json. data, err := json.Marshal(label) - c.Assert(err, IsNil) + re.NoError(err) var label2 RegionLabel err = json.Unmarshal(data, &label2) - c.Assert(err, IsNil) - c.Assert(label2.StartAt, Equals, label.StartAt) - c.Assert(label2.TTL, Equals, label.TTL) + re.NoError(err) + re.Equal(label.StartAt, label2.StartAt) + re.Equal(label.TTL, label2.TTL) label2.checkAndAdjustExpire() // The `expire` should be the same with minor inaccuracies. - c.Assert(math.Abs(label2.expire.Sub(*label.expire).Seconds()) < 1, IsTrue) + re.True(math.Abs(label2.expire.Sub(*label.expire).Seconds()) < 1) } diff --git a/server/server_test.go b/server/server_test.go index c6c14fe011f..2a31cfb3b1c 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -55,7 +55,10 @@ func mustWaitLeader(c *C, svrs []*Server) *Server { } func checkerWithNilAssert(c *C) *assertutil.Checker { - checker := assertutil.NewChecker(c.FailNow) + checker := assertutil.NewChecker() + checker.FailNow = func() { + c.FailNow() + } checker.IsNil = func(obtained interface{}) { c.Assert(obtained, IsNil) } diff --git a/tests/client/client_test.go b/tests/client/client_test.go index 86824261d12..a9022d04a20 100644 --- a/tests/client/client_test.go +++ b/tests/client/client_test.go @@ -729,9 +729,10 @@ func (suite *clientTestSuite) TearDownSuite() { } func (suite *clientTestSuite) checkerWithNilAssert() *assertutil.Checker { - checker := assertutil.NewChecker(func() { + checker := assertutil.NewChecker() + checker.FailNow = func() { suite.FailNow("should be nil") - }) + } checker.IsNil = func(obtained interface{}) { suite.Nil(obtained) } diff --git a/tests/pdctl/helper.go b/tests/pdctl/helper.go index 775f0b40f15..bbfa351cc32 100644 --- a/tests/pdctl/helper.go +++ b/tests/pdctl/helper.go @@ -124,9 +124,10 @@ func MustPutRegion(re *require.Assertions, cluster *tests.TestCluster, regionID, } func checkerWithNilAssert(re *require.Assertions) *assertutil.Checker { - checker := assertutil.NewChecker(func() { + checker := assertutil.NewChecker() + checker.FailNow = func() { re.FailNow("should be nil") - }) + } checker.IsNil = func(obtained interface{}) { re.Nil(obtained) } diff --git a/tests/server/global_config/global_config_test.go b/tests/server/global_config/global_config_test.go index 87dc62e35a2..c2b12353eea 100644 --- a/tests/server/global_config/global_config_test.go +++ b/tests/server/global_config/global_config_test.go @@ -65,7 +65,9 @@ func (s TestReceiver) Send(m *pdpb.WatchGlobalConfigResponse) error { func (s *GlobalConfigTestSuite) SetUpSuite(c *C) { var err error var gsi *server.Server - gsi, s.cleanup, err = server.NewTestServer(assertutil.NewChecker(func() {})) + checker := assertutil.NewChecker() + checker.FailNow = func() {} + gsi, s.cleanup, err = server.NewTestServer(checker) s.server = &server.GrpcServer{Server: gsi} c.Assert(err, IsNil) addr := s.server.GetAddr() diff --git a/tests/server/member/member_test.go b/tests/server/member/member_test.go index 0215d95e5ff..9c856b739a1 100644 --- a/tests/server/member/member_test.go +++ b/tests/server/member/member_test.go @@ -47,7 +47,10 @@ func TestMain(m *testing.M) { } func checkerWithNilAssert(c *C) *assertutil.Checker { - checker := assertutil.NewChecker(c.FailNow) + checker := assertutil.NewChecker() + checker.FailNow = func() { + c.FailNow() + } checker.IsNil = func(obtained interface{}) { c.Assert(obtained, IsNil) } From dcaa318e4470f23ed71549cfb61394e8bdfbd972 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Wed, 22 Jun 2022 15:16:37 +0800 Subject: [PATCH 68/82] tests: testify some server tests (#5186) ref tikv/pd#4813 Testify some server tests. Signed-off-by: JmPotato --- tests/server/id/id_test.go | 111 ++++---- tests/server/join/join_fail/join_fail_test.go | 32 +-- tests/server/join/join_test.go | 136 +++++----- tests/server/member/member_test.go | 248 ++++++++---------- .../region_syncer/region_syncer_test.go | 163 ++++++------ tests/server/server_test.go | 19 +- tests/server/watch/leader_watch_test.go | 87 +++--- 7 files changed, 356 insertions(+), 440 deletions(-) diff --git a/tests/server/id/id_test.go b/tests/server/id/id_test.go index b624ceb056f..d9279cd3616 100644 --- a/tests/server/id/id_test.go +++ b/tests/server/id/id_test.go @@ -19,53 +19,37 @@ import ( "sync" "testing" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/tests" "go.uber.org/goleak" ) -func Test(t *testing.T) { - TestingT(t) -} - func TestMain(m *testing.M) { goleak.VerifyTestMain(m, testutil.LeakOptions...) } const allocStep = uint64(1000) -var _ = Suite(&testAllocIDSuite{}) - -type testAllocIDSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *testAllocIDSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) -} - -func (s *testAllocIDSuite) TearDownSuite(c *C) { - s.cancel() -} - -func (s *testAllocIDSuite) TestID(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 1) - c.Assert(err, IsNil) +func TestID(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 1) + re.NoError(err) defer cluster.Destroy() err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() leaderServer := cluster.GetServer(cluster.GetLeader()) var last uint64 for i := uint64(0); i < allocStep; i++ { id, err := leaderServer.GetAllocator().Alloc() - c.Assert(err, IsNil) - c.Assert(id, Greater, last) + re.NoError(err) + re.Greater(id, last) last = id } @@ -81,12 +65,12 @@ func (s *testAllocIDSuite) TestID(c *C) { for i := 0; i < 200; i++ { id, err := leaderServer.GetAllocator().Alloc() - c.Assert(err, IsNil) + re.NoError(err) m.Lock() _, ok := ids[id] ids[id] = struct{}{} m.Unlock() - c.Assert(ok, IsFalse) + re.False(ok) } }() } @@ -94,98 +78,107 @@ func (s *testAllocIDSuite) TestID(c *C) { wg.Wait() } -func (s *testAllocIDSuite) TestCommand(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 1) - c.Assert(err, IsNil) +func TestCommand(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 1) + re.NoError(err) defer cluster.Destroy() err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() leaderServer := cluster.GetServer(cluster.GetLeader()) req := &pdpb.AllocIDRequest{Header: testutil.NewRequestHeader(leaderServer.GetClusterID())} - grpcPDClient := testutil.MustNewGrpcClient(c, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) var last uint64 for i := uint64(0); i < 2*allocStep; i++ { resp, err := grpcPDClient.AllocID(context.Background(), req) - c.Assert(err, IsNil) - c.Assert(resp.GetId(), Greater, last) + re.NoError(err) + re.Greater(resp.GetId(), last) last = resp.GetId() } } -func (s *testAllocIDSuite) TestMonotonicID(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 2) - c.Assert(err, IsNil) +func TestMonotonicID(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 2) + re.NoError(err) defer cluster.Destroy() err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() leaderServer := cluster.GetServer(cluster.GetLeader()) var last1 uint64 for i := uint64(0); i < 10; i++ { id, err := leaderServer.GetAllocator().Alloc() - c.Assert(err, IsNil) - c.Assert(id, Greater, last1) + re.NoError(err) + re.Greater(id, last1) last1 = id } err = cluster.ResignLeader() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() leaderServer = cluster.GetServer(cluster.GetLeader()) var last2 uint64 for i := uint64(0); i < 10; i++ { id, err := leaderServer.GetAllocator().Alloc() - c.Assert(err, IsNil) - c.Assert(id, Greater, last2) + re.NoError(err) + re.Greater(id, last2) last2 = id } err = cluster.ResignLeader() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() leaderServer = cluster.GetServer(cluster.GetLeader()) id, err := leaderServer.GetAllocator().Alloc() - c.Assert(err, IsNil) - c.Assert(id, Greater, last2) + re.NoError(err) + re.Greater(id, last2) var last3 uint64 for i := uint64(0); i < 1000; i++ { id, err := leaderServer.GetAllocator().Alloc() - c.Assert(err, IsNil) - c.Assert(id, Greater, last3) + re.NoError(err) + re.Greater(id, last3) last3 = id } } -func (s *testAllocIDSuite) TestPDRestart(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 1) - c.Assert(err, IsNil) +func TestPDRestart(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 1) + re.NoError(err) defer cluster.Destroy() err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() leaderServer := cluster.GetServer(cluster.GetLeader()) var last uint64 for i := uint64(0); i < 10; i++ { id, err := leaderServer.GetAllocator().Alloc() - c.Assert(err, IsNil) - c.Assert(id, Greater, last) + re.NoError(err) + re.Greater(id, last) last = id } - c.Assert(leaderServer.Stop(), IsNil) - c.Assert(leaderServer.Run(), IsNil) + re.NoError(leaderServer.Stop()) + re.NoError(leaderServer.Run()) cluster.WaitLeader() for i := uint64(0); i < 10; i++ { id, err := leaderServer.GetAllocator().Alloc() - c.Assert(err, IsNil) - c.Assert(id, Greater, last) + re.NoError(err) + re.Greater(id, last) last = id } } diff --git a/tests/server/join/join_fail/join_fail_test.go b/tests/server/join/join_fail/join_fail_test.go index bc4e98abdce..26dca0d3b52 100644 --- a/tests/server/join/join_fail/join_fail_test.go +++ b/tests/server/join/join_fail/join_fail_test.go @@ -16,43 +16,29 @@ package join_fail_test import ( "context" - "strings" "testing" - . "github.com/pingcap/check" "github.com/pingcap/failpoint" - "github.com/tikv/pd/pkg/testutil" + "github.com/stretchr/testify/require" "github.com/tikv/pd/tests" - "go.uber.org/goleak" ) -func Test(t *testing.T) { - TestingT(t) -} - -func TestMain(m *testing.M) { - goleak.VerifyTestMain(m, testutil.LeakOptions...) -} - -var _ = Suite(&joinTestSuite{}) - -type joinTestSuite struct{} - -func (s *joinTestSuite) TestFailedPDJoinInStep1(c *C) { +func TestFailedPDJoinInStep1(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() cluster, err := tests.NewTestCluster(ctx, 1) defer cluster.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() // Join the second PD. - c.Assert(failpoint.Enable("github.com/tikv/pd/server/join/add-member-failed", `return`), IsNil) + re.NoError(failpoint.Enable("github.com/tikv/pd/server/join/add-member-failed", `return`)) _, err = cluster.Join(ctx) - c.Assert(err, NotNil) - c.Assert(strings.Contains(err.Error(), "join failed"), IsTrue) - c.Assert(failpoint.Disable("github.com/tikv/pd/server/join/add-member-failed"), IsNil) + re.Error(err) + re.Contains(err.Error(), "join failed") + re.NoError(failpoint.Disable("github.com/tikv/pd/server/join/add-member-failed")) } diff --git a/tests/server/join/join_test.go b/tests/server/join/join_test.go index 8cc9cdcdb34..e7e01d74668 100644 --- a/tests/server/join/join_test.go +++ b/tests/server/join/join_test.go @@ -21,119 +21,110 @@ import ( "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/etcdutil" "github.com/tikv/pd/server" "github.com/tikv/pd/server/join" "github.com/tikv/pd/tests" ) -func Test(t *testing.T) { - TestingT(t) -} - // TODO: enable it when we fix TestFailedAndDeletedPDJoinsPreviousCluster // func TestMain(m *testing.M) { // goleak.VerifyTestMain(m, testutil.LeakOptions...) // } -var _ = Suite(&joinTestSuite{}) - -type joinTestSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *joinTestSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) - server.EtcdStartTimeout = 10 * time.Second -} - -func (s *joinTestSuite) TearDownSuite(c *C) { - s.cancel() -} - -func (s *joinTestSuite) TestSimpleJoin(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 1) +func TestSimpleJoin(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 1) defer cluster.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() pd1 := cluster.GetServer("pd1") client := pd1.GetEtcdClient() members, err := etcdutil.ListEtcdMembers(client) - c.Assert(err, IsNil) - c.Assert(members.Members, HasLen, 1) + re.NoError(err) + re.Len(members.Members, 1) // Join the second PD. - pd2, err := cluster.Join(s.ctx) - c.Assert(err, IsNil) + pd2, err := cluster.Join(ctx) + re.NoError(err) err = pd2.Run() - c.Assert(err, IsNil) + re.NoError(err) _, err = os.Stat(path.Join(pd2.GetConfig().DataDir, "join")) - c.Assert(os.IsNotExist(err), IsFalse) + re.False(os.IsNotExist(err)) members, err = etcdutil.ListEtcdMembers(client) - c.Assert(err, IsNil) - c.Assert(members.Members, HasLen, 2) - c.Assert(pd2.GetClusterID(), Equals, pd1.GetClusterID()) + re.NoError(err) + re.Len(members.Members, 2) + re.Equal(pd1.GetClusterID(), pd2.GetClusterID()) // Wait for all nodes becoming healthy. time.Sleep(time.Second * 5) // Join another PD. - pd3, err := cluster.Join(s.ctx) - c.Assert(err, IsNil) + pd3, err := cluster.Join(ctx) + re.NoError(err) err = pd3.Run() - c.Assert(err, IsNil) + re.NoError(err) _, err = os.Stat(path.Join(pd3.GetConfig().DataDir, "join")) - c.Assert(os.IsNotExist(err), IsFalse) + re.False(os.IsNotExist(err)) members, err = etcdutil.ListEtcdMembers(client) - c.Assert(err, IsNil) - c.Assert(members.Members, HasLen, 3) - c.Assert(pd3.GetClusterID(), Equals, pd1.GetClusterID()) + re.NoError(err) + re.Len(members.Members, 3) + re.Equal(pd1.GetClusterID(), pd3.GetClusterID()) } // A failed PD tries to join the previous cluster but it has been deleted // during its downtime. -func (s *joinTestSuite) TestFailedAndDeletedPDJoinsPreviousCluster(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 3) +func TestFailedAndDeletedPDJoinsPreviousCluster(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + server.EtcdStartTimeout = 10 * time.Second + cluster, err := tests.NewTestCluster(ctx, 3) defer cluster.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() // Wait for all nodes becoming healthy. time.Sleep(time.Second * 5) pd3 := cluster.GetServer("pd3") err = pd3.Stop() - c.Assert(err, IsNil) + re.NoError(err) client := cluster.GetServer("pd1").GetEtcdClient() _, err = client.MemberRemove(context.TODO(), pd3.GetServerID()) - c.Assert(err, IsNil) + re.NoError(err) // The server should not successfully start. res := cluster.RunServer(pd3) - c.Assert(<-res, NotNil) + re.Error(<-res) members, err := etcdutil.ListEtcdMembers(client) - c.Assert(err, IsNil) - c.Assert(members.Members, HasLen, 2) + re.NoError(err) + re.Len(members.Members, 2) } // A deleted PD joins the previous cluster. -func (s *joinTestSuite) TestDeletedPDJoinsPreviousCluster(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 3) +func TestDeletedPDJoinsPreviousCluster(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + server.EtcdStartTimeout = 10 * time.Second + cluster, err := tests.NewTestCluster(ctx, 3) defer cluster.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() // Wait for all nodes becoming healthy. time.Sleep(time.Second * 5) @@ -141,37 +132,36 @@ func (s *joinTestSuite) TestDeletedPDJoinsPreviousCluster(c *C) { pd3 := cluster.GetServer("pd3") client := cluster.GetServer("pd1").GetEtcdClient() _, err = client.MemberRemove(context.TODO(), pd3.GetServerID()) - c.Assert(err, IsNil) + re.NoError(err) err = pd3.Stop() - c.Assert(err, IsNil) + re.NoError(err) // The server should not successfully start. res := cluster.RunServer(pd3) - c.Assert(<-res, NotNil) + re.Error(<-res) members, err := etcdutil.ListEtcdMembers(client) - c.Assert(err, IsNil) - c.Assert(members.Members, HasLen, 2) + re.NoError(err) + re.Len(members.Members, 2) } -func (s *joinTestSuite) TestFailedPDJoinsPreviousCluster(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 1) +func TestFailedPDJoinsPreviousCluster(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 1) defer cluster.Destroy() - c.Assert(err, IsNil) + re.NoError(err) - err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(cluster.RunInitialServers()) cluster.WaitLeader() // Join the second PD. - pd2, err := cluster.Join(s.ctx) - c.Assert(err, IsNil) - err = pd2.Run() - c.Assert(err, IsNil) - err = pd2.Stop() - c.Assert(err, IsNil) - err = pd2.Destroy() - c.Assert(err, IsNil) - c.Assert(join.PrepareJoinCluster(pd2.GetConfig()), NotNil) + pd2, err := cluster.Join(ctx) + re.NoError(err) + re.NoError(pd2.Run()) + re.NoError(pd2.Stop()) + re.NoError(pd2.Destroy()) + re.Error(join.PrepareJoinCluster(pd2.GetConfig())) } diff --git a/tests/server/member/member_test.go b/tests/server/member/member_test.go index 9c856b739a1..1864500df74 100644 --- a/tests/server/member/member_test.go +++ b/tests/server/member/member_test.go @@ -25,10 +25,10 @@ import ( "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/assertutil" "github.com/tikv/pd/pkg/etcdutil" "github.com/tikv/pd/pkg/testutil" @@ -38,58 +38,42 @@ import ( "go.uber.org/goleak" ) -func Test(t *testing.T) { - TestingT(t) -} - func TestMain(m *testing.M) { goleak.VerifyTestMain(m, testutil.LeakOptions...) } -func checkerWithNilAssert(c *C) *assertutil.Checker { +func checkerWithNilAssert(re *require.Assertions) *assertutil.Checker { checker := assertutil.NewChecker() checker.FailNow = func() { - c.FailNow() + re.FailNow("should be nil") } checker.IsNil = func(obtained interface{}) { - c.Assert(obtained, IsNil) + re.Nil(obtained) } return checker } -var _ = Suite(&memberTestSuite{}) - -type memberTestSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *memberTestSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) -} - -func (s *memberTestSuite) TearDownSuite(c *C) { - s.cancel() -} - -func (s *memberTestSuite) TestMemberDelete(c *C) { +func TestMemberDelete(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() dcLocationConfig := map[string]string{ "pd1": "dc-1", "pd2": "dc-2", "pd3": "dc-3", } dcLocationNum := len(dcLocationConfig) - cluster, err := tests.NewTestCluster(s.ctx, dcLocationNum, func(conf *config.Config, serverName string) { + cluster, err := tests.NewTestCluster(ctx, dcLocationNum, func(conf *config.Config, serverName string) { conf.EnableLocalTSO = true conf.Labels[config.ZoneLabel] = dcLocationConfig[serverName] }) defer cluster.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) leaderName := cluster.WaitLeader() - c.Assert(leaderName, Not(Equals), "") + re.NotEmpty(leaderName) leader := cluster.GetServer(leaderName) var members []*tests.TestServer for _, s := range cluster.GetConfig().InitialServers { @@ -97,9 +81,9 @@ func (s *memberTestSuite) TestMemberDelete(c *C) { members = append(members, cluster.GetServer(s.Name)) } } - c.Assert(members, HasLen, 2) + re.Len(members, 2) - var table = []struct { + var tables = []struct { path string status int members []*config.Config @@ -111,18 +95,18 @@ func (s *memberTestSuite) TestMemberDelete(c *C) { } httpClient := &http.Client{Timeout: 15 * time.Second} - for _, t := range table { - c.Log(time.Now(), "try to delete:", t.path) - testutil.WaitUntil(c, func() bool { - addr := leader.GetConfig().ClientUrls + "/pd/api/v1/members/" + t.path + for _, table := range tables { + t.Log(time.Now(), "try to delete:", table.path) + testutil.Eventually(re, func() bool { + addr := leader.GetConfig().ClientUrls + "/pd/api/v1/members/" + table.path req, err := http.NewRequest(http.MethodDelete, addr, nil) - c.Assert(err, IsNil) + re.NoError(err) res, err := httpClient.Do(req) - c.Assert(err, IsNil) + re.NoError(err) defer res.Body.Close() // Check by status. - if t.status != 0 { - if res.StatusCode != t.status { + if table.status != 0 { + if res.StatusCode != table.status { time.Sleep(time.Second) return false } @@ -130,8 +114,8 @@ func (s *memberTestSuite) TestMemberDelete(c *C) { } // Check by member list. cluster.WaitLeader() - if err = s.checkMemberList(c, leader.GetConfig().ClientUrls, t.members); err != nil { - c.Logf("check member fail: %v", err) + if err = checkMemberList(re, leader.GetConfig().ClientUrls, table.members); err != nil { + t.Logf("check member fail: %v", err) time.Sleep(time.Second) return false } @@ -142,19 +126,19 @@ func (s *memberTestSuite) TestMemberDelete(c *C) { for _, member := range members { key := member.GetServer().GetMember().GetDCLocationPath(member.GetServerID()) resp, err := etcdutil.EtcdKVGet(leader.GetEtcdClient(), key) - c.Assert(err, IsNil) - c.Assert(resp.Kvs, HasLen, 0) + re.NoError(err) + re.Len(resp.Kvs, 0) } } -func (s *memberTestSuite) checkMemberList(c *C, clientURL string, configs []*config.Config) error { +func checkMemberList(re *require.Assertions, clientURL string, configs []*config.Config) error { httpClient := &http.Client{Timeout: 15 * time.Second} addr := clientURL + "/pd/api/v1/members" res, err := httpClient.Get(addr) - c.Assert(err, IsNil) + re.NoError(err) defer res.Body.Close() buf, err := io.ReadAll(res.Body) - c.Assert(err, IsNil) + re.NoError(err) if res.StatusCode != http.StatusOK { return errors.Errorf("load members failed, status: %v, data: %q", res.StatusCode, buf) } @@ -166,114 +150,118 @@ func (s *memberTestSuite) checkMemberList(c *C, clientURL string, configs []*con for _, member := range data["members"] { for _, cfg := range configs { if member.GetName() == cfg.Name { - c.Assert(member.ClientUrls, DeepEquals, []string{cfg.ClientUrls}) - c.Assert(member.PeerUrls, DeepEquals, []string{cfg.PeerUrls}) + re.Equal([]string{cfg.ClientUrls}, member.ClientUrls) + re.Equal([]string{cfg.PeerUrls}, member.PeerUrls) } } } return nil } -func (s *memberTestSuite) TestLeaderPriority(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 3) +func TestLeaderPriority(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 3) defer cluster.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() leader1, err := cluster.GetServer("pd1").GetEtcdLeader() - c.Assert(err, IsNil) + re.NoError(err) server1 := cluster.GetServer(leader1) addr := server1.GetConfig().ClientUrls // PD leader should sync with etcd leader. - testutil.WaitUntil(c, func() bool { + testutil.Eventually(re, func() bool { return cluster.GetLeader() == leader1 }) // Bind a lower priority to current leader. - s.post(c, addr+"/pd/api/v1/members/name/"+leader1, `{"leader-priority": -1}`) + post(t, re, addr+"/pd/api/v1/members/name/"+leader1, `{"leader-priority": -1}`) // Wait etcd leader change. - leader2 := s.waitEtcdLeaderChange(c, server1, leader1) + leader2 := waitEtcdLeaderChange(re, server1, leader1) // PD leader should sync with etcd leader again. - testutil.WaitUntil(c, func() bool { + testutil.Eventually(re, func() bool { return cluster.GetLeader() == leader2 }) } -func (s *memberTestSuite) post(c *C, url string, body string) { - testutil.WaitUntil(c, func() bool { +func post(t *testing.T, re *require.Assertions, url string, body string) { + testutil.Eventually(re, func() bool { res, err := http.Post(url, "", bytes.NewBufferString(body)) // #nosec - c.Assert(err, IsNil) + re.NoError(err) b, err := io.ReadAll(res.Body) res.Body.Close() - c.Assert(err, IsNil) - c.Logf("post %s, status: %v res: %s", url, res.StatusCode, string(b)) + re.NoError(err) + t.Logf("post %s, status: %v res: %s", url, res.StatusCode, string(b)) return res.StatusCode == http.StatusOK }) } -func (s *memberTestSuite) waitEtcdLeaderChange(c *C, server *tests.TestServer, old string) string { +func waitEtcdLeaderChange(re *require.Assertions, server *tests.TestServer, old string) string { var leader string - testutil.WaitUntil(c, func() bool { + testutil.Eventually(re, func() bool { var err error leader, err = server.GetEtcdLeader() if err != nil { return false } - if leader == old { - // Priority check could be slow. So we sleep longer here. - time.Sleep(5 * time.Second) - } return leader != old - }) + }, testutil.WithWaitFor(time.Second*90), testutil.WithSleepInterval(time.Second)) return leader } -func (s *memberTestSuite) TestLeaderResign(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 3) +func TestLeaderResign(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 3) defer cluster.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) leader1 := cluster.WaitLeader() addr1 := cluster.GetServer(leader1).GetConfig().ClientUrls - s.post(c, addr1+"/pd/api/v1/leader/resign", "") - leader2 := s.waitLeaderChange(c, cluster, leader1) - c.Log("leader2:", leader2) + post(t, re, addr1+"/pd/api/v1/leader/resign", "") + leader2 := waitLeaderChange(re, cluster, leader1) + t.Log("leader2:", leader2) addr2 := cluster.GetServer(leader2).GetConfig().ClientUrls - s.post(c, addr2+"/pd/api/v1/leader/transfer/"+leader1, "") - leader3 := s.waitLeaderChange(c, cluster, leader2) - c.Assert(leader3, Equals, leader1) + post(t, re, addr2+"/pd/api/v1/leader/transfer/"+leader1, "") + leader3 := waitLeaderChange(re, cluster, leader2) + re.Equal(leader1, leader3) } -func (s *memberTestSuite) TestLeaderResignWithBlock(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 3) +func TestLeaderResignWithBlock(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 3) defer cluster.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) leader1 := cluster.WaitLeader() addr1 := cluster.GetServer(leader1).GetConfig().ClientUrls - err = failpoint.Enable("github.com/tikv/pd/server/raftclusterIsBusy", `pause`) - c.Assert(err, IsNil) - defer failpoint.Disable("github.com/tikv/pd/server/raftclusterIsBusy") - s.post(c, addr1+"/pd/api/v1/leader/resign", "") - leader2 := s.waitLeaderChange(c, cluster, leader1) - c.Log("leader2:", leader2) - c.Assert(leader2, Not(Equals), leader1) + re.NoError(failpoint.Enable("github.com/tikv/pd/server/raftclusterIsBusy", `pause`)) + post(t, re, addr1+"/pd/api/v1/leader/resign", "") + leader2 := waitLeaderChange(re, cluster, leader1) + t.Log("leader2:", leader2) + re.NotEqual(leader1, leader2) + re.NoError(failpoint.Disable("github.com/tikv/pd/server/raftclusterIsBusy")) } -func (s *memberTestSuite) waitLeaderChange(c *C, cluster *tests.TestCluster, old string) string { +func waitLeaderChange(re *require.Assertions, cluster *tests.TestCluster, old string) string { var leader string - testutil.WaitUntil(c, func() bool { + testutil.Eventually(re, func() bool { leader = cluster.GetLeader() if leader == old || leader == "" { return false @@ -283,13 +271,16 @@ func (s *memberTestSuite) waitLeaderChange(c *C, cluster *tests.TestCluster, old return leader } -func (s *memberTestSuite) TestMoveLeader(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 5) +func TestMoveLeader(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 5) defer cluster.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() var wg sync.WaitGroup @@ -315,66 +306,49 @@ func (s *memberTestSuite) TestMoveLeader(c *C) { select { case <-done: case <-time.After(10 * time.Second): - c.Fatal("move etcd leader does not return in 10 seconds") + t.Fatal("move etcd leader does not return in 10 seconds") } } -var _ = Suite(&leaderTestSuite{}) - -type leaderTestSuite struct { - ctx context.Context - cancel context.CancelFunc - svr *server.Server - wg sync.WaitGroup - done chan bool - cfg *config.Config -} - -func (s *leaderTestSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) - s.cfg = server.NewTestSingleConfig(checkerWithNilAssert(c)) - s.wg.Add(1) - s.done = make(chan bool) - svr, err := server.CreateServer(s.ctx, s.cfg) - c.Assert(err, IsNil) - err = svr.Run() +func TestGetLeader(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cfg := server.NewTestSingleConfig(checkerWithNilAssert(re)) + wg := &sync.WaitGroup{} + wg.Add(1) + done := make(chan bool) + svr, err := server.CreateServer(ctx, cfg) + re.NoError(err) + defer svr.Close() + re.NoError(svr.Run()) // Send requests after server has started. - go s.sendRequest(c, s.cfg.ClientUrls) + go sendRequest(re, wg, done, cfg.ClientUrls) time.Sleep(100 * time.Millisecond) - c.Assert(err, IsNil) - - s.svr = svr -} -func (s *leaderTestSuite) TearDownSuite(c *C) { - s.cancel() - s.svr.Close() - testutil.CleanServer(s.cfg.DataDir) -} + mustWaitLeader(re, []*server.Server{svr}) -func (s *leaderTestSuite) TestGetLeader(c *C) { - mustWaitLeader(c, []*server.Server{s.svr}) + re.NotNil(svr.GetLeader()) - leader := s.svr.GetLeader() - c.Assert(leader, NotNil) + done <- true + wg.Wait() - s.done <- true - s.wg.Wait() + testutil.CleanServer(cfg.DataDir) } -func (s *leaderTestSuite) sendRequest(c *C, addr string) { - defer s.wg.Done() +func sendRequest(re *require.Assertions, wg *sync.WaitGroup, done <-chan bool, addr string) { + defer wg.Done() req := &pdpb.AllocIDRequest{Header: testutil.NewRequestHeader(0)} for { select { - case <-s.done: + case <-done: return default: // We don't need to check the response and error, // just make sure the server will not panic. - grpcPDClient := testutil.MustNewGrpcClient(c, addr) + grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, addr) if grpcPDClient != nil { _, _ = grpcPDClient.AllocID(context.Background(), req) } @@ -383,9 +357,9 @@ func (s *leaderTestSuite) sendRequest(c *C, addr string) { } } -func mustWaitLeader(c *C, svrs []*server.Server) *server.Server { +func mustWaitLeader(re *require.Assertions, svrs []*server.Server) *server.Server { var leader *server.Server - testutil.WaitUntil(c, func() bool { + testutil.Eventually(re, func() bool { for _, s := range svrs { if !s.IsClosed() && s.GetMember().IsLeader() { leader = s diff --git a/tests/server/region_syncer/region_syncer_test.go b/tests/server/region_syncer/region_syncer_test.go index c4c91806c9f..3efe04e506d 100644 --- a/tests/server/region_syncer/region_syncer_test.go +++ b/tests/server/region_syncer/region_syncer_test.go @@ -19,9 +19,10 @@ import ( "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/mock/mockid" "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server/config" @@ -30,29 +31,10 @@ import ( "go.uber.org/goleak" ) -func Test(t *testing.T) { - TestingT(t) -} - func TestMain(m *testing.M) { goleak.VerifyTestMain(m, testutil.LeakOptions...) } -var _ = SerialSuites(®ionSyncerTestSuite{}) - -type regionSyncerTestSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *regionSyncerTestSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) -} - -func (s *regionSyncerTestSuite) TearDownSuite(c *C) { - s.cancel() -} - type idAllocator struct { allocator *mockid.IDAllocator } @@ -62,51 +44,51 @@ func (i *idAllocator) alloc() uint64 { return v } -func (s *regionSyncerTestSuite) TestRegionSyncer(c *C) { - c.Assert(failpoint.Enable("github.com/tikv/pd/server/storage/regionStorageFastFlush", `return(true)`), IsNil) - c.Assert(failpoint.Enable("github.com/tikv/pd/server/syncer/noFastExitSync", `return(true)`), IsNil) - defer failpoint.Disable("github.com/tikv/pd/server/storage/regionStorageFastFlush") - defer failpoint.Disable("github.com/tikv/pd/server/syncer/noFastExitSync") +func TestRegionSyncer(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + re.NoError(failpoint.Enable("github.com/tikv/pd/server/storage/regionStorageFastFlush", `return(true)`)) + re.NoError(failpoint.Enable("github.com/tikv/pd/server/syncer/noFastExitSync", `return(true)`)) - cluster, err := tests.NewTestCluster(s.ctx, 3, func(conf *config.Config, serverName string) { conf.PDServerCfg.UseRegionStorage = true }) + cluster, err := tests.NewTestCluster(ctx, 3, func(conf *config.Config, serverName string) { conf.PDServerCfg.UseRegionStorage = true }) defer cluster.Destroy() - c.Assert(err, IsNil) + re.NoError(err) - err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(cluster.RunInitialServers()) cluster.WaitLeader() leaderServer := cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) + re.NoError(leaderServer.BootstrapCluster()) rc := leaderServer.GetServer().GetRaftCluster() - c.Assert(rc, NotNil) - c.Assert(cluster.WaitRegionSyncerClientsReady(2), IsTrue) + re.NotNil(rc) + re.True(cluster.WaitRegionSyncerClientsReady(2)) regionLen := 110 regions := initRegions(regionLen) for _, region := range regions { err = rc.HandleRegionHeartbeat(region) - c.Assert(err, IsNil) + re.NoError(err) } // merge case // region2 -> region1 -> region0 // merge A to B will increases version to max(versionA, versionB)+1, but does not increase conver regions[0] = regions[0].Clone(core.WithEndKey(regions[2].GetEndKey()), core.WithIncVersion(), core.WithIncVersion()) err = rc.HandleRegionHeartbeat(regions[2]) - c.Assert(err, IsNil) + re.NoError(err) // merge case // region3 -> region4 // merge A to B will increases version to max(versionA, versionB)+1, but does not increase conver regions[4] = regions[3].Clone(core.WithEndKey(regions[4].GetEndKey()), core.WithIncVersion()) err = rc.HandleRegionHeartbeat(regions[4]) - c.Assert(err, IsNil) + re.NoError(err) // merge case // region0 -> region4 // merge A to B will increases version to max(versionA, versionB)+1, but does not increase conver regions[4] = regions[0].Clone(core.WithEndKey(regions[4].GetEndKey()), core.WithIncVersion(), core.WithIncVersion()) err = rc.HandleRegionHeartbeat(regions[4]) - c.Assert(err, IsNil) + re.NoError(err) regions = regions[4:] regionLen = len(regions) @@ -119,14 +101,14 @@ func (s *regionSyncerTestSuite) TestRegionSyncer(c *C) { core.SetReadBytes(idx+30), core.SetReadKeys(idx+40)) err = rc.HandleRegionHeartbeat(regions[i]) - c.Assert(err, IsNil) + re.NoError(err) } // change the leader of region for i := 0; i < len(regions); i++ { regions[i] = regions[i].Clone(core.WithLeader(regions[i].GetPeers()[1])) err = rc.HandleRegionHeartbeat(regions[i]) - c.Assert(err, IsNil) + re.NoError(err) } // ensure flush to region storage, we use a duration larger than the @@ -135,15 +117,16 @@ func (s *regionSyncerTestSuite) TestRegionSyncer(c *C) { // test All regions have been synchronized to the cache of followerServer followerServer := cluster.GetServer(cluster.GetFollower()) - c.Assert(followerServer, NotNil) + re.NotNil(followerServer) cacheRegions := leaderServer.GetServer().GetBasicCluster().GetRegions() - c.Assert(cacheRegions, HasLen, regionLen) - testutil.WaitUntil(c, func() bool { + re.Len(cacheRegions, regionLen) + testutil.Eventually(re, func() bool { + assert := assert.New(t) for _, region := range cacheRegions { r := followerServer.GetServer().GetBasicCluster().GetRegion(region.GetID()) - if !(c.Check(r.GetMeta(), DeepEquals, region.GetMeta()) && - c.Check(r.GetStat(), DeepEquals, region.GetStat()) && - c.Check(r.GetLeader(), DeepEquals, region.GetLeader())) { + if !(assert.Equal(region.GetMeta(), r.GetMeta()) && + assert.Equal(region.GetStat(), r.GetStat()) && + assert.Equal(region.GetLeader(), r.GetLeader())) { return false } } @@ -151,106 +134,112 @@ func (s *regionSyncerTestSuite) TestRegionSyncer(c *C) { }) err = leaderServer.Stop() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() leaderServer = cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer, NotNil) + re.NotNil(leaderServer) loadRegions := leaderServer.GetServer().GetRaftCluster().GetRegions() - c.Assert(loadRegions, HasLen, regionLen) + re.Len(loadRegions, regionLen) for _, region := range regions { r := leaderServer.GetRegionInfoByID(region.GetID()) - c.Assert(r.GetMeta(), DeepEquals, region.GetMeta()) - c.Assert(r.GetStat(), DeepEquals, region.GetStat()) - c.Assert(r.GetLeader(), DeepEquals, region.GetLeader()) - c.Assert(r.GetBuckets(), DeepEquals, region.GetBuckets()) + re.Equal(region.GetMeta(), r.GetMeta()) + re.Equal(region.GetStat(), r.GetStat()) + re.Equal(region.GetLeader(), r.GetLeader()) + re.Equal(region.GetBuckets(), r.GetBuckets()) } + re.NoError(failpoint.Disable("github.com/tikv/pd/server/syncer/noFastExitSync")) + re.NoError(failpoint.Disable("github.com/tikv/pd/server/storage/regionStorageFastFlush")) } -func (s *regionSyncerTestSuite) TestFullSyncWithAddMember(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 1, func(conf *config.Config, serverName string) { conf.PDServerCfg.UseRegionStorage = true }) +func TestFullSyncWithAddMember(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 1, func(conf *config.Config, serverName string) { conf.PDServerCfg.UseRegionStorage = true }) defer cluster.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() leaderServer := cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) + re.NoError(leaderServer.BootstrapCluster()) rc := leaderServer.GetServer().GetRaftCluster() - c.Assert(rc, NotNil) + re.NotNil(rc) regionLen := 110 regions := initRegions(regionLen) for _, region := range regions { err = rc.HandleRegionHeartbeat(region) - c.Assert(err, IsNil) + re.NoError(err) } // ensure flush to region storage time.Sleep(3 * time.Second) // restart pd1 err = leaderServer.Stop() - c.Assert(err, IsNil) + re.NoError(err) err = leaderServer.Run() - c.Assert(err, IsNil) - c.Assert(cluster.WaitLeader(), Equals, "pd1") + re.NoError(err) + re.Equal("pd1", cluster.WaitLeader()) // join new PD - pd2, err := cluster.Join(s.ctx) - c.Assert(err, IsNil) - err = pd2.Run() - c.Assert(err, IsNil) - c.Assert(cluster.WaitLeader(), Equals, "pd1") + pd2, err := cluster.Join(ctx) + re.NoError(err) + re.NoError(pd2.Run()) + re.Equal("pd1", cluster.WaitLeader()) // waiting for synchronization to complete time.Sleep(3 * time.Second) - err = cluster.ResignLeader() - c.Assert(err, IsNil) - c.Assert(cluster.WaitLeader(), Equals, "pd2") + re.NoError(cluster.ResignLeader()) + re.Equal("pd2", cluster.WaitLeader()) loadRegions := pd2.GetServer().GetRaftCluster().GetRegions() - c.Assert(loadRegions, HasLen, regionLen) + re.Len(loadRegions, regionLen) } -func (s *regionSyncerTestSuite) TestPrepareChecker(c *C) { - c.Assert(failpoint.Enable("github.com/tikv/pd/server/cluster/changeCoordinatorTicker", `return(true)`), IsNil) - defer failpoint.Disable("github.com/tikv/pd/server/cluster/changeCoordinatorTicker") - cluster, err := tests.NewTestCluster(s.ctx, 1, func(conf *config.Config, serverName string) { conf.PDServerCfg.UseRegionStorage = true }) +func TestPrepareChecker(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + re.NoError(failpoint.Enable("github.com/tikv/pd/server/cluster/changeCoordinatorTicker", `return(true)`)) + cluster, err := tests.NewTestCluster(ctx, 1, func(conf *config.Config, serverName string) { conf.PDServerCfg.UseRegionStorage = true }) defer cluster.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() leaderServer := cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) + re.NoError(leaderServer.BootstrapCluster()) rc := leaderServer.GetServer().GetRaftCluster() - c.Assert(rc, NotNil) + re.NotNil(rc) regionLen := 110 regions := initRegions(regionLen) for _, region := range regions { err = rc.HandleRegionHeartbeat(region) - c.Assert(err, IsNil) + re.NoError(err) } // ensure flush to region storage time.Sleep(3 * time.Second) - c.Assert(leaderServer.GetRaftCluster().IsPrepared(), IsTrue) + re.True(leaderServer.GetRaftCluster().IsPrepared()) // join new PD - pd2, err := cluster.Join(s.ctx) - c.Assert(err, IsNil) + pd2, err := cluster.Join(ctx) + re.NoError(err) err = pd2.Run() - c.Assert(err, IsNil) + re.NoError(err) // waiting for synchronization to complete time.Sleep(3 * time.Second) err = cluster.ResignLeader() - c.Assert(err, IsNil) - c.Assert(cluster.WaitLeader(), Equals, "pd2") + re.NoError(err) + re.Equal("pd2", cluster.WaitLeader()) leaderServer = cluster.GetServer(cluster.GetLeader()) rc = leaderServer.GetServer().GetRaftCluster() for _, region := range regions { err = rc.HandleRegionHeartbeat(region) - c.Assert(err, IsNil) + re.NoError(err) } time.Sleep(time.Second) - c.Assert(rc.IsPrepared(), IsTrue) + re.True(rc.IsPrepared()) + re.NoError(failpoint.Disable("github.com/tikv/pd/server/cluster/changeCoordinatorTicker")) } func initRegions(regionLen int) []*core.RegionInfo { diff --git a/tests/server/server_test.go b/tests/server/server_test.go index 4369d03c966..85f1b2e4241 100644 --- a/tests/server/server_test.go +++ b/tests/server/server_test.go @@ -34,9 +34,9 @@ func TestMain(m *testing.M) { } func TestUpdateAdvertiseUrls(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - re := require.New(t) cluster, err := tests.NewTestCluster(ctx, 2) defer cluster.Destroy() re.NoError(err) @@ -60,10 +60,10 @@ func TestUpdateAdvertiseUrls(t *testing.T) { conf.AdvertisePeerURLs = conf.PeerURLs + "," + tempurl.Alloc() } for _, conf := range cluster.GetConfig().InitialServers { - serverConf, e := conf.Generate() - re.NoError(e) - s, e := tests.NewTestServer(ctx, serverConf) - re.NoError(e) + serverConf, err := conf.Generate() + re.NoError(err) + s, err := tests.NewTestServer(ctx, serverConf) + re.NoError(err) cluster.GetServers()[conf.Name] = s } err = cluster.RunInitialServers() @@ -75,9 +75,9 @@ func TestUpdateAdvertiseUrls(t *testing.T) { } func TestClusterID(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - re := require.New(t) cluster, err := tests.NewTestCluster(ctx, 3) defer cluster.Destroy() re.NoError(err) @@ -111,9 +111,9 @@ func TestClusterID(t *testing.T) { } func TestLeader(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - re := require.New(t) cluster, err := tests.NewTestCluster(ctx, 3) defer cluster.Destroy() re.NoError(err) @@ -122,12 +122,11 @@ func TestLeader(t *testing.T) { re.NoError(err) leader1 := cluster.WaitLeader() - re.NotEqual("", leader1) + re.NotEmpty(leader1) err = cluster.GetServer(leader1).Stop() re.NoError(err) testutil.Eventually(re, func() bool { - leader := cluster.GetLeader() - return leader != leader1 + return cluster.GetLeader() != leader1 }) } diff --git a/tests/server/watch/leader_watch_test.go b/tests/server/watch/leader_watch_test.go index 88d1470d733..4cdbee9d868 100644 --- a/tests/server/watch/leader_watch_test.go +++ b/tests/server/watch/leader_watch_test.go @@ -19,94 +19,79 @@ import ( "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/failpoint" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server/config" "github.com/tikv/pd/tests" "go.uber.org/goleak" ) -func Test(t *testing.T) { - TestingT(t) -} - func TestMain(m *testing.M) { goleak.VerifyTestMain(m, testutil.LeakOptions...) } -var _ = Suite(&watchTestSuite{}) - -type watchTestSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *watchTestSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) -} - -func (s *watchTestSuite) TearDownSuite(c *C) { - s.cancel() -} - -func (s *watchTestSuite) TestWatcher(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 1, func(conf *config.Config, serverName string) { conf.AutoCompactionRetention = "1s" }) +func TestWatcher(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 1, func(conf *config.Config, serverName string) { conf.AutoCompactionRetention = "1s" }) defer cluster.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() pd1 := cluster.GetServer(cluster.GetLeader()) - c.Assert(pd1, NotNil) + re.NotNil(pd1) - pd2, err := cluster.Join(s.ctx) - c.Assert(err, IsNil) + pd2, err := cluster.Join(ctx) + re.NoError(err) err = pd2.Run() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() time.Sleep(5 * time.Second) - pd3, err := cluster.Join(s.ctx) - c.Assert(err, IsNil) - c.Assert(failpoint.Enable("github.com/tikv/pd/server/delayWatcher", `pause`), IsNil) + pd3, err := cluster.Join(ctx) + re.NoError(err) + re.NoError(failpoint.Enable("github.com/tikv/pd/server/delayWatcher", `pause`)) err = pd3.Run() - c.Assert(err, IsNil) + re.NoError(err) time.Sleep(200 * time.Millisecond) - c.Assert(pd3.GetLeader().GetName(), Equals, pd1.GetConfig().Name) + re.Equal(pd1.GetConfig().Name, pd3.GetLeader().GetName()) err = pd1.Stop() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() - c.Assert(pd2.GetLeader().GetName(), Equals, pd2.GetConfig().Name) - failpoint.Disable("github.com/tikv/pd/server/delayWatcher") - testutil.WaitUntil(c, func() bool { - return c.Check(pd3.GetLeader().GetName(), Equals, pd2.GetConfig().Name) + re.Equal(pd2.GetConfig().Name, pd2.GetLeader().GetName()) + re.NoError(failpoint.Disable("github.com/tikv/pd/server/delayWatcher")) + testutil.Eventually(re, func() bool { + return pd3.GetLeader().GetName() == pd2.GetConfig().Name }) - c.Succeed() } -func (s *watchTestSuite) TestWatcherCompacted(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 1, func(conf *config.Config, serverName string) { conf.AutoCompactionRetention = "1s" }) +func TestWatcherCompacted(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 1, func(conf *config.Config, serverName string) { conf.AutoCompactionRetention = "1s" }) defer cluster.Destroy() - c.Assert(err, IsNil) + re.NoError(err) err = cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() pd1 := cluster.GetServer(cluster.GetLeader()) - c.Assert(pd1, NotNil) + re.NotNil(pd1) client := pd1.GetEtcdClient() _, err = client.Put(context.Background(), "test", "v") - c.Assert(err, IsNil) + re.NoError(err) // wait compaction time.Sleep(2 * time.Second) - pd2, err := cluster.Join(s.ctx) - c.Assert(err, IsNil) + pd2, err := cluster.Join(ctx) + re.NoError(err) err = pd2.Run() - c.Assert(err, IsNil) - testutil.WaitUntil(c, func() bool { - return c.Check(pd2.GetLeader().GetName(), Equals, pd1.GetConfig().Name) + re.NoError(err) + testutil.Eventually(re, func() bool { + return pd2.GetLeader().GetName() == pd1.GetConfig().Name }) - c.Succeed() } From 170b4984c6dc1405feccac7460f89b6c6d1d0da6 Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Wed, 22 Jun 2022 17:30:37 +0800 Subject: [PATCH 69/82] *: some misc modify for testify (#5211) ref tikv/pd#4813, ref tikv/pd#5193 Signed-off-by: lhy1024 Co-authored-by: Ti Chi Robot --- pkg/errs/errs_test.go | 2 +- pkg/movingaverage/moving_average_test.go | 4 +-- .../checker/joint_state_checker_test.go | 4 +-- .../checker/priority_inspector_test.go | 4 +-- server/schedule/placement/rule_test.go | 2 +- tests/client/client_test.go | 28 +++++++++---------- tests/pdctl/config/config_test.go | 4 +-- tests/server/tso/manager_test.go | 2 +- tools/pd-ctl/pdctl/ctl_test.go | 2 +- 9 files changed, 26 insertions(+), 26 deletions(-) diff --git a/pkg/errs/errs_test.go b/pkg/errs/errs_test.go index 4556898d9fd..591d9f899ce 100644 --- a/pkg/errs/errs_test.go +++ b/pkg/errs/errs_test.go @@ -141,5 +141,5 @@ func TestErrorWithStack(t *testing.T) { re.GreaterOrEqual(idx1, -1) idx2 := strings.Index(m2, "[stack=") re.GreaterOrEqual(idx2, -1) - re.Equal(len(m1[idx1:]), len(m2[idx2:])) + re.Len(m2[idx2:], len(m1[idx1:])) } diff --git a/pkg/movingaverage/moving_average_test.go b/pkg/movingaverage/moving_average_test.go index 9f6864b007b..173eea773bb 100644 --- a/pkg/movingaverage/moving_average_test.go +++ b/pkg/movingaverage/moving_average_test.go @@ -40,7 +40,7 @@ func checkReset(re *require.Assertions, ma MovingAvg, emptyValue float64) { // checkAddGet checks Add works properly. func checkAdd(re *require.Assertions, ma MovingAvg, data []float64, expected []float64) { - re.Equal(len(expected), len(data)) + re.Len(data, len(expected)) for i, x := range data { ma.Add(x) re.LessOrEqual(math.Abs(ma.Get()-expected[i]), 1e-7) @@ -49,7 +49,7 @@ func checkAdd(re *require.Assertions, ma MovingAvg, data []float64, expected []f // checkSet checks Set = Reset + Add func checkSet(re *require.Assertions, ma MovingAvg, data []float64, expected []float64) { - re.Equal(len(expected), len(data)) + re.Len(data, len(expected)) // Reset + Add addRandData(ma, 100, 1000) diff --git a/server/schedule/checker/joint_state_checker_test.go b/server/schedule/checker/joint_state_checker_test.go index b350de469c4..e2986ae11ee 100644 --- a/server/schedule/checker/joint_state_checker_test.go +++ b/server/schedule/checker/joint_state_checker_test.go @@ -132,8 +132,8 @@ func checkSteps(re *require.Assertions, op *operator.Operator, steps []operator. switch obtain := op.Step(i).(type) { case operator.ChangePeerV2Leave: expect := steps[i].(operator.ChangePeerV2Leave) - re.Equal(len(expect.PromoteLearners), len(obtain.PromoteLearners)) - re.Equal(len(expect.DemoteVoters), len(obtain.DemoteVoters)) + re.Len(obtain.PromoteLearners, len(expect.PromoteLearners)) + re.Len(obtain.DemoteVoters, len(expect.DemoteVoters)) for j, p := range expect.PromoteLearners { re.Equal(p.ToStore, obtain.PromoteLearners[j].ToStore) } diff --git a/server/schedule/checker/priority_inspector_test.go b/server/schedule/checker/priority_inspector_test.go index 35662846c4a..9e0b0264a45 100644 --- a/server/schedule/checker/priority_inspector_test.go +++ b/server/schedule/checker/priority_inspector_test.go @@ -56,7 +56,7 @@ func checkPriorityRegionTest(re *require.Assertions, pc *PriorityInspector, tc * pc.Inspect(region) re.Equal(1, pc.queue.Len()) // the region will not rerun after it checks - re.Equal(0, len(pc.GetPriorityRegions())) + re.Len(pc.GetPriorityRegions(), 0) // case3: inspect region 3, it will has high priority region = tc.GetRegion(3) @@ -65,7 +65,7 @@ func checkPriorityRegionTest(re *require.Assertions, pc *PriorityInspector, tc * time.Sleep(opt.GetPatrolRegionInterval() * 10) // region 3 has higher priority ids := pc.GetPriorityRegions() - re.Equal(2, len(ids)) + re.Len(ids, 2) re.Equal(uint64(3), ids[0]) re.Equal(uint64(2), ids[1]) diff --git a/server/schedule/placement/rule_test.go b/server/schedule/placement/rule_test.go index ba2d1bf50e4..52d25c82e03 100644 --- a/server/schedule/placement/rule_test.go +++ b/server/schedule/placement/rule_test.go @@ -54,7 +54,7 @@ func TestPrepareRulesForApply(t *testing.T) { sortRules(rules) rules = prepareRulesForApply(rules) - re.Equal(len(expected), len(rules)) + re.Len(rules, len(expected)) for i := range rules { re.Equal(expected[i], rules[i].Key()) } diff --git a/tests/client/client_test.go b/tests/client/client_test.go index a9022d04a20..c6c4f45ee47 100644 --- a/tests/client/client_test.go +++ b/tests/client/client_test.go @@ -192,11 +192,11 @@ func TestUpdateAfterResetTSO(t *testing.T) { return err == nil }) // Transfer leader to trigger the TSO resetting. - re.Nil(failpoint.Enable("github.com/tikv/pd/server/updateAfterResetTSO", "return(true)")) + re.NoError(failpoint.Enable("github.com/tikv/pd/server/updateAfterResetTSO", "return(true)")) oldLeaderName := cluster.WaitLeader() err = cluster.GetServer(oldLeaderName).ResignLeader() re.NoError(err) - re.Nil(failpoint.Disable("github.com/tikv/pd/server/updateAfterResetTSO")) + re.NoError(failpoint.Disable("github.com/tikv/pd/server/updateAfterResetTSO")) newLeaderName := cluster.WaitLeader() re.NotEqual(oldLeaderName, newLeaderName) // Request a new TSO. @@ -205,7 +205,7 @@ func TestUpdateAfterResetTSO(t *testing.T) { return err == nil }) // Transfer leader back. - re.Nil(failpoint.Enable("github.com/tikv/pd/server/tso/delaySyncTimestamp", `return(true)`)) + re.NoError(failpoint.Enable("github.com/tikv/pd/server/tso/delaySyncTimestamp", `return(true)`)) err = cluster.GetServer(newLeaderName).ResignLeader() re.NoError(err) // Should NOT panic here. @@ -213,7 +213,7 @@ func TestUpdateAfterResetTSO(t *testing.T) { _, _, err := cli.GetTS(context.TODO()) return err == nil }) - re.Nil(failpoint.Disable("github.com/tikv/pd/server/tso/delaySyncTimestamp")) + re.NoError(failpoint.Disable("github.com/tikv/pd/server/tso/delaySyncTimestamp")) } func TestTSOAllocatorLeader(t *testing.T) { @@ -360,7 +360,7 @@ func TestGlobalAndLocalTSO(t *testing.T) { requestGlobalAndLocalTSO(re, wg, dcLocationConfig, cli) // assert global tso after resign leader - re.Nil(failpoint.Enable("github.com/tikv/pd/client/skipUpdateMember", `return(true)`)) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/skipUpdateMember", `return(true)`)) err = cluster.ResignLeader() re.NoError(err) cluster.WaitLeader() @@ -369,7 +369,7 @@ func TestGlobalAndLocalTSO(t *testing.T) { re.True(pd.IsLeaderChange(err)) _, _, err = cli.GetTS(ctx) re.NoError(err) - re.Nil(failpoint.Disable("github.com/tikv/pd/client/skipUpdateMember")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/skipUpdateMember")) // Test the TSO follower proxy while enabling the Local TSO. cli.UpdateOption(pd.EnableTSOFollowerProxy, true) @@ -427,9 +427,9 @@ func TestCustomTimeout(t *testing.T) { cli := setupCli(re, ctx, endpoints, pd.WithCustomTimeoutOption(1*time.Second)) start := time.Now() - re.Nil(failpoint.Enable("github.com/tikv/pd/server/customTimeout", "return(true)")) + re.NoError(failpoint.Enable("github.com/tikv/pd/server/customTimeout", "return(true)")) _, err = cli.GetAllStores(context.TODO()) - re.Nil(failpoint.Disable("github.com/tikv/pd/server/customTimeout")) + re.NoError(failpoint.Disable("github.com/tikv/pd/server/customTimeout")) re.Error(err) re.GreaterOrEqual(time.Since(start), 1*time.Second) re.Less(time.Since(start), 2*time.Second) @@ -447,13 +447,13 @@ func TestGetRegionFromFollowerClient(t *testing.T) { endpoints := runServer(re, cluster) cli := setupCli(re, ctx, endpoints, pd.WithForwardingOption(true)) - re.Nil(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork1", "return(true)")) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork1", "return(true)")) time.Sleep(200 * time.Millisecond) r, err := cli.GetRegion(context.Background(), []byte("a")) re.NoError(err) re.NotNil(r) - re.Nil(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork1")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork1")) time.Sleep(200 * time.Millisecond) r, err = cli.GetRegion(context.Background(), []byte("a")) re.NoError(err) @@ -473,7 +473,7 @@ func TestGetTsoFromFollowerClient1(t *testing.T) { endpoints := runServer(re, cluster) cli := setupCli(re, ctx, endpoints, pd.WithForwardingOption(true)) - re.Nil(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork", "return(true)")) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork", "return(true)")) var lastTS uint64 testutil.Eventually(re, func() bool { physical, logical, err := cli.GetTS(context.TODO()) @@ -486,7 +486,7 @@ func TestGetTsoFromFollowerClient1(t *testing.T) { }) lastTS = checkTS(re, cli, lastTS) - re.Nil(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork")) time.Sleep(2 * time.Second) checkTS(re, cli, lastTS) } @@ -504,7 +504,7 @@ func TestGetTsoFromFollowerClient2(t *testing.T) { endpoints := runServer(re, cluster) cli := setupCli(re, ctx, endpoints, pd.WithForwardingOption(true)) - re.Nil(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork", "return(true)")) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork", "return(true)")) var lastTS uint64 testutil.Eventually(re, func() bool { physical, logical, err := cli.GetTS(context.TODO()) @@ -521,7 +521,7 @@ func TestGetTsoFromFollowerClient2(t *testing.T) { cluster.WaitLeader() lastTS = checkTS(re, cli, lastTS) - re.Nil(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork")) time.Sleep(5 * time.Second) checkTS(re, cli, lastTS) } diff --git a/tests/pdctl/config/config_test.go b/tests/pdctl/config/config_test.go index f5acd3fd3ff..a0de13e30d6 100644 --- a/tests/pdctl/config/config_test.go +++ b/tests/pdctl/config/config_test.go @@ -761,7 +761,7 @@ func TestPDServerConfig(t *testing.T) { } func assertBundles(re *require.Assertions, a, b []placement.GroupBundle) { - re.Equal(len(a), len(b)) + re.Len(b, len(a)) for i := 0; i < len(a); i++ { assertBundle(re, a[i], b[i]) } @@ -771,7 +771,7 @@ func assertBundle(re *require.Assertions, a, b placement.GroupBundle) { re.Equal(a.ID, b.ID) re.Equal(a.Index, b.Index) re.Equal(a.Override, b.Override) - re.Equal(len(a.Rules), len(b.Rules)) + re.Len(b.Rules, len(a.Rules)) for i := 0; i < len(a.Rules); i++ { assertRule(re, a.Rules[i], b.Rules[i]) } diff --git a/tests/server/tso/manager_test.go b/tests/server/tso/manager_test.go index 5ea8bc4be92..00278544f55 100644 --- a/tests/server/tso/manager_test.go +++ b/tests/server/tso/manager_test.go @@ -130,7 +130,7 @@ func TestLocalTSOSuffix(t *testing.T) { clientv3.WithPrefix(), clientv3.WithSort(clientv3.SortByValue, clientv3.SortAscend)) re.NoError(err) - re.Equal(len(testCase.dcLocations), len(allSuffixResp.Kvs)) + re.Len(allSuffixResp.Kvs, len(testCase.dcLocations)) var lastSuffixNum int64 for _, kv := range allSuffixResp.Kvs { suffixNum, err := strconv.ParseInt(string(kv.Value), 10, 64) diff --git a/tools/pd-ctl/pdctl/ctl_test.go b/tools/pd-ctl/pdctl/ctl_test.go index 6dc29058e34..d9cea460e21 100644 --- a/tools/pd-ctl/pdctl/ctl_test.go +++ b/tools/pd-ctl/pdctl/ctl_test.go @@ -86,7 +86,7 @@ func TestReadStdin(t *testing.T) { for _, v := range s { in, err := ReadStdin(v.in) re.NoError(err) - re.Equal(len(v.targets), len(in)) + re.Len(in, len(v.targets)) for i, target := range v.targets { re.Equal(target, in[i]) } From 18fbc4a62b04cf48129927bfce45322bf7c825e6 Mon Sep 17 00:00:00 2001 From: LLThomas Date: Wed, 22 Jun 2022 18:56:37 +0800 Subject: [PATCH 70/82] pkg/testutil: add a WithTestify func for CheckTransferLeader (#5213) ref tikv/pd#4813 Add a WithTestify func for CheckTransferLeader and CheckTransferLeaderWithTestify will be used in server/cluster/coordinator_test.go. Signed-off-by: LLThomas --- pkg/testutil/operator_check.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pkg/testutil/operator_check.go b/pkg/testutil/operator_check.go index 1df641e7e0a..e8dca8dff47 100644 --- a/pkg/testutil/operator_check.go +++ b/pkg/testutil/operator_check.go @@ -165,6 +165,15 @@ func CheckRemovePeerWithTestify(re *require.Assertions, op *operator.Operator, s } } +// CheckTransferLeaderWithTestify checks if the operator is to transfer leader between the specified source and target stores. +func CheckTransferLeaderWithTestify(re *require.Assertions, op *operator.Operator, kind operator.OpKind, sourceID, targetID uint64) { + re.NotNil(op) + re.Equal(1, op.Len()) + re.Equal(operator.TransferLeader{FromStore: sourceID, ToStore: targetID}, op.Step(0)) + kind |= operator.OpLeader + re.Equal(kind, op.Kind()&kind) +} + // CheckTransferPeerWithTestify checks if the operator is to transfer peer between the specified source and target stores. func CheckTransferPeerWithTestify(re *require.Assertions, op *operator.Operator, kind operator.OpKind, sourceID, targetID uint64) { re.NotNil(op) From 3b3ff6973da682b04970df60c3fd3984aa14a761 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Wed, 22 Jun 2022 19:52:38 +0800 Subject: [PATCH 71/82] pkg, api, tests: migrate the server/api tests to testify (#5204) close tikv/pd#5199 Migrate the server/api tests to testify. Signed-off-by: JmPotato Co-authored-by: Ti Chi Robot --- client/option_test.go | 2 +- client/testutil/testutil.go | 42 +- pkg/testutil/api_check.go | 36 +- pkg/testutil/testutil.go | 11 +- pkg/typeutil/duration_test.go | 2 +- server/api/admin_test.go | 129 ++-- server/api/checker_test.go | 136 ++-- server/api/cluster_test.go | 88 +-- server/api/config_test.go | 365 +++++------ server/api/etcd_api_test.go | 20 +- server/api/health_test.go | 32 +- server/api/hot_status_test.go | 85 +-- server/api/label_test.go | 119 ++-- server/api/log_test.go | 39 +- server/api/member_test.go | 124 ++-- server/api/min_resolved_ts_test.go | 57 +- server/api/operator_test.go | 263 ++++---- server/api/pprof_test.go | 43 +- server/api/region.go | 10 +- server/api/region_label_test.go | 74 ++- server/api/region_test.go | 588 +++++++++--------- server/api/rule_test.go | 497 +++++++-------- server/api/scheduler_test.go | 444 +++++++------ server/api/server_test.go | 137 ++-- server/api/service_gc_safepoint_test.go | 53 +- server/api/service_middleware_test.go | 270 ++++---- server/api/stats_test.go | 54 +- server/api/status_test.go | 28 +- server/api/store_test.go | 394 ++++++------ server/api/trend_test.go | 71 +-- server/api/tso_test.go | 42 +- server/api/unsafe_operation_test.go | 60 +- server/api/version_test.go | 35 +- .../placement/region_rule_cache_test.go | 2 +- tests/client/client_test.go | 2 +- tests/pdctl/helper.go | 6 +- tests/server/api/api_test.go | 4 +- tests/server/cluster/cluster_test.go | 32 +- tests/server/cluster/cluster_work_test.go | 6 +- tests/server/config/config_test.go | 64 +- tests/server/id/id_test.go | 2 +- tests/server/member/member_test.go | 2 +- tests/server/tso/consistency_test.go | 20 +- tests/server/tso/global_tso_test.go | 6 +- tests/server/tso/tso_test.go | 2 +- 45 files changed, 2300 insertions(+), 2198 deletions(-) diff --git a/client/option_test.go b/client/option_test.go index 2a7f7824e12..1b5604f4d19 100644 --- a/client/option_test.go +++ b/client/option_test.go @@ -45,7 +45,7 @@ func TestDynamicOptionChange(t *testing.T) { expectBool := true o.setEnableTSOFollowerProxy(expectBool) // Check the value changing notification. - testutil.WaitUntil(t, func() bool { + testutil.Eventually(re, func() bool { <-o.enableTSOFollowerProxyCh return true }) diff --git a/client/testutil/testutil.go b/client/testutil/testutil.go index 095a31ae74a..79a3c9eb913 100644 --- a/client/testutil/testutil.go +++ b/client/testutil/testutil.go @@ -15,49 +15,47 @@ package testutil import ( - "testing" "time" + + "github.com/stretchr/testify/require" ) const ( - waitMaxRetry = 200 - waitRetrySleep = time.Millisecond * 100 + defaultWaitFor = time.Second * 20 + defaultSleepInterval = time.Millisecond * 100 ) -// WaitOp represents available options when execute WaitUntil +// WaitOp represents available options when execute Eventually. type WaitOp struct { - retryTimes int + waitFor time.Duration sleepInterval time.Duration } // WaitOption configures WaitOp type WaitOption func(op *WaitOp) -// WithRetryTimes specify the retry times -func WithRetryTimes(retryTimes int) WaitOption { - return func(op *WaitOp) { op.retryTimes = retryTimes } -} - // WithSleepInterval specify the sleep duration func WithSleepInterval(sleep time.Duration) WaitOption { return func(op *WaitOp) { op.sleepInterval = sleep } } -// WaitUntil repeatedly evaluates f() for a period of time, util it returns true. -func WaitUntil(t *testing.T, f func() bool, opts ...WaitOption) { - t.Log("wait start") +// WithWaitFor specify the max wait for duration +func WithWaitFor(waitFor time.Duration) WaitOption { + return func(op *WaitOp) { op.waitFor = waitFor } +} + +// Eventually asserts that given condition will be met in a period of time. +func Eventually(re *require.Assertions, condition func() bool, opts ...WaitOption) { option := &WaitOp{ - retryTimes: waitMaxRetry, - sleepInterval: waitRetrySleep, + waitFor: defaultWaitFor, + sleepInterval: defaultSleepInterval, } for _, opt := range opts { opt(option) } - for i := 0; i < option.retryTimes; i++ { - if f() { - return - } - time.Sleep(option.sleepInterval) - } - t.Fatal("wait timeout") + re.Eventually( + condition, + option.waitFor, + option.sleepInterval, + ) } diff --git a/pkg/testutil/api_check.go b/pkg/testutil/api_check.go index cbe0e01d166..015b7168a8c 100644 --- a/pkg/testutil/api_check.go +++ b/pkg/testutil/api_check.go @@ -18,69 +18,67 @@ import ( "encoding/json" "io" "net/http" - "strings" - "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/apiutil" ) // Status is used to check whether http response code is equal given code -func Status(c *check.C, code int) func([]byte, int) { +func Status(re *require.Assertions, code int) func([]byte, int) { return func(_ []byte, i int) { - c.Assert(i, check.Equals, code) + re.Equal(code, i) } } // StatusOK is used to check whether http response code is equal http.StatusOK -func StatusOK(c *check.C) func([]byte, int) { - return Status(c, http.StatusOK) +func StatusOK(re *require.Assertions) func([]byte, int) { + return Status(re, http.StatusOK) } // StatusNotOK is used to check whether http response code is not equal http.StatusOK -func StatusNotOK(c *check.C) func([]byte, int) { +func StatusNotOK(re *require.Assertions) func([]byte, int) { return func(_ []byte, i int) { - c.Assert(i == http.StatusOK, check.IsFalse) + re.NotEqual(http.StatusOK, i) } } // ExtractJSON is used to check whether given data can be extracted successfully -func ExtractJSON(c *check.C, data interface{}) func([]byte, int) { +func ExtractJSON(re *require.Assertions, data interface{}) func([]byte, int) { return func(res []byte, _ int) { - err := json.Unmarshal(res, data) - c.Assert(err, check.IsNil) + re.NoError(json.Unmarshal(res, data)) } } // StringContain is used to check whether response context contains given string -func StringContain(c *check.C, sub string) func([]byte, int) { +func StringContain(re *require.Assertions, sub string) func([]byte, int) { return func(res []byte, _ int) { - c.Assert(strings.Contains(string(res), sub), check.IsTrue) + re.Contains(string(res), sub) } } // StringEqual is used to check whether response context equal given string -func StringEqual(c *check.C, str string) func([]byte, int) { +func StringEqual(re *require.Assertions, str string) func([]byte, int) { return func(res []byte, _ int) { - c.Assert(strings.Contains(string(res), str), check.IsTrue) + re.Contains(string(res), str) } } // ReadGetJSON is used to do get request and check whether given data can be extracted successfully -func ReadGetJSON(c *check.C, client *http.Client, url string, data interface{}) error { +func ReadGetJSON(re *require.Assertions, client *http.Client, url string, data interface{}) error { resp, err := apiutil.GetJSON(client, url, nil) if err != nil { return err } - return checkResp(resp, StatusOK(c), ExtractJSON(c, data)) + return checkResp(resp, StatusOK(re), ExtractJSON(re, data)) } // ReadGetJSONWithBody is used to do get request with input and check whether given data can be extracted successfully -func ReadGetJSONWithBody(c *check.C, client *http.Client, url string, input []byte, data interface{}) error { +func ReadGetJSONWithBody(re *require.Assertions, client *http.Client, url string, input []byte, data interface{}) error { resp, err := apiutil.GetJSON(client, url, input) if err != nil { return err } - return checkResp(resp, StatusOK(c), ExtractJSON(c, data)) + return checkResp(resp, StatusOK(re), ExtractJSON(re, data)) } // CheckPostJSON is used to do post request and do check options diff --git a/pkg/testutil/testutil.go b/pkg/testutil/testutil.go index 59063aa5385..236438ecfd1 100644 --- a/pkg/testutil/testutil.go +++ b/pkg/testutil/testutil.go @@ -104,16 +104,7 @@ func NewRequestHeader(clusterID uint64) *pdpb.RequestHeader { } // MustNewGrpcClient must create a new grpc client. -func MustNewGrpcClient(c *check.C, addr string) pdpb.PDClient { - conn, err := grpc.Dial(strings.TrimPrefix(addr, "http://"), grpc.WithInsecure()) - - c.Assert(err, check.IsNil) - return pdpb.NewPDClient(conn) -} - -// MustNewGrpcClientWithTestify must create a new grpc client. -// NOTICE: this is a temporary function that we will be used to replace `MustNewGrpcClient` later. -func MustNewGrpcClientWithTestify(re *require.Assertions, addr string) pdpb.PDClient { +func MustNewGrpcClient(re *require.Assertions, addr string) pdpb.PDClient { conn, err := grpc.Dial(strings.TrimPrefix(addr, "http://"), grpc.WithInsecure()) re.NoError(err) diff --git a/pkg/typeutil/duration_test.go b/pkg/typeutil/duration_test.go index f815b29ab6b..9a0beda7979 100644 --- a/pkg/typeutil/duration_test.go +++ b/pkg/typeutil/duration_test.go @@ -32,7 +32,7 @@ func TestDurationJSON(t *testing.T) { example := &example{} text := []byte(`{"interval":"1h1m1s"}`) - re.Nil(json.Unmarshal(text, example)) + re.NoError(json.Unmarshal(text, example)) re.Equal(float64(60*60+60+1), example.Interval.Seconds()) b, err := json.Marshal(example) diff --git a/server/api/admin_test.go b/server/api/admin_test.go index 1ece28a5239..ba9aaa875a4 100644 --- a/server/api/admin_test.go +++ b/server/api/admin_test.go @@ -18,39 +18,44 @@ import ( "encoding/json" "fmt" "net/http" + "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/suite" tu "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server" "github.com/tikv/pd/server/core" ) -var _ = Suite(&testAdminSuite{}) - -type testAdminSuite struct { +type adminTestSuite struct { + suite.Suite svr *server.Server cleanup cleanUpFunc urlPrefix string } -func (s *testAdminSuite) SetUpSuite(c *C) { - s.svr, s.cleanup = mustNewServer(c) - mustWaitLeader(c, []*server.Server{s.svr}) +func TestAdminTestSuite(t *testing.T) { + suite.Run(t, new(adminTestSuite)) +} + +func (suite *adminTestSuite) SetupSuite() { + re := suite.Require() + suite.svr, suite.cleanup = mustNewServer(re) + mustWaitLeader(re, []*server.Server{suite.svr}) - addr := s.svr.GetAddr() - s.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) + addr := suite.svr.GetAddr() + suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) - mustBootstrapCluster(c, s.svr) + mustBootstrapCluster(re, suite.svr) } -func (s *testAdminSuite) TearDownSuite(c *C) { - s.cleanup() +func (suite *adminTestSuite) TearDownSuite() { + suite.cleanup() } -func (s *testAdminSuite) TestDropRegion(c *C) { - cluster := s.svr.GetRaftCluster() +func (suite *adminTestSuite) TestDropRegion() { + cluster := suite.svr.GetRaftCluster() // Update region's epoch to (100, 100). region := cluster.GetRegionByKey([]byte("foo")).Clone( @@ -63,7 +68,7 @@ func (s *testAdminSuite) TestDropRegion(c *C) { }, })) err := cluster.HandleRegionHeartbeat(region) - c.Assert(err, IsNil) + suite.NoError(err) // Region epoch cannot decrease. region = region.Clone( @@ -71,39 +76,32 @@ func (s *testAdminSuite) TestDropRegion(c *C) { core.SetRegionVersion(50), ) err = cluster.HandleRegionHeartbeat(region) - c.Assert(err, NotNil) + suite.Error(err) // After drop region from cache, lower version is accepted. - url := fmt.Sprintf("%s/admin/cache/region/%d", s.urlPrefix, region.GetID()) + url := fmt.Sprintf("%s/admin/cache/region/%d", suite.urlPrefix, region.GetID()) req, err := http.NewRequest(http.MethodDelete, url, nil) - c.Assert(err, IsNil) + suite.NoError(err) res, err := testDialClient.Do(req) - c.Assert(err, IsNil) - c.Assert(res.StatusCode, Equals, http.StatusOK) + suite.NoError(err) + suite.Equal(http.StatusOK, res.StatusCode) res.Body.Close() err = cluster.HandleRegionHeartbeat(region) - c.Assert(err, IsNil) + suite.NoError(err) region = cluster.GetRegionByKey([]byte("foo")) - c.Assert(region.GetRegionEpoch().ConfVer, Equals, uint64(50)) - c.Assert(region.GetRegionEpoch().Version, Equals, uint64(50)) + suite.Equal(uint64(50), region.GetRegionEpoch().ConfVer) + suite.Equal(uint64(50), region.GetRegionEpoch().Version) } -func (s *testAdminSuite) TestPersistFile(c *C) { +func (suite *adminTestSuite) TestPersistFile() { data := []byte("#!/bin/sh\nrm -rf /") - err := tu.CheckPostJSON(testDialClient, s.urlPrefix+"/admin/persist-file/fun.sh", data, tu.StatusNotOK(c)) - c.Assert(err, IsNil) + re := suite.Require() + err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/admin/persist-file/fun.sh", data, tu.StatusNotOK(re)) + suite.NoError(err) data = []byte(`{"foo":"bar"}`) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/admin/persist-file/good.json", data, tu.StatusOK(c)) - c.Assert(err, IsNil) -} - -var _ = Suite(&testTSOSuite{}) - -type testTSOSuite struct { - svr *server.Server - cleanup cleanUpFunc - urlPrefix string + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/admin/persist-file/good.json", data, tu.StatusOK(re)) + suite.NoError(err) } func makeTS(offset time.Duration) uint64 { @@ -111,62 +109,49 @@ func makeTS(offset time.Duration) uint64 { return uint64(physical << 18) } -func (s *testTSOSuite) SetUpSuite(c *C) { - s.svr, s.cleanup = mustNewServer(c) - mustWaitLeader(c, []*server.Server{s.svr}) - addr := s.svr.GetAddr() - s.urlPrefix = fmt.Sprintf("%s%s/api/v1/admin/reset-ts", addr, apiPrefix) - - mustBootstrapCluster(c, s.svr) - mustPutStore(c, s.svr, 1, metapb.StoreState_Up, metapb.NodeState_Serving, nil) -} - -func (s *testTSOSuite) TearDownSuite(c *C) { - s.cleanup() -} - -func (s *testTSOSuite) TestResetTS(c *C) { +func (suite *adminTestSuite) TestResetTS() { args := make(map[string]interface{}) t1 := makeTS(time.Hour) - url := s.urlPrefix + url := fmt.Sprintf("%s/admin/reset-ts", suite.urlPrefix) args["tso"] = fmt.Sprintf("%d", t1) values, err := json.Marshal(args) - c.Assert(err, IsNil) + suite.NoError(err) + re := suite.Require() err = tu.CheckPostJSON(testDialClient, url, values, - tu.StatusOK(c), - tu.StringEqual(c, "\"Reset ts successfully.\"\n")) - c.Assert(err, IsNil) + tu.StatusOK(re), + tu.StringEqual(re, "\"Reset ts successfully.\"\n")) + suite.NoError(err) t2 := makeTS(32 * time.Hour) args["tso"] = fmt.Sprintf("%d", t2) values, err = json.Marshal(args) - c.Assert(err, IsNil) + suite.NoError(err) err = tu.CheckPostJSON(testDialClient, url, values, - tu.Status(c, http.StatusForbidden), - tu.StringContain(c, "too large")) - c.Assert(err, IsNil) + tu.Status(re, http.StatusForbidden), + tu.StringContain(re, "too large")) + suite.NoError(err) t3 := makeTS(-2 * time.Hour) args["tso"] = fmt.Sprintf("%d", t3) values, err = json.Marshal(args) - c.Assert(err, IsNil) + suite.NoError(err) err = tu.CheckPostJSON(testDialClient, url, values, - tu.Status(c, http.StatusForbidden), - tu.StringContain(c, "small")) - c.Assert(err, IsNil) + tu.Status(re, http.StatusForbidden), + tu.StringContain(re, "small")) + suite.NoError(err) args["tso"] = "" values, err = json.Marshal(args) - c.Assert(err, IsNil) + suite.NoError(err) err = tu.CheckPostJSON(testDialClient, url, values, - tu.Status(c, http.StatusBadRequest), - tu.StringEqual(c, "\"invalid tso value\"\n")) - c.Assert(err, IsNil) + tu.Status(re, http.StatusBadRequest), + tu.StringEqual(re, "\"invalid tso value\"\n")) + suite.NoError(err) args["tso"] = "test" values, err = json.Marshal(args) - c.Assert(err, IsNil) + suite.NoError(err) err = tu.CheckPostJSON(testDialClient, url, values, - tu.Status(c, http.StatusBadRequest), - tu.StringEqual(c, "\"invalid tso value\"\n")) - c.Assert(err, IsNil) + tu.Status(re, http.StatusBadRequest), + tu.StringEqual(re, "\"invalid tso value\"\n")) + suite.NoError(err) } diff --git a/server/api/checker_test.go b/server/api/checker_test.go index 99fbb9ee68a..a3ab815ffb7 100644 --- a/server/api/checker_test.go +++ b/server/api/checker_test.go @@ -17,40 +17,45 @@ package api import ( "encoding/json" "fmt" + "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/suite" tu "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server" ) -var _ = Suite(&testCheckerSuite{}) - -type testCheckerSuite struct { +type checkerTestSuite struct { + suite.Suite svr *server.Server cleanup cleanUpFunc urlPrefix string } -func (s *testCheckerSuite) SetUpSuite(c *C) { - s.svr, s.cleanup = mustNewServer(c) - mustWaitLeader(c, []*server.Server{s.svr}) +func TestCheckerTestSuite(t *testing.T) { + suite.Run(t, new(checkerTestSuite)) +} + +func (suite *checkerTestSuite) SetupSuite() { + re := suite.Require() + suite.svr, suite.cleanup = mustNewServer(re) + mustWaitLeader(re, []*server.Server{suite.svr}) - addr := s.svr.GetAddr() - s.urlPrefix = fmt.Sprintf("%s%s/api/v1/checker", addr, apiPrefix) + addr := suite.svr.GetAddr() + suite.urlPrefix = fmt.Sprintf("%s%s/api/v1/checker", addr, apiPrefix) - mustBootstrapCluster(c, s.svr) - mustPutStore(c, s.svr, 1, metapb.StoreState_Up, metapb.NodeState_Serving, nil) - mustPutStore(c, s.svr, 2, metapb.StoreState_Up, metapb.NodeState_Serving, nil) + mustBootstrapCluster(re, suite.svr) + mustPutStore(re, suite.svr, 1, metapb.StoreState_Up, metapb.NodeState_Serving, nil) + mustPutStore(re, suite.svr, 2, metapb.StoreState_Up, metapb.NodeState_Serving, nil) } -func (s *testCheckerSuite) TearDownSuite(c *C) { - s.cleanup() +func (suite *checkerTestSuite) TearDownSuite() { + suite.cleanup() } -func (s *testCheckerSuite) TestAPI(c *C) { - s.testErrCases(c) +func (suite *checkerTestSuite) TestAPI() { + suite.testErrCases() cases := []struct { name string @@ -63,101 +68,104 @@ func (s *testCheckerSuite) TestAPI(c *C) { {name: "joint-state"}, } for _, ca := range cases { - s.testGetStatus(ca.name, c) - s.testPauseOrResume(ca.name, c) + suite.testGetStatus(ca.name) + suite.testPauseOrResume(ca.name) } } -func (s *testCheckerSuite) testErrCases(c *C) { +func (suite *checkerTestSuite) testErrCases() { // missing args input := make(map[string]interface{}) pauseArgs, err := json.Marshal(input) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/merge", pauseArgs, tu.StatusNotOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + re := suite.Require() + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/merge", pauseArgs, tu.StatusNotOK(re)) + suite.NoError(err) // negative delay input["delay"] = -10 pauseArgs, err = json.Marshal(input) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/merge", pauseArgs, tu.StatusNotOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/merge", pauseArgs, tu.StatusNotOK(re)) + suite.NoError(err) // wrong name name := "dummy" input["delay"] = 30 pauseArgs, err = json.Marshal(input) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/"+name, pauseArgs, tu.StatusNotOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/"+name, pauseArgs, tu.StatusNotOK(re)) + suite.NoError(err) input["delay"] = 0 pauseArgs, err = json.Marshal(input) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/"+name, pauseArgs, tu.StatusNotOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/"+name, pauseArgs, tu.StatusNotOK(re)) + suite.NoError(err) } -func (s *testCheckerSuite) testGetStatus(name string, c *C) { - handler := s.svr.GetHandler() +func (suite *checkerTestSuite) testGetStatus(name string) { + handler := suite.svr.GetHandler() // normal run resp := make(map[string]interface{}) - err := tu.ReadGetJSON(c, testDialClient, fmt.Sprintf("%s/%s", s.urlPrefix, name), &resp) - c.Assert(err, IsNil) - c.Assert(resp["paused"], IsFalse) + re := suite.Require() + err := tu.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", suite.urlPrefix, name), &resp) + suite.NoError(err) + suite.False(resp["paused"].(bool)) // paused err = handler.PauseOrResumeChecker(name, 30) - c.Assert(err, IsNil) + suite.NoError(err) resp = make(map[string]interface{}) - err = tu.ReadGetJSON(c, testDialClient, fmt.Sprintf("%s/%s", s.urlPrefix, name), &resp) - c.Assert(err, IsNil) - c.Assert(resp["paused"], IsTrue) + err = tu.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", suite.urlPrefix, name), &resp) + suite.NoError(err) + suite.True(resp["paused"].(bool)) // resumed err = handler.PauseOrResumeChecker(name, 1) - c.Assert(err, IsNil) + suite.NoError(err) time.Sleep(time.Second) resp = make(map[string]interface{}) - err = tu.ReadGetJSON(c, testDialClient, fmt.Sprintf("%s/%s", s.urlPrefix, name), &resp) - c.Assert(err, IsNil) - c.Assert(resp["paused"], IsFalse) + err = tu.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", suite.urlPrefix, name), &resp) + suite.NoError(err) + suite.False(resp["paused"].(bool)) } -func (s *testCheckerSuite) testPauseOrResume(name string, c *C) { - handler := s.svr.GetHandler() +func (suite *checkerTestSuite) testPauseOrResume(name string) { + handler := suite.svr.GetHandler() input := make(map[string]interface{}) // test pause. input["delay"] = 30 pauseArgs, err := json.Marshal(input) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/"+name, pauseArgs, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + re := suite.Require() + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/"+name, pauseArgs, tu.StatusOK(re)) + suite.NoError(err) isPaused, err := handler.IsCheckerPaused(name) - c.Assert(err, IsNil) - c.Assert(isPaused, IsTrue) + suite.NoError(err) + suite.True(isPaused) input["delay"] = 1 pauseArgs, err = json.Marshal(input) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/"+name, pauseArgs, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/"+name, pauseArgs, tu.StatusOK(re)) + suite.NoError(err) time.Sleep(time.Second) isPaused, err = handler.IsCheckerPaused(name) - c.Assert(err, IsNil) - c.Assert(isPaused, IsFalse) + suite.NoError(err) + suite.False(isPaused) // test resume. input = make(map[string]interface{}) input["delay"] = 30 pauseArgs, err = json.Marshal(input) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/"+name, pauseArgs, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/"+name, pauseArgs, tu.StatusOK(re)) + suite.NoError(err) input["delay"] = 0 pauseArgs, err = json.Marshal(input) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/"+name, pauseArgs, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/"+name, pauseArgs, tu.StatusOK(re)) + suite.NoError(err) isPaused, err = handler.IsCheckerPaused(name) - c.Assert(err, IsNil) - c.Assert(isPaused, IsFalse) + suite.NoError(err) + suite.False(isPaused) } diff --git a/server/api/cluster_test.go b/server/api/cluster_test.go index d1ece1041a1..496d75e6f38 100644 --- a/server/api/cluster_test.go +++ b/server/api/cluster_test.go @@ -16,83 +16,91 @@ package api import ( "fmt" + "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/suite" tu "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server" "github.com/tikv/pd/server/cluster" "github.com/tikv/pd/server/config" ) -var _ = Suite(&testClusterSuite{}) - -type testClusterSuite struct { +type clusterTestSuite struct { + suite.Suite svr *server.Server cleanup cleanUpFunc urlPrefix string } -func (s *testClusterSuite) SetUpSuite(c *C) { - s.svr, s.cleanup = mustNewServer(c) - mustWaitLeader(c, []*server.Server{s.svr}) +func TestClusterTestSuite(t *testing.T) { + suite.Run(t, new(clusterTestSuite)) +} + +func (suite *clusterTestSuite) SetupSuite() { + re := suite.Require() + suite.svr, suite.cleanup = mustNewServer(re) + mustWaitLeader(re, []*server.Server{suite.svr}) - addr := s.svr.GetAddr() - s.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) + addr := suite.svr.GetAddr() + suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) } -func (s *testClusterSuite) TearDownSuite(c *C) { - s.cleanup() +func (suite *clusterTestSuite) TearDownSuite() { + suite.cleanup() } -func (s *testClusterSuite) TestCluster(c *C) { +func (suite *clusterTestSuite) TestCluster() { // Test get cluster status, and bootstrap cluster - s.testGetClusterStatus(c) - s.svr.GetPersistOptions().SetPlacementRuleEnabled(true) - s.svr.GetPersistOptions().GetReplicationConfig().LocationLabels = []string{"host"} - rm := s.svr.GetRaftCluster().GetRuleManager() + suite.testGetClusterStatus() + suite.svr.GetPersistOptions().SetPlacementRuleEnabled(true) + suite.svr.GetPersistOptions().GetReplicationConfig().LocationLabels = []string{"host"} + rm := suite.svr.GetRaftCluster().GetRuleManager() rule := rm.GetRule("pd", "default") rule.LocationLabels = []string{"host"} rule.Count = 1 rm.SetRule(rule) // Test set the config - url := fmt.Sprintf("%s/cluster", s.urlPrefix) + url := fmt.Sprintf("%s/cluster", suite.urlPrefix) c1 := &metapb.Cluster{} - err := tu.ReadGetJSON(c, testDialClient, url, c1) - c.Assert(err, IsNil) + re := suite.Require() + err := tu.ReadGetJSON(re, testDialClient, url, c1) + suite.NoError(err) c2 := &metapb.Cluster{} r := config.ReplicationConfig{ MaxReplicas: 6, EnablePlacementRules: true, } - c.Assert(s.svr.SetReplicationConfig(r), IsNil) - err = tu.ReadGetJSON(c, testDialClient, url, c2) - c.Assert(err, IsNil) + suite.NoError(suite.svr.SetReplicationConfig(r)) + + err = tu.ReadGetJSON(re, testDialClient, url, c2) + suite.NoError(err) c1.MaxPeerCount = 6 - c.Assert(c1, DeepEquals, c2) - c.Assert(int(r.MaxReplicas), Equals, s.svr.GetRaftCluster().GetRuleManager().GetRule("pd", "default").Count) + suite.Equal(c2, c1) + suite.Equal(int(r.MaxReplicas), suite.svr.GetRaftCluster().GetRuleManager().GetRule("pd", "default").Count) } -func (s *testClusterSuite) testGetClusterStatus(c *C) { - url := fmt.Sprintf("%s/cluster/status", s.urlPrefix) +func (suite *clusterTestSuite) testGetClusterStatus() { + url := fmt.Sprintf("%s/cluster/status", suite.urlPrefix) status := cluster.Status{} - err := tu.ReadGetJSON(c, testDialClient, url, &status) - c.Assert(err, IsNil) - c.Assert(status.RaftBootstrapTime.IsZero(), IsTrue) - c.Assert(status.IsInitialized, IsFalse) + re := suite.Require() + err := tu.ReadGetJSON(re, testDialClient, url, &status) + suite.NoError(err) + suite.True(status.RaftBootstrapTime.IsZero()) + suite.False(status.IsInitialized) now := time.Now() - mustBootstrapCluster(c, s.svr) - err = tu.ReadGetJSON(c, testDialClient, url, &status) - c.Assert(err, IsNil) - c.Assert(status.RaftBootstrapTime.After(now), IsTrue) - c.Assert(status.IsInitialized, IsFalse) - s.svr.SetReplicationConfig(config.ReplicationConfig{MaxReplicas: 1}) - err = tu.ReadGetJSON(c, testDialClient, url, &status) - c.Assert(err, IsNil) - c.Assert(status.RaftBootstrapTime.After(now), IsTrue) - c.Assert(status.IsInitialized, IsTrue) + mustBootstrapCluster(re, suite.svr) + err = tu.ReadGetJSON(re, testDialClient, url, &status) + suite.NoError(err) + suite.True(status.RaftBootstrapTime.After(now)) + suite.False(status.IsInitialized) + suite.svr.SetReplicationConfig(config.ReplicationConfig{MaxReplicas: 1}) + err = tu.ReadGetJSON(re, testDialClient, url, &status) + suite.NoError(err) + suite.True(status.RaftBootstrapTime.After(now)) + suite.True(status.IsInitialized) } diff --git a/server/api/config_test.go b/server/api/config_test.go index 7abfafd04a6..144e511979a 100644 --- a/server/api/config_test.go +++ b/server/api/config_test.go @@ -17,9 +17,10 @@ package api import ( "encoding/json" "fmt" + "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/suite" tu "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/pkg/typeutil" "github.com/tikv/pd/server" @@ -27,65 +28,70 @@ import ( "github.com/tikv/pd/server/versioninfo" ) -var _ = Suite(&testConfigSuite{}) - -type testConfigSuite struct { +type configTestSuite struct { + suite.Suite svr *server.Server cleanup cleanUpFunc urlPrefix string } -func (s *testConfigSuite) SetUpSuite(c *C) { - s.svr, s.cleanup = mustNewServer(c, func(cfg *config.Config) { +func TestConfigTestSuite(t *testing.T) { + suite.Run(t, new(configTestSuite)) +} + +func (suite *configTestSuite) SetupSuite() { + re := suite.Require() + suite.svr, suite.cleanup = mustNewServer(re, func(cfg *config.Config) { cfg.Replication.EnablePlacementRules = false }) - mustWaitLeader(c, []*server.Server{s.svr}) + mustWaitLeader(re, []*server.Server{suite.svr}) - addr := s.svr.GetAddr() - s.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) + addr := suite.svr.GetAddr() + suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) } -func (s *testConfigSuite) TearDownSuite(c *C) { - s.cleanup() +func (suite *configTestSuite) TearDownSuite() { + suite.cleanup() } -func (s *testConfigSuite) TestConfigAll(c *C) { - addr := fmt.Sprintf("%s/config", s.urlPrefix) +func (suite *configTestSuite) TestConfigAll() { + re := suite.Require() + addr := fmt.Sprintf("%s/config", suite.urlPrefix) cfg := &config.Config{} - err := tu.ReadGetJSON(c, testDialClient, addr, cfg) - c.Assert(err, IsNil) + err := tu.ReadGetJSON(re, testDialClient, addr, cfg) + suite.NoError(err) // the original way r := map[string]int{"max-replicas": 5} postData, err := json.Marshal(r) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + suite.NoError(err) l := map[string]interface{}{ "location-labels": "zone,rack", "region-schedule-limit": 10, } postData, err = json.Marshal(l) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + suite.NoError(err) l = map[string]interface{}{ "metric-storage": "http://127.0.0.1:9090", } postData, err = json.Marshal(l) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + suite.NoError(err) newCfg := &config.Config{} - err = tu.ReadGetJSON(c, testDialClient, addr, newCfg) - c.Assert(err, IsNil) + err = tu.ReadGetJSON(re, testDialClient, addr, newCfg) + suite.NoError(err) cfg.Replication.MaxReplicas = 5 cfg.Replication.LocationLabels = []string{"zone", "rack"} cfg.Schedule.RegionScheduleLimit = 10 cfg.PDServerCfg.MetricStorage = "http://127.0.0.1:9090" - c.Assert(cfg, DeepEquals, newCfg) + suite.Equal(newCfg, cfg) // the new way l = map[string]interface{}{ @@ -98,12 +104,12 @@ func (s *testConfigSuite) TestConfigAll(c *C) { "replication-mode.dr-auto-sync.label-key": "foobar", } postData, err = json.Marshal(l) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + suite.NoError(err) newCfg1 := &config.Config{} - err = tu.ReadGetJSON(c, testDialClient, addr, newCfg1) - c.Assert(err, IsNil) + err = tu.ReadGetJSON(re, testDialClient, addr, newCfg1) + suite.NoError(err) cfg.Schedule.TolerantSizeRatio = 2.5 cfg.Replication.LocationLabels = []string{"idc", "host"} cfg.PDServerCfg.MetricStorage = "http://127.0.0.1:1234" @@ -111,109 +117,110 @@ func (s *testConfigSuite) TestConfigAll(c *C) { cfg.ReplicationMode.DRAutoSync.LabelKey = "foobar" cfg.ReplicationMode.ReplicationMode = "dr-auto-sync" v, err := versioninfo.ParseVersion("v4.0.0-beta") - c.Assert(err, IsNil) + suite.NoError(err) cfg.ClusterVersion = *v - c.Assert(newCfg1, DeepEquals, cfg) + suite.Equal(cfg, newCfg1) postData, err = json.Marshal(l) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + suite.NoError(err) // illegal prefix l = map[string]interface{}{ "replicate.max-replicas": 1, } postData, err = json.Marshal(l) - c.Assert(err, IsNil) + suite.NoError(err) err = tu.CheckPostJSON(testDialClient, addr, postData, - tu.StatusNotOK(c), - tu.StringContain(c, "not found")) - c.Assert(err, IsNil) + tu.StatusNotOK(re), + tu.StringContain(re, "not found")) + suite.NoError(err) // update prefix directly l = map[string]interface{}{ "replication-mode": nil, } postData, err = json.Marshal(l) - c.Assert(err, IsNil) + suite.NoError(err) err = tu.CheckPostJSON(testDialClient, addr, postData, - tu.StatusNotOK(c), - tu.StringContain(c, "cannot update config prefix")) - c.Assert(err, IsNil) + tu.StatusNotOK(re), + tu.StringContain(re, "cannot update config prefix")) + suite.NoError(err) // config item not found l = map[string]interface{}{ "schedule.region-limit": 10, } postData, err = json.Marshal(l) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusNotOK(c), tu.StringContain(c, "not found")) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusNotOK(re), tu.StringContain(re, "not found")) + suite.NoError(err) } -func (s *testConfigSuite) TestConfigSchedule(c *C) { - addr := fmt.Sprintf("%s/config/schedule", s.urlPrefix) +func (suite *configTestSuite) TestConfigSchedule() { + re := suite.Require() + addr := fmt.Sprintf("%s/config/schedule", suite.urlPrefix) sc := &config.ScheduleConfig{} - c.Assert(tu.ReadGetJSON(c, testDialClient, addr, sc), IsNil) - + suite.NoError(tu.ReadGetJSON(re, testDialClient, addr, sc)) sc.MaxStoreDownTime.Duration = time.Second postData, err := json.Marshal(sc) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + suite.NoError(err) sc1 := &config.ScheduleConfig{} - c.Assert(tu.ReadGetJSON(c, testDialClient, addr, sc1), IsNil) - c.Assert(*sc, DeepEquals, *sc1) + suite.NoError(tu.ReadGetJSON(re, testDialClient, addr, sc1)) + suite.Equal(*sc1, *sc) } -func (s *testConfigSuite) TestConfigReplication(c *C) { - addr := fmt.Sprintf("%s/config/replicate", s.urlPrefix) +func (suite *configTestSuite) TestConfigReplication() { + re := suite.Require() + addr := fmt.Sprintf("%s/config/replicate", suite.urlPrefix) rc := &config.ReplicationConfig{} - err := tu.ReadGetJSON(c, testDialClient, addr, rc) - c.Assert(err, IsNil) + err := tu.ReadGetJSON(re, testDialClient, addr, rc) + suite.NoError(err) rc.MaxReplicas = 5 rc1 := map[string]int{"max-replicas": 5} postData, err := json.Marshal(rc1) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + suite.NoError(err) rc.LocationLabels = []string{"zone", "rack"} rc2 := map[string]string{"location-labels": "zone,rack"} postData, err = json.Marshal(rc2) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + suite.NoError(err) rc.IsolationLevel = "zone" rc3 := map[string]string{"isolation-level": "zone"} postData, err = json.Marshal(rc3) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + suite.NoError(err) rc4 := &config.ReplicationConfig{} - err = tu.ReadGetJSON(c, testDialClient, addr, rc4) - c.Assert(err, IsNil) + err = tu.ReadGetJSON(re, testDialClient, addr, rc4) + suite.NoError(err) - c.Assert(*rc, DeepEquals, *rc4) + suite.Equal(*rc4, *rc) } -func (s *testConfigSuite) TestConfigLabelProperty(c *C) { - addr := s.svr.GetAddr() + apiPrefix + "/api/v1/config/label-property" - +func (suite *configTestSuite) TestConfigLabelProperty() { + re := suite.Require() + addr := suite.svr.GetAddr() + apiPrefix + "/api/v1/config/label-property" loadProperties := func() config.LabelPropertyConfig { var cfg config.LabelPropertyConfig - err := tu.ReadGetJSON(c, testDialClient, addr, &cfg) - c.Assert(err, IsNil) + err := tu.ReadGetJSON(re, testDialClient, addr, &cfg) + suite.NoError(err) return cfg } cfg := loadProperties() - c.Assert(cfg, HasLen, 0) + suite.Len(cfg, 0) cmds := []string{ `{"type": "foo", "action": "set", "label-key": "zone", "label-value": "cn1"}`, @@ -221,90 +228,89 @@ func (s *testConfigSuite) TestConfigLabelProperty(c *C) { `{"type": "bar", "action": "set", "label-key": "host", "label-value": "h1"}`, } for _, cmd := range cmds { - err := tu.CheckPostJSON(testDialClient, addr, []byte(cmd), tu.StatusOK(c)) - c.Assert(err, IsNil) + err := tu.CheckPostJSON(testDialClient, addr, []byte(cmd), tu.StatusOK(re)) + suite.NoError(err) } cfg = loadProperties() - c.Assert(cfg, HasLen, 2) - c.Assert(cfg["foo"], DeepEquals, []config.StoreLabel{ + suite.Len(cfg, 2) + suite.Equal([]config.StoreLabel{ {Key: "zone", Value: "cn1"}, {Key: "zone", Value: "cn2"}, - }) - c.Assert(cfg["bar"], DeepEquals, []config.StoreLabel{{Key: "host", Value: "h1"}}) + }, cfg["foo"]) + suite.Equal([]config.StoreLabel{{Key: "host", Value: "h1"}}, cfg["bar"]) cmds = []string{ `{"type": "foo", "action": "delete", "label-key": "zone", "label-value": "cn1"}`, `{"type": "bar", "action": "delete", "label-key": "host", "label-value": "h1"}`, } for _, cmd := range cmds { - err := tu.CheckPostJSON(testDialClient, addr, []byte(cmd), tu.StatusOK(c)) - c.Assert(err, IsNil) + err := tu.CheckPostJSON(testDialClient, addr, []byte(cmd), tu.StatusOK(re)) + suite.NoError(err) } cfg = loadProperties() - c.Assert(cfg, HasLen, 1) - c.Assert(cfg["foo"], DeepEquals, []config.StoreLabel{{Key: "zone", Value: "cn2"}}) + suite.Len(cfg, 1) + suite.Equal([]config.StoreLabel{{Key: "zone", Value: "cn2"}}, cfg["foo"]) } -func (s *testConfigSuite) TestConfigDefault(c *C) { - addr := fmt.Sprintf("%s/config", s.urlPrefix) +func (suite *configTestSuite) TestConfigDefault() { + addr := fmt.Sprintf("%s/config", suite.urlPrefix) r := map[string]int{"max-replicas": 5} postData, err := json.Marshal(r) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + re := suite.Require() + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + suite.NoError(err) l := map[string]interface{}{ "location-labels": "zone,rack", "region-schedule-limit": 10, } postData, err = json.Marshal(l) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + suite.NoError(err) l = map[string]interface{}{ "metric-storage": "http://127.0.0.1:9090", } postData, err = json.Marshal(l) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + suite.NoError(err) - addr = fmt.Sprintf("%s/config/default", s.urlPrefix) + addr = fmt.Sprintf("%s/config/default", suite.urlPrefix) defaultCfg := &config.Config{} - err = tu.ReadGetJSON(c, testDialClient, addr, defaultCfg) - c.Assert(err, IsNil) + err = tu.ReadGetJSON(re, testDialClient, addr, defaultCfg) + suite.NoError(err) - c.Assert(defaultCfg.Replication.MaxReplicas, Equals, uint64(3)) - c.Assert(defaultCfg.Replication.LocationLabels, DeepEquals, typeutil.StringSlice([]string{})) - c.Assert(defaultCfg.Schedule.RegionScheduleLimit, Equals, uint64(2048)) - c.Assert(defaultCfg.PDServerCfg.MetricStorage, Equals, "") + suite.Equal(uint64(3), defaultCfg.Replication.MaxReplicas) + suite.Equal(typeutil.StringSlice([]string{}), defaultCfg.Replication.LocationLabels) + suite.Equal(uint64(2048), defaultCfg.Schedule.RegionScheduleLimit) + suite.Equal("", defaultCfg.PDServerCfg.MetricStorage) } -func (s *testConfigSuite) TestConfigPDServer(c *C) { - addrPost := fmt.Sprintf("%s/config", s.urlPrefix) - +func (suite *configTestSuite) TestConfigPDServer() { + re := suite.Require() + addrPost := fmt.Sprintf("%s/config", suite.urlPrefix) ms := map[string]interface{}{ "metric-storage": "", } postData, err := json.Marshal(ms) - c.Assert(err, IsNil) - c.Assert(tu.CheckPostJSON(testDialClient, addrPost, postData, tu.StatusOK(c)), IsNil) - - addrGet := fmt.Sprintf("%s/config/pd-server", s.urlPrefix) + suite.NoError(err) + suite.NoError(tu.CheckPostJSON(testDialClient, addrPost, postData, tu.StatusOK(re))) + addrGet := fmt.Sprintf("%s/config/pd-server", suite.urlPrefix) sc := &config.PDServerConfig{} - c.Assert(tu.ReadGetJSON(c, testDialClient, addrGet, sc), IsNil) - - c.Assert(sc.UseRegionStorage, Equals, bool(true)) - c.Assert(sc.KeyType, Equals, "table") - c.Assert(sc.RuntimeServices, DeepEquals, typeutil.StringSlice([]string{})) - c.Assert(sc.MetricStorage, Equals, "") - c.Assert(sc.DashboardAddress, Equals, "auto") - c.Assert(sc.FlowRoundByDigit, Equals, int(3)) - c.Assert(sc.MinResolvedTSPersistenceInterval, Equals, typeutil.NewDuration(0)) - c.Assert(sc.MaxResetTSGap.Duration, Equals, 24*time.Hour) + suite.NoError(tu.ReadGetJSON(re, testDialClient, addrGet, sc)) + suite.Equal(bool(true), sc.UseRegionStorage) + suite.Equal("table", sc.KeyType) + suite.Equal(typeutil.StringSlice([]string{}), sc.RuntimeServices) + suite.Equal("", sc.MetricStorage) + suite.Equal("auto", sc.DashboardAddress) + suite.Equal(int(3), sc.FlowRoundByDigit) + suite.Equal(typeutil.NewDuration(0), sc.MinResolvedTSPersistenceInterval) + suite.Equal(24*time.Hour, sc.MaxResetTSGap.Duration) } var ttlConfig = map[string]interface{}{ @@ -324,89 +330,94 @@ var invalidTTLConfig = map[string]interface{}{ "schedule.invalid-ttl-config": 0, } -func assertTTLConfig(c *C, options *config.PersistOptions, checker Checker) { - c.Assert(options.GetMaxSnapshotCount(), checker, uint64(999)) - c.Assert(options.IsLocationReplacementEnabled(), checker, false) - c.Assert(options.GetMaxMergeRegionSize(), checker, uint64(999)) - c.Assert(options.GetMaxMergeRegionKeys(), checker, uint64(999)) - c.Assert(options.GetSchedulerMaxWaitingOperator(), checker, uint64(999)) - c.Assert(options.GetLeaderScheduleLimit(), checker, uint64(999)) - c.Assert(options.GetRegionScheduleLimit(), checker, uint64(999)) - c.Assert(options.GetHotRegionScheduleLimit(), checker, uint64(999)) - c.Assert(options.GetReplicaScheduleLimit(), checker, uint64(999)) - c.Assert(options.GetMergeScheduleLimit(), checker, uint64(999)) +func assertTTLConfig( + options *config.PersistOptions, + equality func(interface{}, interface{}, ...interface{}) bool, +) { + equality(uint64(999), options.GetMaxSnapshotCount()) + equality(false, options.IsLocationReplacementEnabled()) + equality(uint64(999), options.GetMaxMergeRegionSize()) + equality(uint64(999), options.GetMaxMergeRegionKeys()) + equality(uint64(999), options.GetSchedulerMaxWaitingOperator()) + equality(uint64(999), options.GetLeaderScheduleLimit()) + equality(uint64(999), options.GetRegionScheduleLimit()) + equality(uint64(999), options.GetHotRegionScheduleLimit()) + equality(uint64(999), options.GetReplicaScheduleLimit()) + equality(uint64(999), options.GetMergeScheduleLimit()) } func createTTLUrl(url string, ttl int) string { return fmt.Sprintf("%s/config?ttlSecond=%d", url, ttl) } -func (s *testConfigSuite) TestConfigTTL(c *C) { +func (suite *configTestSuite) TestConfigTTL() { postData, err := json.Marshal(ttlConfig) - c.Assert(err, IsNil) + suite.NoError(err) // test no config and cleaning up - err = tu.CheckPostJSON(testDialClient, createTTLUrl(s.urlPrefix, 0), postData, tu.StatusOK(c)) - c.Assert(err, IsNil) - assertTTLConfig(c, s.svr.GetPersistOptions(), Not(Equals)) + re := suite.Require() + err = tu.CheckPostJSON(testDialClient, createTTLUrl(suite.urlPrefix, 0), postData, tu.StatusOK(re)) + suite.NoError(err) + assertTTLConfig(suite.svr.GetPersistOptions(), suite.NotEqual) // test time goes by - err = tu.CheckPostJSON(testDialClient, createTTLUrl(s.urlPrefix, 1), postData, tu.StatusOK(c)) - c.Assert(err, IsNil) - assertTTLConfig(c, s.svr.GetPersistOptions(), Equals) + err = tu.CheckPostJSON(testDialClient, createTTLUrl(suite.urlPrefix, 1), postData, tu.StatusOK(re)) + suite.NoError(err) + assertTTLConfig(suite.svr.GetPersistOptions(), suite.Equal) time.Sleep(2 * time.Second) - assertTTLConfig(c, s.svr.GetPersistOptions(), Not(Equals)) + assertTTLConfig(suite.svr.GetPersistOptions(), suite.NotEqual) // test cleaning up - err = tu.CheckPostJSON(testDialClient, createTTLUrl(s.urlPrefix, 1), postData, tu.StatusOK(c)) - c.Assert(err, IsNil) - assertTTLConfig(c, s.svr.GetPersistOptions(), Equals) - err = tu.CheckPostJSON(testDialClient, createTTLUrl(s.urlPrefix, 0), postData, tu.StatusOK(c)) - c.Assert(err, IsNil) - assertTTLConfig(c, s.svr.GetPersistOptions(), Not(Equals)) + err = tu.CheckPostJSON(testDialClient, createTTLUrl(suite.urlPrefix, 1), postData, tu.StatusOK(re)) + suite.NoError(err) + assertTTLConfig(suite.svr.GetPersistOptions(), suite.Equal) + err = tu.CheckPostJSON(testDialClient, createTTLUrl(suite.urlPrefix, 0), postData, tu.StatusOK(re)) + suite.NoError(err) + assertTTLConfig(suite.svr.GetPersistOptions(), suite.NotEqual) postData, err = json.Marshal(invalidTTLConfig) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, createTTLUrl(s.urlPrefix, 1), postData, - tu.StatusNotOK(c), tu.StringEqual(c, "\"unsupported ttl config schedule.invalid-ttl-config\"\n")) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, createTTLUrl(suite.urlPrefix, 1), postData, + tu.StatusNotOK(re), tu.StringEqual(re, "\"unsupported ttl config schedule.invalid-ttl-config\"\n")) + suite.NoError(err) // only set max-merge-region-size mergeConfig := map[string]interface{}{ "schedule.max-merge-region-size": 999, } postData, err = json.Marshal(mergeConfig) - c.Assert(err, IsNil) + suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, createTTLUrl(s.urlPrefix, 1), postData, tu.StatusOK(c)) - c.Assert(err, IsNil) - c.Assert(s.svr.GetPersistOptions().GetMaxMergeRegionSize(), Equals, uint64(999)) + err = tu.CheckPostJSON(testDialClient, createTTLUrl(suite.urlPrefix, 1), postData, tu.StatusOK(re)) + suite.NoError(err) + suite.Equal(uint64(999), suite.svr.GetPersistOptions().GetMaxMergeRegionSize()) // max-merge-region-keys should keep consistence with max-merge-region-size. - c.Assert(s.svr.GetPersistOptions().GetMaxMergeRegionKeys(), Equals, uint64(999*10000)) + suite.Equal(uint64(999*10000), suite.svr.GetPersistOptions().GetMaxMergeRegionKeys()) } -func (s *testConfigSuite) TestTTLConflict(c *C) { - addr := createTTLUrl(s.urlPrefix, 1) +func (suite *configTestSuite) TestTTLConflict() { + addr := createTTLUrl(suite.urlPrefix, 1) postData, err := json.Marshal(ttlConfig) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(c)) - c.Assert(err, IsNil) - assertTTLConfig(c, s.svr.GetPersistOptions(), Equals) + suite.NoError(err) + re := suite.Require() + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + suite.NoError(err) + assertTTLConfig(suite.svr.GetPersistOptions(), suite.Equal) cfg := map[string]interface{}{"max-snapshot-count": 30} postData, err = json.Marshal(cfg) - c.Assert(err, IsNil) - addr = fmt.Sprintf("%s/config", s.urlPrefix) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusNotOK(c), tu.StringEqual(c, "\"need to clean up TTL first for schedule.max-snapshot-count\"\n")) - c.Assert(err, IsNil) - addr = fmt.Sprintf("%s/config/schedule", s.urlPrefix) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusNotOK(c), tu.StringEqual(c, "\"need to clean up TTL first for schedule.max-snapshot-count\"\n")) - c.Assert(err, IsNil) + suite.NoError(err) + addr = fmt.Sprintf("%s/config", suite.urlPrefix) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusNotOK(re), tu.StringEqual(re, "\"need to clean up TTL first for schedule.max-snapshot-count\"\n")) + suite.NoError(err) + addr = fmt.Sprintf("%s/config/schedule", suite.urlPrefix) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusNotOK(re), tu.StringEqual(re, "\"need to clean up TTL first for schedule.max-snapshot-count\"\n")) + suite.NoError(err) cfg = map[string]interface{}{"schedule.max-snapshot-count": 30} postData, err = json.Marshal(cfg) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, createTTLUrl(s.urlPrefix, 0), postData, tu.StatusOK(c)) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, createTTLUrl(suite.urlPrefix, 0), postData, tu.StatusOK(re)) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + suite.NoError(err) } diff --git a/server/api/etcd_api_test.go b/server/api/etcd_api_test.go index 76e0dcb3af5..23406bb64cd 100644 --- a/server/api/etcd_api_test.go +++ b/server/api/etcd_api_test.go @@ -16,27 +16,25 @@ package api import ( "encoding/json" + "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" tu "github.com/tikv/pd/pkg/testutil" ) -var _ = Suite(&testEtcdAPISuite{}) - -type testEtcdAPISuite struct{} - -func (s *testEtcdAPISuite) TestGRPCGateway(c *C) { - svr, clean := mustNewServer(c) +func TestGRPCGateway(t *testing.T) { + re := require.New(t) + svr, clean := mustNewServer(re) defer clean() addr := svr.GetConfig().ClientUrls + "/v3/kv/put" putKey := map[string]string{"key": "Zm9v", "value": "YmFy"} v, _ := json.Marshal(putKey) - err := tu.CheckPostJSON(testDialClient, addr, v, tu.StatusOK(c)) - c.Assert(err, IsNil) + err := tu.CheckPostJSON(testDialClient, addr, v, tu.StatusOK(re)) + re.NoError(err) addr = svr.GetConfig().ClientUrls + "/v3/kv/range" getKey := map[string]string{"key": "Zm9v"} v, _ = json.Marshal(getKey) - err = tu.CheckPostJSON(testDialClient, addr, v, tu.StatusOK(c), tu.StringContain(c, "Zm9v")) - c.Assert(err, IsNil) + err = tu.CheckPostJSON(testDialClient, addr, v, tu.StatusOK(re), tu.StringContain(re, "Zm9v")) + re.NoError(err) } diff --git a/server/api/health_test.go b/server/api/health_test.go index ec3ac54be00..6d2caec12cd 100644 --- a/server/api/health_test.go +++ b/server/api/health_test.go @@ -18,38 +18,36 @@ import ( "encoding/json" "io" "strings" + "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/server" "github.com/tikv/pd/server/config" ) -var _ = Suite(&testHealthAPISuite{}) - -type testHealthAPISuite struct{} - -func checkSliceResponse(c *C, body []byte, cfgs []*config.Config, unhealthy string) { +func checkSliceResponse(re *require.Assertions, body []byte, cfgs []*config.Config, unhealthy string) { got := []Health{} - c.Assert(json.Unmarshal(body, &got), IsNil) - c.Assert(len(got), Equals, len(cfgs)) + re.NoError(json.Unmarshal(body, &got)) + re.Len(cfgs, len(got)) for _, h := range got { for _, cfg := range cfgs { if h.Name != cfg.Name { continue } - relaxEqualStings(c, h.ClientUrls, strings.Split(cfg.ClientUrls, ",")) + relaxEqualStings(re, h.ClientUrls, strings.Split(cfg.ClientUrls, ",")) } if h.Name == unhealthy { - c.Assert(h.Health, IsFalse) + re.False(h.Health) continue } - c.Assert(h.Health, IsTrue) + re.True(h.Health) } } -func (s *testHealthAPISuite) TestHealthSlice(c *C) { - cfgs, svrs, clean := mustNewCluster(c, 3) +func TestHealthSlice(t *testing.T) { + re := require.New(t) + cfgs, svrs, clean := mustNewCluster(re, 3) defer clean() var leader, follow *server.Server @@ -60,13 +58,13 @@ func (s *testHealthAPISuite) TestHealthSlice(c *C) { follow = svr } } - mustBootstrapCluster(c, leader) + mustBootstrapCluster(re, leader) addr := leader.GetConfig().ClientUrls + apiPrefix + "/api/v1/health" follow.Close() resp, err := testDialClient.Get(addr) - c.Assert(err, IsNil) + re.NoError(err) defer resp.Body.Close() buf, err := io.ReadAll(resp.Body) - c.Assert(err, IsNil) - checkSliceResponse(c, buf, cfgs, follow.GetConfig().Name) + re.NoError(err) + checkSliceResponse(re, buf, cfgs, follow.GetConfig().Name) } diff --git a/server/api/hot_status_test.go b/server/api/hot_status_test.go index 18af694fce6..66a4e29afb7 100644 --- a/server/api/hot_status_test.go +++ b/server/api/hot_status_test.go @@ -17,10 +17,10 @@ package api import ( "encoding/json" "fmt" - "reflect" + "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/suite" "github.com/syndtr/goleveldb/leveldb" tu "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server" @@ -29,50 +29,55 @@ import ( "github.com/tikv/pd/server/storage/kv" ) -var _ = Suite(&testHotStatusSuite{}) - -type testHotStatusSuite struct { +type hotStatusTestSuite struct { + suite.Suite svr *server.Server cleanup cleanUpFunc urlPrefix string } -func (s *testHotStatusSuite) SetUpSuite(c *C) { - s.svr, s.cleanup = mustNewServer(c) - mustWaitLeader(c, []*server.Server{s.svr}) +func TestHotStatusTestSuite(t *testing.T) { + suite.Run(t, new(hotStatusTestSuite)) +} + +func (suite *hotStatusTestSuite) SetupSuite() { + re := suite.Require() + suite.svr, suite.cleanup = mustNewServer(re) + mustWaitLeader(re, []*server.Server{suite.svr}) - addr := s.svr.GetAddr() - s.urlPrefix = fmt.Sprintf("%s%s/api/v1/hotspot", addr, apiPrefix) + addr := suite.svr.GetAddr() + suite.urlPrefix = fmt.Sprintf("%s%s/api/v1/hotspot", addr, apiPrefix) - mustBootstrapCluster(c, s.svr) + mustBootstrapCluster(re, suite.svr) } -func (s *testHotStatusSuite) TearDownSuite(c *C) { - s.cleanup() +func (suite *hotStatusTestSuite) TearDownSuite() { + suite.cleanup() } -func (s testHotStatusSuite) TestGetHotStore(c *C) { +func (suite *hotStatusTestSuite) TestGetHotStore() { stat := HotStoreStats{} - err := tu.ReadGetJSON(c, testDialClient, s.urlPrefix+"/stores", &stat) - c.Assert(err, IsNil) + err := tu.ReadGetJSON(suite.Require(), testDialClient, suite.urlPrefix+"/stores", &stat) + suite.NoError(err) } -func (s testHotStatusSuite) TestGetHistoryHotRegionsBasic(c *C) { +func (suite *hotStatusTestSuite) TestGetHistoryHotRegionsBasic() { request := HistoryHotRegionsRequest{ StartTime: 0, EndTime: time.Now().AddDate(0, 2, 0).UnixNano() / int64(time.Millisecond), } data, err := json.Marshal(request) - c.Assert(err, IsNil) - err = tu.CheckGetJSON(testDialClient, s.urlPrefix+"/regions/history", data, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + re := suite.Require() + err = tu.CheckGetJSON(testDialClient, suite.urlPrefix+"/regions/history", data, tu.StatusOK(re)) + suite.NoError(err) errRequest := "{\"start_time\":\"err\"}" - err = tu.CheckGetJSON(testDialClient, s.urlPrefix+"/regions/history", []byte(errRequest), tu.StatusNotOK(c)) - c.Assert(err, IsNil) + err = tu.CheckGetJSON(testDialClient, suite.urlPrefix+"/regions/history", []byte(errRequest), tu.StatusNotOK(re)) + suite.NoError(err) } -func (s testHotStatusSuite) TestGetHistoryHotRegionsTimeRange(c *C) { - hotRegionStorage := s.svr.GetHistoryHotRegionStorage() +func (suite *hotStatusTestSuite) TestGetHistoryHotRegionsTimeRange() { + hotRegionStorage := suite.svr.GetHistoryHotRegionStorage() now := time.Now() hotRegions := []*storage.HistoryHotRegion{ { @@ -89,24 +94,24 @@ func (s testHotStatusSuite) TestGetHistoryHotRegionsTimeRange(c *C) { EndTime: now.Add(10*time.Second).UnixNano() / int64(time.Millisecond), } check := func(res []byte, statusCode int) { - c.Assert(statusCode, Equals, 200) + suite.Equal(200, statusCode) historyHotRegions := &storage.HistoryHotRegions{} json.Unmarshal(res, historyHotRegions) for _, region := range historyHotRegions.HistoryHotRegion { - c.Assert(region.UpdateTime, GreaterEqual, request.StartTime) - c.Assert(region.UpdateTime, LessEqual, request.EndTime) + suite.GreaterOrEqual(region.UpdateTime, request.StartTime) + suite.LessOrEqual(region.UpdateTime, request.EndTime) } } err := writeToDB(hotRegionStorage.LevelDBKV, hotRegions) - c.Assert(err, IsNil) + suite.NoError(err) data, err := json.Marshal(request) - c.Assert(err, IsNil) - err = tu.CheckGetJSON(testDialClient, s.urlPrefix+"/regions/history", data, check) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckGetJSON(testDialClient, suite.urlPrefix+"/regions/history", data, check) + suite.NoError(err) } -func (s testHotStatusSuite) TestGetHistoryHotRegionsIDAndTypes(c *C) { - hotRegionStorage := s.svr.GetHistoryHotRegionStorage() +func (suite *hotStatusTestSuite) TestGetHistoryHotRegionsIDAndTypes() { + hotRegionStorage := suite.svr.GetHistoryHotRegionStorage() now := time.Now() hotRegions := []*storage.HistoryHotRegion{ { @@ -174,18 +179,18 @@ func (s testHotStatusSuite) TestGetHistoryHotRegionsIDAndTypes(c *C) { EndTime: now.Add(10*time.Minute).UnixNano() / int64(time.Millisecond), } check := func(res []byte, statusCode int) { - c.Assert(statusCode, Equals, 200) + suite.Equal(200, statusCode) historyHotRegions := &storage.HistoryHotRegions{} json.Unmarshal(res, historyHotRegions) - c.Assert(historyHotRegions.HistoryHotRegion, HasLen, 1) - c.Assert(reflect.DeepEqual(historyHotRegions.HistoryHotRegion[0], hotRegions[0]), IsTrue) + suite.Len(historyHotRegions.HistoryHotRegion, 1) + suite.Equal(hotRegions[0], historyHotRegions.HistoryHotRegion[0]) } err := writeToDB(hotRegionStorage.LevelDBKV, hotRegions) - c.Assert(err, IsNil) + suite.NoError(err) data, err := json.Marshal(request) - c.Assert(err, IsNil) - err = tu.CheckGetJSON(testDialClient, s.urlPrefix+"/regions/history", data, check) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckGetJSON(testDialClient, suite.urlPrefix+"/regions/history", data, check) + suite.NoError(err) } func writeToDB(kv *kv.LevelDBKV, hotRegions []*storage.HistoryHotRegion) error { diff --git a/server/api/label_test.go b/server/api/label_test.go index 12933d17e64..b9503871a5a 100644 --- a/server/api/label_test.go +++ b/server/api/label_test.go @@ -17,28 +17,30 @@ package api import ( "context" "fmt" - "strings" + "testing" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/suite" tu "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server" "github.com/tikv/pd/server/config" ) -var _ = Suite(&testLabelsStoreSuite{}) -var _ = Suite(&testStrictlyLabelsStoreSuite{}) - -type testLabelsStoreSuite struct { +type labelsStoreTestSuite struct { + suite.Suite svr *server.Server cleanup cleanUpFunc urlPrefix string stores []*metapb.Store } -func (s *testLabelsStoreSuite) SetUpSuite(c *C) { - s.stores = []*metapb.Store{ +func TestLabelsStoreTestSuite(t *testing.T) { + suite.Run(t, new(labelsStoreTestSuite)) +} + +func (suite *labelsStoreTestSuite) SetupSuite() { + suite.stores = []*metapb.Store{ { Id: 1, Address: "tikv1", @@ -113,58 +115,58 @@ func (s *testLabelsStoreSuite) SetUpSuite(c *C) { }, } - s.svr, s.cleanup = mustNewServer(c, func(cfg *config.Config) { + re := suite.Require() + suite.svr, suite.cleanup = mustNewServer(re, func(cfg *config.Config) { cfg.Replication.StrictlyMatchLabel = false }) - mustWaitLeader(c, []*server.Server{s.svr}) + mustWaitLeader(re, []*server.Server{suite.svr}) - addr := s.svr.GetAddr() - s.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) + addr := suite.svr.GetAddr() + suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) - mustBootstrapCluster(c, s.svr) - for _, store := range s.stores { - mustPutStore(c, s.svr, store.Id, store.State, store.NodeState, store.Labels) + mustBootstrapCluster(re, suite.svr) + for _, store := range suite.stores { + mustPutStore(re, suite.svr, store.Id, store.State, store.NodeState, store.Labels) } } -func (s *testLabelsStoreSuite) TearDownSuite(c *C) { - s.cleanup() +func (suite *labelsStoreTestSuite) TearDownSuite() { + suite.cleanup() } -func (s *testLabelsStoreSuite) TestLabelsGet(c *C) { - url := fmt.Sprintf("%s/labels", s.urlPrefix) - labels := make([]*metapb.StoreLabel, 0, len(s.stores)) - err := tu.ReadGetJSON(c, testDialClient, url, &labels) - c.Assert(err, IsNil) +func (suite *labelsStoreTestSuite) TestLabelsGet() { + url := fmt.Sprintf("%s/labels", suite.urlPrefix) + labels := make([]*metapb.StoreLabel, 0, len(suite.stores)) + suite.NoError(tu.ReadGetJSON(suite.Require(), testDialClient, url, &labels)) } -func (s *testLabelsStoreSuite) TestStoresLabelFilter(c *C) { +func (suite *labelsStoreTestSuite) TestStoresLabelFilter() { var table = []struct { name, value string want []*metapb.Store }{ { name: "Zone", - want: s.stores, + want: suite.stores, }, { name: "other", - want: s.stores[3:], + want: suite.stores[3:], }, { name: "zone", value: "Us-west-1", - want: s.stores[:1], + want: suite.stores[:1], }, { name: "Zone", value: "west", - want: s.stores[:2], + want: suite.stores[:2], }, { name: "Zo", value: "Beijing", - want: s.stores[2:3], + want: suite.stores[2:3], }, { name: "ZONE", @@ -172,40 +174,47 @@ func (s *testLabelsStoreSuite) TestStoresLabelFilter(c *C) { want: []*metapb.Store{}, }, } + re := suite.Require() for _, t := range table { - url := fmt.Sprintf("%s/labels/stores?name=%s&value=%s", s.urlPrefix, t.name, t.value) + url := fmt.Sprintf("%s/labels/stores?name=%s&value=%s", suite.urlPrefix, t.name, t.value) info := new(StoresInfo) - err := tu.ReadGetJSON(c, testDialClient, url, info) - c.Assert(err, IsNil) - checkStoresInfo(c, info.Stores, t.want) + err := tu.ReadGetJSON(re, testDialClient, url, info) + suite.NoError(err) + checkStoresInfo(re, info.Stores, t.want) } _, err := newStoresLabelFilter("test", ".[test") - c.Assert(err, NotNil) + suite.Error(err) } -type testStrictlyLabelsStoreSuite struct { +type strictlyLabelsStoreTestSuite struct { + suite.Suite svr *server.Server grpcSvr *server.GrpcServer cleanup cleanUpFunc urlPrefix string } -func (s *testStrictlyLabelsStoreSuite) SetUpSuite(c *C) { - s.svr, s.cleanup = mustNewServer(c, func(cfg *config.Config) { +func TestStrictlyLabelsStoreTestSuite(t *testing.T) { + suite.Run(t, new(strictlyLabelsStoreTestSuite)) +} + +func (suite *strictlyLabelsStoreTestSuite) SetupSuite() { + re := suite.Require() + suite.svr, suite.cleanup = mustNewServer(re, func(cfg *config.Config) { cfg.Replication.LocationLabels = []string{"zone", "disk"} cfg.Replication.StrictlyMatchLabel = true cfg.Replication.EnablePlacementRules = false }) - mustWaitLeader(c, []*server.Server{s.svr}) + mustWaitLeader(re, []*server.Server{suite.svr}) - s.grpcSvr = &server.GrpcServer{Server: s.svr} - addr := s.svr.GetAddr() - s.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) + suite.grpcSvr = &server.GrpcServer{Server: suite.svr} + addr := suite.svr.GetAddr() + suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) - mustBootstrapCluster(c, s.svr) + mustBootstrapCluster(re, suite.svr) } -func (s *testStrictlyLabelsStoreSuite) TestStoreMatch(c *C) { +func (suite *strictlyLabelsStoreTestSuite) TestStoreMatch() { cases := []struct { store *metapb.Store valid bool @@ -268,8 +277,8 @@ func (s *testStrictlyLabelsStoreSuite) TestStoreMatch(c *C) { } for _, t := range cases { - _, err := s.grpcSvr.PutStore(context.Background(), &pdpb.PutStoreRequest{ - Header: &pdpb.RequestHeader{ClusterId: s.svr.ClusterID()}, + _, err := suite.grpcSvr.PutStore(context.Background(), &pdpb.PutStoreRequest{ + Header: &pdpb.RequestHeader{ClusterId: suite.svr.ClusterID()}, Store: &metapb.Store{ Id: t.store.Id, Address: fmt.Sprintf("tikv%d", t.store.Id), @@ -279,17 +288,21 @@ func (s *testStrictlyLabelsStoreSuite) TestStoreMatch(c *C) { }, }) if t.valid { - c.Assert(err, IsNil) + suite.NoError(err) } else { - c.Assert(strings.Contains(err.Error(), t.expectError), IsTrue) + suite.Contains(err.Error(), t.expectError) } } // enable placement rules. Report no error any more. - c.Assert(tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/config", s.urlPrefix), []byte(`{"enable-placement-rules":"true"}`), tu.StatusOK(c)), IsNil) + suite.NoError(tu.CheckPostJSON( + testDialClient, + fmt.Sprintf("%s/config", suite.urlPrefix), + []byte(`{"enable-placement-rules":"true"}`), + tu.StatusOK(suite.Require()))) for _, t := range cases { - _, err := s.grpcSvr.PutStore(context.Background(), &pdpb.PutStoreRequest{ - Header: &pdpb.RequestHeader{ClusterId: s.svr.ClusterID()}, + _, err := suite.grpcSvr.PutStore(context.Background(), &pdpb.PutStoreRequest{ + Header: &pdpb.RequestHeader{ClusterId: suite.svr.ClusterID()}, Store: &metapb.Store{ Id: t.store.Id, Address: fmt.Sprintf("tikv%d", t.store.Id), @@ -299,13 +312,13 @@ func (s *testStrictlyLabelsStoreSuite) TestStoreMatch(c *C) { }, }) if t.valid { - c.Assert(err, IsNil) + suite.NoError(err) } else { - c.Assert(strings.Contains(err.Error(), t.expectError), IsTrue) + suite.Contains(err.Error(), t.expectError) } } } -func (s *testStrictlyLabelsStoreSuite) TearDownSuite(c *C) { - s.cleanup() +func (suite *strictlyLabelsStoreTestSuite) TearDownSuite() { + suite.cleanup() } diff --git a/server/api/log_test.go b/server/api/log_test.go index afd65bb35b3..f03472b8146 100644 --- a/server/api/log_test.go +++ b/server/api/log_test.go @@ -17,40 +17,45 @@ package api import ( "encoding/json" "fmt" + "testing" - . "github.com/pingcap/check" "github.com/pingcap/log" + "github.com/stretchr/testify/suite" tu "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server" ) -var _ = Suite(&testLogSuite{}) - -type testLogSuite struct { +type logTestSuite struct { + suite.Suite svr *server.Server cleanup cleanUpFunc urlPrefix string } -func (s *testLogSuite) SetUpSuite(c *C) { - s.svr, s.cleanup = mustNewServer(c) - mustWaitLeader(c, []*server.Server{s.svr}) +func TestLogTestSuite(t *testing.T) { + suite.Run(t, new(logTestSuite)) +} + +func (suite *logTestSuite) SetupSuite() { + re := suite.Require() + suite.svr, suite.cleanup = mustNewServer(re) + mustWaitLeader(re, []*server.Server{suite.svr}) - addr := s.svr.GetAddr() - s.urlPrefix = fmt.Sprintf("%s%s/api/v1/admin", addr, apiPrefix) + addr := suite.svr.GetAddr() + suite.urlPrefix = fmt.Sprintf("%s%s/api/v1/admin", addr, apiPrefix) - mustBootstrapCluster(c, s.svr) + mustBootstrapCluster(re, suite.svr) } -func (s *testLogSuite) TearDownSuite(c *C) { - s.cleanup() +func (suite *logTestSuite) TearDownSuite() { + suite.cleanup() } -func (s *testLogSuite) TestSetLogLevel(c *C) { +func (suite *logTestSuite) TestSetLogLevel() { level := "error" data, err := json.Marshal(level) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/log", data, tu.StatusOK(c)) - c.Assert(err, IsNil) - c.Assert(log.GetLevel().String(), Equals, level) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/log", data, tu.StatusOK(suite.Require())) + suite.NoError(err) + suite.Equal(level, log.GetLevel().String()) } diff --git a/server/api/member_test.go b/server/api/member_test.go index bffad22380b..1132010319d 100644 --- a/server/api/member_test.go +++ b/server/api/member_test.go @@ -23,24 +23,28 @@ import ( "net/http" "sort" "strings" + "testing" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" "github.com/tikv/pd/server" "github.com/tikv/pd/server/config" ) -var _ = Suite(&testMemberAPISuite{}) -var _ = Suite(&testResignAPISuite{}) - -type testMemberAPISuite struct { +type memberTestSuite struct { + suite.Suite cfgs []*config.Config servers []*server.Server clean func() } -func (s *testMemberAPISuite) SetUpSuite(c *C) { - s.cfgs, s.servers, s.clean = mustNewCluster(c, 3, func(cfg *config.Config) { +func TestMemberTestSuite(t *testing.T) { + suite.Run(t, new(memberTestSuite)) +} + +func (suite *memberTestSuite) SetupSuite() { + suite.cfgs, suite.servers, suite.clean = mustNewCluster(suite.Require(), 3, func(cfg *config.Config) { cfg.EnableLocalTSO = true cfg.Labels = map[string]string{ config.ZoneLabel: "dc-1", @@ -48,128 +52,132 @@ func (s *testMemberAPISuite) SetUpSuite(c *C) { }) } -func (s *testMemberAPISuite) TearDownSuite(c *C) { - s.clean() +func (suite *memberTestSuite) TearDownSuite() { + suite.clean() } -func relaxEqualStings(c *C, a, b []string) { +func relaxEqualStings(re *require.Assertions, a, b []string) { sort.Strings(a) sortedStringA := strings.Join(a, "") sort.Strings(b) sortedStringB := strings.Join(b, "") - c.Assert(sortedStringA, Equals, sortedStringB) + re.Equal(sortedStringB, sortedStringA) } -func checkListResponse(c *C, body []byte, cfgs []*config.Config) { +func (suite *memberTestSuite) checkListResponse(body []byte, cfgs []*config.Config) { got := make(map[string][]*pdpb.Member) json.Unmarshal(body, &got) - - c.Assert(len(got["members"]), Equals, len(cfgs)) - + suite.Len(cfgs, len(got["members"])) + re := suite.Require() for _, member := range got["members"] { for _, cfg := range cfgs { if member.GetName() != cfg.Name { continue } - c.Assert(member.DcLocation, Equals, "dc-1") - relaxEqualStings(c, member.ClientUrls, strings.Split(cfg.ClientUrls, ",")) - relaxEqualStings(c, member.PeerUrls, strings.Split(cfg.PeerUrls, ",")) + suite.Equal("dc-1", member.DcLocation) + relaxEqualStings(re, member.ClientUrls, strings.Split(cfg.ClientUrls, ",")) + relaxEqualStings(re, member.PeerUrls, strings.Split(cfg.PeerUrls, ",")) } } } -func (s *testMemberAPISuite) TestMemberList(c *C) { - for _, cfg := range s.cfgs { +func (suite *memberTestSuite) TestMemberList() { + for _, cfg := range suite.cfgs { addr := cfg.ClientUrls + apiPrefix + "/api/v1/members" resp, err := testDialClient.Get(addr) - c.Assert(err, IsNil) + suite.NoError(err) buf, err := io.ReadAll(resp.Body) - c.Assert(err, IsNil) + suite.NoError(err) resp.Body.Close() - checkListResponse(c, buf, s.cfgs) + suite.checkListResponse(buf, suite.cfgs) } } -func (s *testMemberAPISuite) TestMemberLeader(c *C) { - leader := s.servers[0].GetLeader() - addr := s.cfgs[rand.Intn(len(s.cfgs))].ClientUrls + apiPrefix + "/api/v1/leader" +func (suite *memberTestSuite) TestMemberLeader() { + leader := suite.servers[0].GetLeader() + addr := suite.cfgs[rand.Intn(len(suite.cfgs))].ClientUrls + apiPrefix + "/api/v1/leader" resp, err := testDialClient.Get(addr) - c.Assert(err, IsNil) + suite.NoError(err) defer resp.Body.Close() buf, err := io.ReadAll(resp.Body) - c.Assert(err, IsNil) + suite.NoError(err) var got pdpb.Member - c.Assert(json.Unmarshal(buf, &got), IsNil) - c.Assert(got.GetClientUrls(), DeepEquals, leader.GetClientUrls()) - c.Assert(got.GetMemberId(), Equals, leader.GetMemberId()) + suite.NoError(json.Unmarshal(buf, &got)) + suite.Equal(leader.GetClientUrls(), got.GetClientUrls()) + suite.Equal(leader.GetMemberId(), got.GetMemberId()) } -func (s *testMemberAPISuite) TestChangeLeaderPeerUrls(c *C) { - leader := s.servers[0].GetLeader() - addr := s.cfgs[rand.Intn(len(s.cfgs))].ClientUrls + apiPrefix + "/api/v1/leader" +func (suite *memberTestSuite) TestChangeLeaderPeerUrls() { + leader := suite.servers[0].GetLeader() + addr := suite.cfgs[rand.Intn(len(suite.cfgs))].ClientUrls + apiPrefix + "/api/v1/leader" resp, err := testDialClient.Get(addr) - c.Assert(err, IsNil) + suite.NoError(err) defer resp.Body.Close() buf, err := io.ReadAll(resp.Body) - c.Assert(err, IsNil) + suite.NoError(err) var got pdpb.Member - c.Assert(json.Unmarshal(buf, &got), IsNil) + suite.NoError(json.Unmarshal(buf, &got)) id := got.GetMemberId() peerUrls := got.GetPeerUrls() newPeerUrls := []string{"http://127.0.0.1:1111"} - changeLeaderPeerUrls(c, leader, id, newPeerUrls) - addr = s.cfgs[rand.Intn(len(s.cfgs))].ClientUrls + apiPrefix + "/api/v1/members" + suite.changeLeaderPeerUrls(leader, id, newPeerUrls) + addr = suite.cfgs[rand.Intn(len(suite.cfgs))].ClientUrls + apiPrefix + "/api/v1/members" resp, err = testDialClient.Get(addr) - c.Assert(err, IsNil) + suite.NoError(err) buf, err = io.ReadAll(resp.Body) - c.Assert(err, IsNil) + suite.NoError(err) resp.Body.Close() got1 := make(map[string]*pdpb.Member) json.Unmarshal(buf, &got1) - c.Assert(got1["leader"].GetPeerUrls(), DeepEquals, newPeerUrls) - c.Assert(got1["etcd_leader"].GetPeerUrls(), DeepEquals, newPeerUrls) + suite.Equal(newPeerUrls, got1["leader"].GetPeerUrls()) + suite.Equal(newPeerUrls, got1["etcd_leader"].GetPeerUrls()) // reset - changeLeaderPeerUrls(c, leader, id, peerUrls) + suite.changeLeaderPeerUrls(leader, id, peerUrls) } -func changeLeaderPeerUrls(c *C, leader *pdpb.Member, id uint64, urls []string) { +func (suite *memberTestSuite) changeLeaderPeerUrls(leader *pdpb.Member, id uint64, urls []string) { data := map[string][]string{"peerURLs": urls} postData, err := json.Marshal(data) - c.Assert(err, IsNil) + suite.NoError(err) req, err := http.NewRequest(http.MethodPut, fmt.Sprintf("%s/v2/members/%s", leader.GetClientUrls()[0], fmt.Sprintf("%x", id)), bytes.NewBuffer(postData)) - c.Assert(err, IsNil) + suite.NoError(err) req.Header.Set("Content-Type", "application/json") resp, err := testDialClient.Do(req) - c.Assert(err, IsNil) - c.Assert(resp.StatusCode, Equals, 204) + suite.NoError(err) + suite.Equal(204, resp.StatusCode) resp.Body.Close() } -type testResignAPISuite struct { +type resignTestSuite struct { + suite.Suite cfgs []*config.Config servers []*server.Server clean func() } -func (s *testResignAPISuite) SetUpSuite(c *C) { - s.cfgs, s.servers, s.clean = mustNewCluster(c, 1) +func TestResignTestSuite(t *testing.T) { + suite.Run(t, new(resignTestSuite)) +} + +func (suite *resignTestSuite) SetupSuite() { + suite.cfgs, suite.servers, suite.clean = mustNewCluster(suite.Require(), 1) } -func (s *testResignAPISuite) TearDownSuite(c *C) { - s.clean() +func (suite *resignTestSuite) TearDownSuite() { + suite.clean() } -func (s *testResignAPISuite) TestResignMyself(c *C) { - addr := s.cfgs[0].ClientUrls + apiPrefix + "/api/v1/leader/resign" +func (suite *resignTestSuite) TestResignMyself() { + addr := suite.cfgs[0].ClientUrls + apiPrefix + "/api/v1/leader/resign" resp, err := testDialClient.Post(addr, "", nil) - c.Assert(err, IsNil) - c.Assert(resp.StatusCode, Equals, http.StatusOK) + suite.NoError(err) + suite.Equal(http.StatusOK, resp.StatusCode) _, _ = io.Copy(io.Discard, resp.Body) resp.Body.Close() } diff --git a/server/api/min_resolved_ts_test.go b/server/api/min_resolved_ts_test.go index 69f4935f184..47c47713bff 100644 --- a/server/api/min_resolved_ts_test.go +++ b/server/api/min_resolved_ts_test.go @@ -16,47 +16,52 @@ package api import ( "fmt" + "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/apiutil" "github.com/tikv/pd/pkg/typeutil" "github.com/tikv/pd/server" "github.com/tikv/pd/server/cluster" ) -var _ = Suite(&testMinResolvedTSSuite{}) - -type testMinResolvedTSSuite struct { +type minResolvedTSTestSuite struct { + suite.Suite svr *server.Server cleanup cleanUpFunc urlPrefix string } -func (s *testMinResolvedTSSuite) SetUpSuite(c *C) { +func TestMinResolvedTSTestSuite(t *testing.T) { + suite.Run(t, new(minResolvedTSTestSuite)) +} + +func (suite *minResolvedTSTestSuite) SetupSuite() { + re := suite.Require() cluster.DefaultMinResolvedTSPersistenceInterval = time.Microsecond - s.svr, s.cleanup = mustNewServer(c) - mustWaitLeader(c, []*server.Server{s.svr}) + suite.svr, suite.cleanup = mustNewServer(re) + mustWaitLeader(re, []*server.Server{suite.svr}) - addr := s.svr.GetAddr() - s.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) + addr := suite.svr.GetAddr() + suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) - mustBootstrapCluster(c, s.svr) - mustPutStore(c, s.svr, 1, metapb.StoreState_Up, metapb.NodeState_Serving, nil) + mustBootstrapCluster(re, suite.svr) + mustPutStore(re, suite.svr, 1, metapb.StoreState_Up, metapb.NodeState_Serving, nil) r1 := newTestRegionInfo(7, 1, []byte("a"), []byte("b")) - mustRegionHeartbeat(c, s.svr, r1) + mustRegionHeartbeat(re, suite.svr, r1) r2 := newTestRegionInfo(8, 1, []byte("b"), []byte("c")) - mustRegionHeartbeat(c, s.svr, r2) + mustRegionHeartbeat(re, suite.svr, r2) } -func (s *testMinResolvedTSSuite) TearDownSuite(c *C) { - s.cleanup() +func (suite *minResolvedTSTestSuite) TearDownSuite() { + suite.cleanup() } -func (s *testMinResolvedTSSuite) TestMinResolvedTS(c *C) { - url := s.urlPrefix + "/min-resolved-ts" - rc := s.svr.GetRaftCluster() +func (suite *minResolvedTSTestSuite) TestMinResolvedTS() { + url := suite.urlPrefix + "/min-resolved-ts" + rc := suite.svr.GetRaftCluster() ts := uint64(233) rc.SetMinResolvedTS(1, ts) @@ -67,18 +72,18 @@ func (s *testMinResolvedTSSuite) TestMinResolvedTS(c *C) { PersistInterval: typeutil.Duration{Duration: 0}, } res, err := testDialClient.Get(url) - c.Assert(err, IsNil) + suite.NoError(err) defer res.Body.Close() listResp := &minResolvedTS{} err = apiutil.ReadJSON(res.Body, listResp) - c.Assert(err, IsNil) - c.Assert(listResp, DeepEquals, result) + suite.NoError(err) + suite.Equal(result, listResp) // run job interval := typeutil.NewDuration(time.Microsecond) - cfg := s.svr.GetRaftCluster().GetOpts().GetPDServerConfig().Clone() + cfg := suite.svr.GetRaftCluster().GetOpts().GetPDServerConfig().Clone() cfg.MinResolvedTSPersistenceInterval = interval - s.svr.GetRaftCluster().GetOpts().SetPDServerConfig(cfg) + suite.svr.GetRaftCluster().GetOpts().SetPDServerConfig(cfg) time.Sleep(time.Millisecond) result = &minResolvedTS{ MinResolvedTS: ts, @@ -86,10 +91,10 @@ func (s *testMinResolvedTSSuite) TestMinResolvedTS(c *C) { PersistInterval: interval, } res, err = testDialClient.Get(url) - c.Assert(err, IsNil) + suite.NoError(err) defer res.Body.Close() listResp = &minResolvedTS{} err = apiutil.ReadJSON(res.Body, listResp) - c.Assert(err, IsNil) - c.Assert(listResp, DeepEquals, result) + suite.NoError(err) + suite.Equal(result, listResp) } diff --git a/server/api/operator_test.go b/server/api/operator_test.go index 86d99c5e726..ba08b890b9b 100644 --- a/server/api/operator_test.go +++ b/server/api/operator_test.go @@ -21,12 +21,14 @@ import ( "io" "strconv" "strings" + "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/apiutil" "github.com/tikv/pd/pkg/mock/mockhbstream" tu "github.com/tikv/pd/pkg/testutil" @@ -38,34 +40,37 @@ import ( "github.com/tikv/pd/server/versioninfo" ) -var _ = Suite(&testOperatorSuite{}) - -var _ = Suite(&testTransferRegionOperatorSuite{}) - -type testOperatorSuite struct { +type operatorTestSuite struct { + suite.Suite svr *server.Server cleanup cleanUpFunc urlPrefix string } -func (s *testOperatorSuite) SetUpSuite(c *C) { - c.Assert(failpoint.Enable("github.com/tikv/pd/server/schedule/unexpectedOperator", "return(true)"), IsNil) - s.svr, s.cleanup = mustNewServer(c, func(cfg *config.Config) { cfg.Replication.MaxReplicas = 1 }) - mustWaitLeader(c, []*server.Server{s.svr}) +func TestOperatorTestSuite(t *testing.T) { + suite.Run(t, new(operatorTestSuite)) +} + +func (suite *operatorTestSuite) SetupSuite() { + re := suite.Require() + suite.NoError(failpoint.Enable("github.com/tikv/pd/server/schedule/unexpectedOperator", "return(true)")) + suite.svr, suite.cleanup = mustNewServer(re, func(cfg *config.Config) { cfg.Replication.MaxReplicas = 1 }) + mustWaitLeader(re, []*server.Server{suite.svr}) - addr := s.svr.GetAddr() - s.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) + addr := suite.svr.GetAddr() + suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) - mustBootstrapCluster(c, s.svr) + mustBootstrapCluster(re, suite.svr) } -func (s *testOperatorSuite) TearDownSuite(c *C) { - s.cleanup() +func (suite *operatorTestSuite) TearDownSuite() { + suite.cleanup() } -func (s *testOperatorSuite) TestAddRemovePeer(c *C) { - mustPutStore(c, s.svr, 1, metapb.StoreState_Up, metapb.NodeState_Serving, nil) - mustPutStore(c, s.svr, 2, metapb.StoreState_Up, metapb.NodeState_Serving, nil) +func (suite *operatorTestSuite) TestAddRemovePeer() { + re := suite.Require() + mustPutStore(re, suite.svr, 1, metapb.StoreState_Up, metapb.NodeState_Serving, nil) + mustPutStore(re, suite.svr, 2, metapb.StoreState_Up, metapb.NodeState_Serving, nil) peer1 := &metapb.Peer{Id: 1, StoreId: 1} peer2 := &metapb.Peer{Id: 2, StoreId: 2} @@ -78,115 +83,123 @@ func (s *testOperatorSuite) TestAddRemovePeer(c *C) { }, } regionInfo := core.NewRegionInfo(region, peer1) - mustRegionHeartbeat(c, s.svr, regionInfo) - - regionURL := fmt.Sprintf("%s/operators/%d", s.urlPrefix, region.GetId()) - operator := mustReadURL(c, regionURL) - c.Assert(strings.Contains(operator, "operator not found"), IsTrue) - recordURL := fmt.Sprintf("%s/operators/records?from=%s", s.urlPrefix, strconv.FormatInt(time.Now().Unix(), 10)) - records := mustReadURL(c, recordURL) - c.Assert(strings.Contains(records, "operator not found"), IsTrue) - - mustPutStore(c, s.svr, 3, metapb.StoreState_Up, metapb.NodeState_Serving, nil) - err := tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", s.urlPrefix), []byte(`{"name":"add-peer", "region_id": 1, "store_id": 3}`), tu.StatusOK(c)) - c.Assert(err, IsNil) - operator = mustReadURL(c, regionURL) - c.Assert(strings.Contains(operator, "add learner peer 1 on store 3"), IsTrue) - c.Assert(strings.Contains(operator, "RUNNING"), IsTrue) + mustRegionHeartbeat(re, suite.svr, regionInfo) + + regionURL := fmt.Sprintf("%s/operators/%d", suite.urlPrefix, region.GetId()) + operator := mustReadURL(re, regionURL) + suite.Contains(operator, "operator not found") + recordURL := fmt.Sprintf("%s/operators/records?from=%s", suite.urlPrefix, strconv.FormatInt(time.Now().Unix(), 10)) + records := mustReadURL(re, recordURL) + suite.Contains(records, "operator not found") + + mustPutStore(re, suite.svr, 3, metapb.StoreState_Up, metapb.NodeState_Serving, nil) + err := tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", suite.urlPrefix), []byte(`{"name":"add-peer", "region_id": 1, "store_id": 3}`), tu.StatusOK(re)) + suite.NoError(err) + operator = mustReadURL(re, regionURL) + suite.Contains(operator, "add learner peer 1 on store 3") + suite.Contains(operator, "RUNNING") _, err = apiutil.DoDelete(testDialClient, regionURL) - c.Assert(err, IsNil) - records = mustReadURL(c, recordURL) - c.Assert(strings.Contains(records, "admin-add-peer {add peer: store [3]}"), IsTrue) + suite.NoError(err) + records = mustReadURL(re, recordURL) + suite.Contains(records, "admin-add-peer {add peer: store [3]}") - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", s.urlPrefix), []byte(`{"name":"remove-peer", "region_id": 1, "store_id": 2}`), tu.StatusOK(c)) - c.Assert(err, IsNil) - operator = mustReadURL(c, regionURL) - c.Assert(strings.Contains(operator, "RUNNING"), IsTrue) - c.Assert(strings.Contains(operator, "remove peer on store 2"), IsTrue) + err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", suite.urlPrefix), []byte(`{"name":"remove-peer", "region_id": 1, "store_id": 2}`), tu.StatusOK(re)) + suite.NoError(err) + operator = mustReadURL(re, regionURL) + suite.Contains(operator, "RUNNING") + suite.Contains(operator, "remove peer on store 2") _, err = apiutil.DoDelete(testDialClient, regionURL) - c.Assert(err, IsNil) - records = mustReadURL(c, recordURL) - c.Assert(strings.Contains(records, "admin-remove-peer {rm peer: store [2]}"), IsTrue) + suite.NoError(err) + records = mustReadURL(re, recordURL) + suite.Contains(records, "admin-remove-peer {rm peer: store [2]}") - mustPutStore(c, s.svr, 4, metapb.StoreState_Up, metapb.NodeState_Serving, nil) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", s.urlPrefix), []byte(`{"name":"add-learner", "region_id": 1, "store_id": 4}`), tu.StatusOK(c)) - c.Assert(err, IsNil) - operator = mustReadURL(c, regionURL) - c.Assert(strings.Contains(operator, "add learner peer 2 on store 4"), IsTrue) + mustPutStore(re, suite.svr, 4, metapb.StoreState_Up, metapb.NodeState_Serving, nil) + err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", suite.urlPrefix), []byte(`{"name":"add-learner", "region_id": 1, "store_id": 4}`), tu.StatusOK(re)) + suite.NoError(err) + operator = mustReadURL(re, regionURL) + suite.Contains(operator, "add learner peer 2 on store 4") // Fail to add peer to tombstone store. - err = s.svr.GetRaftCluster().RemoveStore(3, true) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", s.urlPrefix), []byte(`{"name":"add-peer", "region_id": 1, "store_id": 3}`), tu.StatusNotOK(c)) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", s.urlPrefix), []byte(`{"name":"transfer-peer", "region_id": 1, "from_store_id": 1, "to_store_id": 3}`), tu.StatusNotOK(c)) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", s.urlPrefix), []byte(`{"name":"transfer-region", "region_id": 1, "to_store_ids": [1, 2, 3]}`), tu.StatusNotOK(c)) - c.Assert(err, IsNil) + err = suite.svr.GetRaftCluster().RemoveStore(3, true) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", suite.urlPrefix), []byte(`{"name":"add-peer", "region_id": 1, "store_id": 3}`), tu.StatusNotOK(re)) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", suite.urlPrefix), []byte(`{"name":"transfer-peer", "region_id": 1, "from_store_id": 1, "to_store_id": 3}`), tu.StatusNotOK(re)) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", suite.urlPrefix), []byte(`{"name":"transfer-region", "region_id": 1, "to_store_ids": [1, 2, 3]}`), tu.StatusNotOK(re)) + suite.NoError(err) // Fail to get operator if from is latest. time.Sleep(time.Second) - records = mustReadURL(c, fmt.Sprintf("%s/operators/records?from=%s", s.urlPrefix, strconv.FormatInt(time.Now().Unix(), 10))) - c.Assert(strings.Contains(records, "operator not found"), IsTrue) + records = mustReadURL(re, fmt.Sprintf("%s/operators/records?from=%s", suite.urlPrefix, strconv.FormatInt(time.Now().Unix(), 10))) + suite.Contains(records, "operator not found") } -func (s *testOperatorSuite) TestMergeRegionOperator(c *C) { +func (suite *operatorTestSuite) TestMergeRegionOperator() { + re := suite.Require() r1 := newTestRegionInfo(10, 1, []byte(""), []byte("b"), core.SetWrittenBytes(1000), core.SetReadBytes(1000), core.SetRegionConfVer(1), core.SetRegionVersion(1)) - mustRegionHeartbeat(c, s.svr, r1) + mustRegionHeartbeat(re, suite.svr, r1) r2 := newTestRegionInfo(20, 1, []byte("b"), []byte("c"), core.SetWrittenBytes(2000), core.SetReadBytes(0), core.SetRegionConfVer(2), core.SetRegionVersion(3)) - mustRegionHeartbeat(c, s.svr, r2) + mustRegionHeartbeat(re, suite.svr, r2) r3 := newTestRegionInfo(30, 1, []byte("c"), []byte(""), core.SetWrittenBytes(500), core.SetReadBytes(800), core.SetRegionConfVer(3), core.SetRegionVersion(2)) - mustRegionHeartbeat(c, s.svr, r3) - - err := tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", s.urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 10, "target_region_id": 20}`), tu.StatusOK(c)) - c.Assert(err, IsNil) - - s.svr.GetHandler().RemoveOperator(10) - s.svr.GetHandler().RemoveOperator(20) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", s.urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 20, "target_region_id": 10}`), tu.StatusOK(c)) - c.Assert(err, IsNil) - s.svr.GetHandler().RemoveOperator(10) - s.svr.GetHandler().RemoveOperator(20) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", s.urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 10, "target_region_id": 30}`), - tu.StatusNotOK(c), tu.StringContain(c, "not adjacent")) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", s.urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 30, "target_region_id": 10}`), - tu.StatusNotOK(c), tu.StringContain(c, "not adjacent")) - c.Assert(err, IsNil) + mustRegionHeartbeat(re, suite.svr, r3) + + err := tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", suite.urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 10, "target_region_id": 20}`), tu.StatusOK(re)) + suite.NoError(err) + + suite.svr.GetHandler().RemoveOperator(10) + suite.svr.GetHandler().RemoveOperator(20) + err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", suite.urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 20, "target_region_id": 10}`), tu.StatusOK(re)) + suite.NoError(err) + suite.svr.GetHandler().RemoveOperator(10) + suite.svr.GetHandler().RemoveOperator(20) + err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", suite.urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 10, "target_region_id": 30}`), + tu.StatusNotOK(re), tu.StringContain(re, "not adjacent")) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", suite.urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 30, "target_region_id": 10}`), + tu.StatusNotOK(re), tu.StringContain(re, "not adjacent")) + suite.NoError(err) } -type testTransferRegionOperatorSuite struct { +type transferRegionOperatorTestSuite struct { + suite.Suite svr *server.Server cleanup cleanUpFunc urlPrefix string } -func (s *testTransferRegionOperatorSuite) SetUpSuite(c *C) { - c.Assert(failpoint.Enable("github.com/tikv/pd/server/schedule/unexpectedOperator", "return(true)"), IsNil) - s.svr, s.cleanup = mustNewServer(c, func(cfg *config.Config) { cfg.Replication.MaxReplicas = 3 }) - mustWaitLeader(c, []*server.Server{s.svr}) +func TestTransferRegionOperatorTestSuite(t *testing.T) { + suite.Run(t, new(transferRegionOperatorTestSuite)) +} + +func (suite *transferRegionOperatorTestSuite) SetupSuite() { + re := suite.Require() + suite.NoError(failpoint.Enable("github.com/tikv/pd/server/schedule/unexpectedOperator", "return(true)")) + suite.svr, suite.cleanup = mustNewServer(re, func(cfg *config.Config) { cfg.Replication.MaxReplicas = 3 }) + mustWaitLeader(re, []*server.Server{suite.svr}) - addr := s.svr.GetAddr() - s.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) + addr := suite.svr.GetAddr() + suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) - mustBootstrapCluster(c, s.svr) + mustBootstrapCluster(re, suite.svr) } -func (s *testTransferRegionOperatorSuite) TearDownSuite(c *C) { - s.cleanup() +func (suite *transferRegionOperatorTestSuite) TearDownSuite() { + suite.cleanup() } -func (s *testTransferRegionOperatorSuite) TestTransferRegionWithPlacementRule(c *C) { - mustPutStore(c, s.svr, 1, metapb.StoreState_Up, metapb.NodeState_Serving, []*metapb.StoreLabel{{Key: "key", Value: "1"}}) - mustPutStore(c, s.svr, 2, metapb.StoreState_Up, metapb.NodeState_Serving, []*metapb.StoreLabel{{Key: "key", Value: "2"}}) - mustPutStore(c, s.svr, 3, metapb.StoreState_Up, metapb.NodeState_Serving, []*metapb.StoreLabel{{Key: "key", Value: "3"}}) +func (suite *transferRegionOperatorTestSuite) TestTransferRegionWithPlacementRule() { + re := suite.Require() + mustPutStore(re, suite.svr, 1, metapb.StoreState_Up, metapb.NodeState_Serving, []*metapb.StoreLabel{{Key: "key", Value: "1"}}) + mustPutStore(re, suite.svr, 2, metapb.StoreState_Up, metapb.NodeState_Serving, []*metapb.StoreLabel{{Key: "key", Value: "2"}}) + mustPutStore(re, suite.svr, 3, metapb.StoreState_Up, metapb.NodeState_Serving, []*metapb.StoreLabel{{Key: "key", Value: "3"}}) hbStream := mockhbstream.NewHeartbeatStream() - s.svr.GetHBStreams().BindStream(1, hbStream) - s.svr.GetHBStreams().BindStream(2, hbStream) - s.svr.GetHBStreams().BindStream(3, hbStream) + suite.svr.GetHBStreams().BindStream(1, hbStream) + suite.svr.GetHBStreams().BindStream(2, hbStream) + suite.svr.GetHBStreams().BindStream(3, hbStream) peer1 := &metapb.Peer{Id: 1, StoreId: 1} peer2 := &metapb.Peer{Id: 2, StoreId: 2} @@ -199,11 +212,11 @@ func (s *testTransferRegionOperatorSuite) TestTransferRegionWithPlacementRule(c Version: 1, }, } - mustRegionHeartbeat(c, s.svr, core.NewRegionInfo(region, peer1)) + mustRegionHeartbeat(re, suite.svr, core.NewRegionInfo(region, peer1)) - regionURL := fmt.Sprintf("%s/operators/%d", s.urlPrefix, region.GetId()) - operator := mustReadURL(c, regionURL) - c.Assert(strings.Contains(operator, "operator not found"), IsTrue) + regionURL := fmt.Sprintf("%s/operators/%d", suite.urlPrefix, region.GetId()) + operator := mustReadURL(re, regionURL) + suite.Contains(operator, "operator not found") tt := []struct { name string @@ -359,39 +372,39 @@ func (s *testTransferRegionOperatorSuite) TestTransferRegionWithPlacementRule(c }, } for _, tc := range tt { - c.Log(tc.name) - s.svr.GetRaftCluster().GetOpts().SetPlacementRuleEnabled(tc.placementRuleEnable) + suite.T().Log(tc.name) + suite.svr.GetRaftCluster().GetOpts().SetPlacementRuleEnabled(tc.placementRuleEnable) if tc.placementRuleEnable { - err := s.svr.GetRaftCluster().GetRuleManager().Initialize( - s.svr.GetRaftCluster().GetOpts().GetMaxReplicas(), - s.svr.GetRaftCluster().GetOpts().GetLocationLabels()) - c.Assert(err, IsNil) + err := suite.svr.GetRaftCluster().GetRuleManager().Initialize( + suite.svr.GetRaftCluster().GetOpts().GetMaxReplicas(), + suite.svr.GetRaftCluster().GetOpts().GetLocationLabels()) + suite.NoError(err) } if len(tc.rules) > 0 { // add customized rule first and then remove default rule - err := s.svr.GetRaftCluster().GetRuleManager().SetRules(tc.rules) - c.Assert(err, IsNil) - err = s.svr.GetRaftCluster().GetRuleManager().DeleteRule("pd", "default") - c.Assert(err, IsNil) + err := suite.svr.GetRaftCluster().GetRuleManager().SetRules(tc.rules) + suite.NoError(err) + err = suite.svr.GetRaftCluster().GetRuleManager().DeleteRule("pd", "default") + suite.NoError(err) } var err error if tc.expectedError == nil { - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", s.urlPrefix), tc.input, tu.StatusOK(c)) + err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", suite.urlPrefix), tc.input, tu.StatusOK(re)) } else { - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", s.urlPrefix), tc.input, - tu.StatusNotOK(c), tu.StringContain(c, tc.expectedError.Error())) + err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", suite.urlPrefix), tc.input, + tu.StatusNotOK(re), tu.StringContain(re, tc.expectedError.Error())) } - c.Assert(err, IsNil) + suite.NoError(err) if len(tc.expectSteps) > 0 { - operator = mustReadURL(c, regionURL) - c.Assert(strings.Contains(operator, tc.expectSteps), IsTrue) + operator = mustReadURL(re, regionURL) + suite.Contains(operator, tc.expectSteps) } _, err = apiutil.DoDelete(testDialClient, regionURL) - c.Assert(err, IsNil) + suite.NoError(err) } } -func mustPutStore(c *C, svr *server.Server, id uint64, state metapb.StoreState, nodeState metapb.NodeState, labels []*metapb.StoreLabel) { +func mustPutStore(re *require.Assertions, svr *server.Server, id uint64, state metapb.StoreState, nodeState metapb.NodeState, labels []*metapb.StoreLabel) { s := &server.GrpcServer{Server: svr} _, err := s.PutStore(context.Background(), &pdpb.PutStoreRequest{ Header: &pdpb.RequestHeader{ClusterId: svr.ClusterID()}, @@ -404,27 +417,27 @@ func mustPutStore(c *C, svr *server.Server, id uint64, state metapb.StoreState, Version: versioninfo.MinSupportedVersion(versioninfo.Version2_0).String(), }, }) - c.Assert(err, IsNil) + re.NoError(err) if state == metapb.StoreState_Up { _, err = s.StoreHeartbeat(context.Background(), &pdpb.StoreHeartbeatRequest{ Header: &pdpb.RequestHeader{ClusterId: svr.ClusterID()}, Stats: &pdpb.StoreStats{StoreId: id}, }) - c.Assert(err, IsNil) + re.NoError(err) } } -func mustRegionHeartbeat(c *C, svr *server.Server, region *core.RegionInfo) { +func mustRegionHeartbeat(re *require.Assertions, svr *server.Server, region *core.RegionInfo) { cluster := svr.GetRaftCluster() err := cluster.HandleRegionHeartbeat(region) - c.Assert(err, IsNil) + re.NoError(err) } -func mustReadURL(c *C, url string) string { +func mustReadURL(re *require.Assertions, url string) string { res, err := testDialClient.Get(url) - c.Assert(err, IsNil) + re.NoError(err) defer res.Body.Close() data, err := io.ReadAll(res.Body) - c.Assert(err, IsNil) + re.NoError(err) return string(data) } diff --git a/server/api/pprof_test.go b/server/api/pprof_test.go index 326b4d09005..3d80a325758 100644 --- a/server/api/pprof_test.go +++ b/server/api/pprof_test.go @@ -18,41 +18,46 @@ import ( "bytes" "fmt" "io/ioutil" + "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/suite" "github.com/tikv/pd/server" ) -var _ = Suite(&ProfSuite{}) - -type ProfSuite struct { +type profTestSuite struct { + suite.Suite svr *server.Server cleanup cleanUpFunc urlPrefix string } -func (s *ProfSuite) SetUpSuite(c *C) { - s.svr, s.cleanup = mustNewServer(c) - mustWaitLeader(c, []*server.Server{s.svr}) +func TestProfTestSuite(t *testing.T) { + suite.Run(t, new(profTestSuite)) +} + +func (suite *profTestSuite) SetupSuite() { + re := suite.Require() + suite.svr, suite.cleanup = mustNewServer(re) + mustWaitLeader(re, []*server.Server{suite.svr}) - addr := s.svr.GetAddr() - s.urlPrefix = fmt.Sprintf("%s%s/api/v1/debug", addr, apiPrefix) + addr := suite.svr.GetAddr() + suite.urlPrefix = fmt.Sprintf("%s%s/api/v1/debug", addr, apiPrefix) - mustBootstrapCluster(c, s.svr) + mustBootstrapCluster(re, suite.svr) } -func (s *ProfSuite) TearDownSuite(c *C) { - s.cleanup() +func (suite *profTestSuite) TearDownSuite() { + suite.cleanup() } -func (s *ProfSuite) TestGetZip(c *C) { - rsp, err := testDialClient.Get(s.urlPrefix + "/pprof/zip?" + "seconds=5s") - c.Assert(err, IsNil) +func (suite *profTestSuite) TestGetZip() { + rsp, err := testDialClient.Get(suite.urlPrefix + "/pprof/zip?" + "seconds=5s") + suite.NoError(err) defer rsp.Body.Close() body, err := ioutil.ReadAll(rsp.Body) - c.Assert(err, IsNil) - c.Assert(body, NotNil) + suite.NoError(err) + suite.NotNil(body) zipReader, err := zip.NewReader(bytes.NewReader(body), int64(len(body))) - c.Assert(err, IsNil) - c.Assert(zipReader.File, HasLen, 7) + suite.NoError(err) + suite.Len(zipReader.File, 7) } diff --git a/server/api/region.go b/server/api/region.go index a4ae3fe0df9..9756f12e3c2 100644 --- a/server/api/region.go +++ b/server/api/region.go @@ -137,12 +137,12 @@ func fromPBReplicationStatus(s *replication_modepb.RegionReplicationStatus) *Rep } } -// NewRegionInfo create a new api RegionInfo. -func NewRegionInfo(r *core.RegionInfo) *RegionInfo { +// NewAPIRegionInfo create a new API RegionInfo. +func NewAPIRegionInfo(r *core.RegionInfo) *RegionInfo { return InitRegion(r, &RegionInfo{}) } -// InitRegion init a new api RegionInfo from the core.RegionInfo. +// InitRegion init a new API RegionInfo from the core.RegionInfo. func InitRegion(r *core.RegionInfo, s *RegionInfo) *RegionInfo { if r == nil { return nil @@ -228,7 +228,7 @@ func (h *regionHandler) GetRegionByID(w http.ResponseWriter, r *http.Request) { } regionInfo := rc.GetRegion(regionID) - h.rd.JSON(w, http.StatusOK, NewRegionInfo(regionInfo)) + h.rd.JSON(w, http.StatusOK, NewAPIRegionInfo(regionInfo)) } // @Tags region @@ -247,7 +247,7 @@ func (h *regionHandler) GetRegion(w http.ResponseWriter, r *http.Request) { return } regionInfo := rc.GetRegionByKey([]byte(key)) - h.rd.JSON(w, http.StatusOK, NewRegionInfo(regionInfo)) + h.rd.JSON(w, http.StatusOK, NewAPIRegionInfo(regionInfo)) } // @Tags region diff --git a/server/api/region_label_test.go b/server/api/region_label_test.go index 0e6aa7ccf62..0165ec7d37e 100644 --- a/server/api/region_label_test.go +++ b/server/api/region_label_test.go @@ -19,41 +19,47 @@ import ( "fmt" "net/url" "sort" + "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/apiutil" tu "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server" "github.com/tikv/pd/server/schedule/labeler" ) -var _ = Suite(&testRegionLabelSuite{}) - -type testRegionLabelSuite struct { +type regionLabelTestSuite struct { + suite.Suite svr *server.Server cleanup cleanUpFunc urlPrefix string } -func (s *testRegionLabelSuite) SetUpSuite(c *C) { - s.svr, s.cleanup = mustNewServer(c) - mustWaitLeader(c, []*server.Server{s.svr}) +func TestRegionLabelTestSuite(t *testing.T) { + suite.Run(t, new(regionLabelTestSuite)) +} + +func (suite *regionLabelTestSuite) SetupSuite() { + re := suite.Require() + suite.svr, suite.cleanup = mustNewServer(re) + mustWaitLeader(re, []*server.Server{suite.svr}) - addr := s.svr.GetAddr() - s.urlPrefix = fmt.Sprintf("%s%s/api/v1/config/region-label/", addr, apiPrefix) + addr := suite.svr.GetAddr() + suite.urlPrefix = fmt.Sprintf("%s%s/api/v1/config/region-label/", addr, apiPrefix) - mustBootstrapCluster(c, s.svr) + mustBootstrapCluster(re, suite.svr) } -func (s *testRegionLabelSuite) TearDownSuite(c *C) { - s.cleanup() +func (suite *regionLabelTestSuite) TearDownSuite() { + suite.cleanup() } -func (s *testRegionLabelSuite) TestGetSet(c *C) { +func (suite *regionLabelTestSuite) TestGetSet() { + re := suite.Require() var resp []*labeler.LabelRule - err := tu.ReadGetJSON(c, testDialClient, s.urlPrefix+"rules", &resp) - c.Assert(err, IsNil) - c.Assert(resp, HasLen, 0) + err := tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"rules", &resp) + suite.NoError(err) + suite.Len(resp, 0) rules := []*labeler.LabelRule{ {ID: "rule1", Labels: []labeler.RegionLabel{{Key: "k1", Value: "v1"}}, RuleType: "key-range", Data: makeKeyRanges("1234", "5678")}, @@ -63,27 +69,27 @@ func (s *testRegionLabelSuite) TestGetSet(c *C) { ruleIDs := []string{"rule1", "rule2/a/b", "rule3"} for _, rule := range rules { data, _ := json.Marshal(rule) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"rule", data, tu.StatusOK(c)) - c.Assert(err, IsNil) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"rule", data, tu.StatusOK(re)) + suite.NoError(err) } for i, id := range ruleIDs { var rule labeler.LabelRule - err = tu.ReadGetJSON(c, testDialClient, s.urlPrefix+"rule/"+url.QueryEscape(id), &rule) - c.Assert(err, IsNil) - c.Assert(&rule, DeepEquals, rules[i]) + err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"rule/"+url.QueryEscape(id), &rule) + suite.NoError(err) + suite.Equal(rules[i], &rule) } - err = tu.ReadGetJSONWithBody(c, testDialClient, s.urlPrefix+"rules/ids", []byte(`["rule1", "rule3"]`), &resp) - c.Assert(err, IsNil) + err = tu.ReadGetJSONWithBody(re, testDialClient, suite.urlPrefix+"rules/ids", []byte(`["rule1", "rule3"]`), &resp) + suite.NoError(err) expects := []*labeler.LabelRule{rules[0], rules[2]} - c.Assert(resp, DeepEquals, expects) + suite.Equal(expects, resp) - _, err = apiutil.DoDelete(testDialClient, s.urlPrefix+"rule/"+url.QueryEscape("rule2/a/b")) - c.Assert(err, IsNil) - err = tu.ReadGetJSON(c, testDialClient, s.urlPrefix+"rules", &resp) - c.Assert(err, IsNil) + _, err = apiutil.DoDelete(testDialClient, suite.urlPrefix+"rule/"+url.QueryEscape("rule2/a/b")) + suite.NoError(err) + err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"rules", &resp) + suite.NoError(err) sort.Slice(resp, func(i, j int) bool { return resp[i].ID < resp[j].ID }) - c.Assert(resp, DeepEquals, []*labeler.LabelRule{rules[0], rules[2]}) + suite.Equal([]*labeler.LabelRule{rules[0], rules[2]}, resp) patch := labeler.LabelRulePatch{ SetRules: []*labeler.LabelRule{ @@ -92,12 +98,12 @@ func (s *testRegionLabelSuite) TestGetSet(c *C) { DeleteRules: []string{"rule1"}, } data, _ := json.Marshal(patch) - err = tu.CheckPatchJSON(testDialClient, s.urlPrefix+"rules", data, tu.StatusOK(c)) - c.Assert(err, IsNil) - err = tu.ReadGetJSON(c, testDialClient, s.urlPrefix+"rules", &resp) - c.Assert(err, IsNil) + err = tu.CheckPatchJSON(testDialClient, suite.urlPrefix+"rules", data, tu.StatusOK(re)) + suite.NoError(err) + err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"rules", &resp) + suite.NoError(err) sort.Slice(resp, func(i, j int) bool { return resp[i].ID < resp[j].ID }) - c.Assert(resp, DeepEquals, []*labeler.LabelRule{rules[1], rules[2]}) + suite.Equal([]*labeler.LabelRule{rules[1], rules[2]}, resp) } func makeKeyRanges(keys ...string) []interface{} { diff --git a/server/api/region_test.go b/server/api/region_test.go index 816fefcc092..168edd0e419 100644 --- a/server/api/region_test.go +++ b/server/api/region_test.go @@ -25,21 +25,19 @@ import ( "sort" "testing" - . "github.com/pingcap/check" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" tu "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server" "github.com/tikv/pd/server/core" "github.com/tikv/pd/server/schedule/placement" ) -var _ = Suite(&testRegionStructSuite{}) - -type testRegionStructSuite struct{} - -func (s *testRegionStructSuite) TestPeer(c *C) { +func TestPeer(t *testing.T) { + re := require.New(t) peers := []*metapb.Peer{ {Id: 1, StoreId: 10, Role: metapb.PeerRole_Voter}, {Id: 2, StoreId: 20, Role: metapb.PeerRole_Learner}, @@ -55,13 +53,14 @@ func (s *testRegionStructSuite) TestPeer(c *C) { } data, err := json.Marshal(fromPeerSlice(peers)) - c.Assert(err, IsNil) + re.NoError(err) var ret []map[string]interface{} - c.Assert(json.Unmarshal(data, &ret), IsNil) - c.Assert(ret, DeepEquals, expected) + re.NoError(json.Unmarshal(data, &ret)) + re.Equal(expected, ret) } -func (s *testRegionStructSuite) TestPeerStats(c *C) { +func TestPeerStats(t *testing.T) { + re := require.New(t) peers := []*pdpb.PeerStats{ {Peer: &metapb.Peer{Id: 1, StoreId: 10, Role: metapb.PeerRole_Voter}, DownSeconds: 0}, {Peer: &metapb.Peer{Id: 2, StoreId: 20, Role: metapb.PeerRole_Learner}, DownSeconds: 1}, @@ -77,32 +76,36 @@ func (s *testRegionStructSuite) TestPeerStats(c *C) { } data, err := json.Marshal(fromPeerStatsSlice(peers)) - c.Assert(err, IsNil) + re.NoError(err) var ret []map[string]interface{} - c.Assert(json.Unmarshal(data, &ret), IsNil) - c.Assert(ret, DeepEquals, expected) + re.NoError(json.Unmarshal(data, &ret)) + re.Equal(expected, ret) } -var _ = Suite(&testRegionSuite{}) - -type testRegionSuite struct { +type regionTestSuite struct { + suite.Suite svr *server.Server cleanup cleanUpFunc urlPrefix string } -func (s *testRegionSuite) SetUpSuite(c *C) { - s.svr, s.cleanup = mustNewServer(c) - mustWaitLeader(c, []*server.Server{s.svr}) +func TestRegionTestSuite(t *testing.T) { + suite.Run(t, new(regionTestSuite)) +} + +func (suite *regionTestSuite) SetupSuite() { + re := suite.Require() + suite.svr, suite.cleanup = mustNewServer(re) + mustWaitLeader(re, []*server.Server{suite.svr}) - addr := s.svr.GetAddr() - s.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) + addr := suite.svr.GetAddr() + suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) - mustBootstrapCluster(c, s.svr) + mustBootstrapCluster(re, suite.svr) } -func (s *testRegionSuite) TearDownSuite(c *C) { - s.cleanup() +func (suite *regionTestSuite) TearDownSuite() { + suite.cleanup() } func newTestRegionInfo(regionID, storeID uint64, start, end []byte, opts ...core.RegionCreateOption) *core.RegionInfo { @@ -130,7 +133,7 @@ func newTestRegionInfo(regionID, storeID uint64, start, end []byte, opts ...core return region } -func (s *testRegionSuite) TestRegion(c *C) { +func (suite *regionTestSuite) TestRegion() { r := newTestRegionInfo(2, 1, []byte("a"), []byte("b")) buckets := &metapb.Buckets{ RegionId: 2, @@ -138,227 +141,236 @@ func (s *testRegionSuite) TestRegion(c *C) { Version: 1, } r.UpdateBuckets(buckets, r.GetBuckets()) - mustRegionHeartbeat(c, s.svr, r) - url := fmt.Sprintf("%s/region/id/%d", s.urlPrefix, r.GetID()) + re := suite.Require() + mustRegionHeartbeat(re, suite.svr, r) + url := fmt.Sprintf("%s/region/id/%d", suite.urlPrefix, r.GetID()) r1 := &RegionInfo{} r1m := make(map[string]interface{}) - c.Assert(tu.ReadGetJSON(c, testDialClient, url, r1), IsNil) + suite.NoError(tu.ReadGetJSON(re, testDialClient, url, r1)) r1.Adjust() - c.Assert(r1, DeepEquals, NewRegionInfo(r)) - c.Assert(tu.ReadGetJSON(c, testDialClient, url, &r1m), IsNil) - c.Assert(r1m["written_bytes"].(float64), Equals, float64(r.GetBytesWritten())) - c.Assert(r1m["written_keys"].(float64), Equals, float64(r.GetKeysWritten())) - c.Assert(r1m["read_bytes"].(float64), Equals, float64(r.GetBytesRead())) - c.Assert(r1m["read_keys"].(float64), Equals, float64(r.GetKeysRead())) + suite.Equal(NewAPIRegionInfo(r), r1) + suite.NoError(tu.ReadGetJSON(re, testDialClient, url, &r1m)) + suite.Equal(float64(r.GetBytesWritten()), r1m["written_bytes"].(float64)) + suite.Equal(float64(r.GetKeysWritten()), r1m["written_keys"].(float64)) + suite.Equal(float64(r.GetBytesRead()), r1m["read_bytes"].(float64)) + suite.Equal(float64(r.GetKeysRead()), r1m["read_keys"].(float64)) keys := r1m["buckets"].([]interface{}) - c.Assert(keys, HasLen, 2) - c.Assert(keys[0].(string), Equals, core.HexRegionKeyStr([]byte("a"))) - c.Assert(keys[1].(string), Equals, core.HexRegionKeyStr([]byte("b"))) + suite.Len(keys, 2) + suite.Equal(core.HexRegionKeyStr([]byte("a")), keys[0].(string)) + suite.Equal(core.HexRegionKeyStr([]byte("b")), keys[1].(string)) - url = fmt.Sprintf("%s/region/key/%s", s.urlPrefix, "a") + url = fmt.Sprintf("%s/region/key/%s", suite.urlPrefix, "a") r2 := &RegionInfo{} - c.Assert(tu.ReadGetJSON(c, testDialClient, url, r2), IsNil) + suite.NoError(tu.ReadGetJSON(re, testDialClient, url, r2)) r2.Adjust() - c.Assert(r2, DeepEquals, NewRegionInfo(r)) + suite.Equal(NewAPIRegionInfo(r), r2) } -func (s *testRegionSuite) TestRegionCheck(c *C) { +func (suite *regionTestSuite) TestRegionCheck() { r := newTestRegionInfo(2, 1, []byte("a"), []byte("b")) downPeer := &metapb.Peer{Id: 13, StoreId: 2} r = r.Clone(core.WithAddPeer(downPeer), core.WithDownPeers([]*pdpb.PeerStats{{Peer: downPeer, DownSeconds: 3600}}), core.WithPendingPeers([]*metapb.Peer{downPeer})) - mustRegionHeartbeat(c, s.svr, r) - url := fmt.Sprintf("%s/region/id/%d", s.urlPrefix, r.GetID()) + re := suite.Require() + mustRegionHeartbeat(re, suite.svr, r) + url := fmt.Sprintf("%s/region/id/%d", suite.urlPrefix, r.GetID()) r1 := &RegionInfo{} - c.Assert(tu.ReadGetJSON(c, testDialClient, url, r1), IsNil) + suite.NoError(tu.ReadGetJSON(re, testDialClient, url, r1)) r1.Adjust() - c.Assert(r1, DeepEquals, NewRegionInfo(r)) + suite.Equal(NewAPIRegionInfo(r), r1) - url = fmt.Sprintf("%s/regions/check/%s", s.urlPrefix, "down-peer") + url = fmt.Sprintf("%s/regions/check/%s", suite.urlPrefix, "down-peer") r2 := &RegionsInfo{} - c.Assert(tu.ReadGetJSON(c, testDialClient, url, r2), IsNil) + suite.NoError(tu.ReadGetJSON(re, testDialClient, url, r2)) r2.Adjust() - c.Assert(r2, DeepEquals, &RegionsInfo{Count: 1, Regions: []RegionInfo{*NewRegionInfo(r)}}) + suite.Equal(&RegionsInfo{Count: 1, Regions: []RegionInfo{*NewAPIRegionInfo(r)}}, r2) - url = fmt.Sprintf("%s/regions/check/%s", s.urlPrefix, "pending-peer") + url = fmt.Sprintf("%s/regions/check/%s", suite.urlPrefix, "pending-peer") r3 := &RegionsInfo{} - c.Assert(tu.ReadGetJSON(c, testDialClient, url, r3), IsNil) + suite.NoError(tu.ReadGetJSON(re, testDialClient, url, r3)) r3.Adjust() - c.Assert(r3, DeepEquals, &RegionsInfo{Count: 1, Regions: []RegionInfo{*NewRegionInfo(r)}}) + suite.Equal(&RegionsInfo{Count: 1, Regions: []RegionInfo{*NewAPIRegionInfo(r)}}, r3) - url = fmt.Sprintf("%s/regions/check/%s", s.urlPrefix, "offline-peer") + url = fmt.Sprintf("%s/regions/check/%s", suite.urlPrefix, "offline-peer") r4 := &RegionsInfo{} - c.Assert(tu.ReadGetJSON(c, testDialClient, url, r4), IsNil) + suite.NoError(tu.ReadGetJSON(re, testDialClient, url, r4)) r4.Adjust() - c.Assert(r4, DeepEquals, &RegionsInfo{Count: 0, Regions: []RegionInfo{}}) + suite.Equal(&RegionsInfo{Count: 0, Regions: []RegionInfo{}}, r4) r = r.Clone(core.SetApproximateSize(1)) - mustRegionHeartbeat(c, s.svr, r) - url = fmt.Sprintf("%s/regions/check/%s", s.urlPrefix, "empty-region") + mustRegionHeartbeat(re, suite.svr, r) + url = fmt.Sprintf("%s/regions/check/%s", suite.urlPrefix, "empty-region") r5 := &RegionsInfo{} - c.Assert(tu.ReadGetJSON(c, testDialClient, url, r5), IsNil) + suite.NoError(tu.ReadGetJSON(re, testDialClient, url, r5)) r5.Adjust() - c.Assert(r5, DeepEquals, &RegionsInfo{Count: 1, Regions: []RegionInfo{*NewRegionInfo(r)}}) + suite.Equal(&RegionsInfo{Count: 1, Regions: []RegionInfo{*NewAPIRegionInfo(r)}}, r5) r = r.Clone(core.SetApproximateSize(1)) - mustRegionHeartbeat(c, s.svr, r) - url = fmt.Sprintf("%s/regions/check/%s", s.urlPrefix, "hist-size") + mustRegionHeartbeat(re, suite.svr, r) + url = fmt.Sprintf("%s/regions/check/%s", suite.urlPrefix, "hist-size") r6 := make([]*histItem, 1) - c.Assert(tu.ReadGetJSON(c, testDialClient, url, &r6), IsNil) + suite.NoError(tu.ReadGetJSON(re, testDialClient, url, &r6)) histSizes := []*histItem{{Start: 1, End: 1, Count: 1}} - c.Assert(r6, DeepEquals, histSizes) + suite.Equal(histSizes, r6) r = r.Clone(core.SetApproximateKeys(1000)) - mustRegionHeartbeat(c, s.svr, r) - url = fmt.Sprintf("%s/regions/check/%s", s.urlPrefix, "hist-keys") + mustRegionHeartbeat(re, suite.svr, r) + url = fmt.Sprintf("%s/regions/check/%s", suite.urlPrefix, "hist-keys") r7 := make([]*histItem, 1) - c.Assert(tu.ReadGetJSON(c, testDialClient, url, &r7), IsNil) + suite.NoError(tu.ReadGetJSON(re, testDialClient, url, &r7)) histKeys := []*histItem{{Start: 1000, End: 1999, Count: 1}} - c.Assert(r7, DeepEquals, histKeys) + suite.Equal(histKeys, r7) } -func (s *testRegionSuite) TestRegions(c *C) { +func (suite *regionTestSuite) TestRegions() { rs := []*core.RegionInfo{ newTestRegionInfo(2, 1, []byte("a"), []byte("b")), newTestRegionInfo(3, 1, []byte("b"), []byte("c")), newTestRegionInfo(4, 2, []byte("c"), []byte("d")), } regions := make([]RegionInfo, 0, len(rs)) + re := suite.Require() for _, r := range rs { - regions = append(regions, *NewRegionInfo(r)) - mustRegionHeartbeat(c, s.svr, r) + regions = append(regions, *NewAPIRegionInfo(r)) + mustRegionHeartbeat(re, suite.svr, r) } - url := fmt.Sprintf("%s/regions", s.urlPrefix) + url := fmt.Sprintf("%s/regions", suite.urlPrefix) RegionsInfo := &RegionsInfo{} - err := tu.ReadGetJSON(c, testDialClient, url, RegionsInfo) - c.Assert(err, IsNil) - c.Assert(RegionsInfo.Count, Equals, len(regions)) + err := tu.ReadGetJSON(re, testDialClient, url, RegionsInfo) + suite.NoError(err) + suite.Len(regions, RegionsInfo.Count) sort.Slice(RegionsInfo.Regions, func(i, j int) bool { return RegionsInfo.Regions[i].ID < RegionsInfo.Regions[j].ID }) for i, r := range RegionsInfo.Regions { - c.Assert(r.ID, Equals, regions[i].ID) - c.Assert(r.ApproximateSize, Equals, regions[i].ApproximateSize) - c.Assert(r.ApproximateKeys, Equals, regions[i].ApproximateKeys) + suite.Equal(regions[i].ID, r.ID) + suite.Equal(regions[i].ApproximateSize, r.ApproximateSize) + suite.Equal(regions[i].ApproximateKeys, r.ApproximateKeys) } } -func (s *testRegionSuite) TestStoreRegions(c *C) { +func (suite *regionTestSuite) TestStoreRegions() { + re := suite.Require() r1 := newTestRegionInfo(2, 1, []byte("a"), []byte("b")) r2 := newTestRegionInfo(3, 1, []byte("b"), []byte("c")) r3 := newTestRegionInfo(4, 2, []byte("c"), []byte("d")) - mustRegionHeartbeat(c, s.svr, r1) - mustRegionHeartbeat(c, s.svr, r2) - mustRegionHeartbeat(c, s.svr, r3) + mustRegionHeartbeat(re, suite.svr, r1) + mustRegionHeartbeat(re, suite.svr, r2) + mustRegionHeartbeat(re, suite.svr, r3) regionIDs := []uint64{2, 3} - url := fmt.Sprintf("%s/regions/store/%d", s.urlPrefix, 1) + url := fmt.Sprintf("%s/regions/store/%d", suite.urlPrefix, 1) r4 := &RegionsInfo{} - err := tu.ReadGetJSON(c, testDialClient, url, r4) - c.Assert(err, IsNil) - c.Assert(r4.Count, Equals, len(regionIDs)) + err := tu.ReadGetJSON(re, testDialClient, url, r4) + suite.NoError(err) + suite.Len(regionIDs, r4.Count) sort.Slice(r4.Regions, func(i, j int) bool { return r4.Regions[i].ID < r4.Regions[j].ID }) for i, r := range r4.Regions { - c.Assert(r.ID, Equals, regionIDs[i]) + suite.Equal(regionIDs[i], r.ID) } regionIDs = []uint64{4} - url = fmt.Sprintf("%s/regions/store/%d", s.urlPrefix, 2) + url = fmt.Sprintf("%s/regions/store/%d", suite.urlPrefix, 2) r5 := &RegionsInfo{} - err = tu.ReadGetJSON(c, testDialClient, url, r5) - c.Assert(err, IsNil) - c.Assert(r5.Count, Equals, len(regionIDs)) + err = tu.ReadGetJSON(re, testDialClient, url, r5) + suite.NoError(err) + suite.Len(regionIDs, r5.Count) for i, r := range r5.Regions { - c.Assert(r.ID, Equals, regionIDs[i]) + suite.Equal(regionIDs[i], r.ID) } regionIDs = []uint64{} - url = fmt.Sprintf("%s/regions/store/%d", s.urlPrefix, 3) + url = fmt.Sprintf("%s/regions/store/%d", suite.urlPrefix, 3) r6 := &RegionsInfo{} - err = tu.ReadGetJSON(c, testDialClient, url, r6) - c.Assert(err, IsNil) - c.Assert(r6.Count, Equals, len(regionIDs)) + err = tu.ReadGetJSON(re, testDialClient, url, r6) + suite.NoError(err) + suite.Len(regionIDs, r6.Count) } -func (s *testRegionSuite) TestTopFlow(c *C) { +func (suite *regionTestSuite) TestTopFlow() { + re := suite.Require() r1 := newTestRegionInfo(1, 1, []byte("a"), []byte("b"), core.SetWrittenBytes(1000), core.SetReadBytes(1000), core.SetRegionConfVer(1), core.SetRegionVersion(1)) - mustRegionHeartbeat(c, s.svr, r1) + mustRegionHeartbeat(re, suite.svr, r1) r2 := newTestRegionInfo(2, 1, []byte("b"), []byte("c"), core.SetWrittenBytes(2000), core.SetReadBytes(0), core.SetRegionConfVer(2), core.SetRegionVersion(3)) - mustRegionHeartbeat(c, s.svr, r2) + mustRegionHeartbeat(re, suite.svr, r2) r3 := newTestRegionInfo(3, 1, []byte("c"), []byte("d"), core.SetWrittenBytes(500), core.SetReadBytes(800), core.SetRegionConfVer(3), core.SetRegionVersion(2)) - mustRegionHeartbeat(c, s.svr, r3) - s.checkTopRegions(c, fmt.Sprintf("%s/regions/writeflow", s.urlPrefix), []uint64{2, 1, 3}) - s.checkTopRegions(c, fmt.Sprintf("%s/regions/readflow", s.urlPrefix), []uint64{1, 3, 2}) - s.checkTopRegions(c, fmt.Sprintf("%s/regions/writeflow?limit=2", s.urlPrefix), []uint64{2, 1}) - s.checkTopRegions(c, fmt.Sprintf("%s/regions/confver", s.urlPrefix), []uint64{3, 2, 1}) - s.checkTopRegions(c, fmt.Sprintf("%s/regions/confver?limit=2", s.urlPrefix), []uint64{3, 2}) - s.checkTopRegions(c, fmt.Sprintf("%s/regions/version", s.urlPrefix), []uint64{2, 3, 1}) - s.checkTopRegions(c, fmt.Sprintf("%s/regions/version?limit=2", s.urlPrefix), []uint64{2, 3}) -} - -func (s *testRegionSuite) TestTopSize(c *C) { + mustRegionHeartbeat(re, suite.svr, r3) + suite.checkTopRegions(fmt.Sprintf("%s/regions/writeflow", suite.urlPrefix), []uint64{2, 1, 3}) + suite.checkTopRegions(fmt.Sprintf("%s/regions/readflow", suite.urlPrefix), []uint64{1, 3, 2}) + suite.checkTopRegions(fmt.Sprintf("%s/regions/writeflow?limit=2", suite.urlPrefix), []uint64{2, 1}) + suite.checkTopRegions(fmt.Sprintf("%s/regions/confver", suite.urlPrefix), []uint64{3, 2, 1}) + suite.checkTopRegions(fmt.Sprintf("%s/regions/confver?limit=2", suite.urlPrefix), []uint64{3, 2}) + suite.checkTopRegions(fmt.Sprintf("%s/regions/version", suite.urlPrefix), []uint64{2, 3, 1}) + suite.checkTopRegions(fmt.Sprintf("%s/regions/version?limit=2", suite.urlPrefix), []uint64{2, 3}) +} + +func (suite *regionTestSuite) TestTopSize() { + re := suite.Require() baseOpt := []core.RegionCreateOption{core.SetRegionConfVer(3), core.SetRegionVersion(3)} opt := core.SetApproximateSize(1000) r1 := newTestRegionInfo(7, 1, []byte("a"), []byte("b"), append(baseOpt, opt)...) - mustRegionHeartbeat(c, s.svr, r1) + mustRegionHeartbeat(re, suite.svr, r1) opt = core.SetApproximateSize(900) r2 := newTestRegionInfo(8, 1, []byte("b"), []byte("c"), append(baseOpt, opt)...) - mustRegionHeartbeat(c, s.svr, r2) + mustRegionHeartbeat(re, suite.svr, r2) opt = core.SetApproximateSize(800) r3 := newTestRegionInfo(9, 1, []byte("c"), []byte("d"), append(baseOpt, opt)...) - mustRegionHeartbeat(c, s.svr, r3) + mustRegionHeartbeat(re, suite.svr, r3) // query with limit - s.checkTopRegions(c, fmt.Sprintf("%s/regions/size?limit=%d", s.urlPrefix, 2), []uint64{7, 8}) + suite.checkTopRegions(fmt.Sprintf("%s/regions/size?limit=%d", suite.urlPrefix, 2), []uint64{7, 8}) } -func (s *testRegionSuite) TestAccelerateRegionsScheduleInRange(c *C) { +func (suite *regionTestSuite) TestAccelerateRegionsScheduleInRange() { + re := suite.Require() r1 := newTestRegionInfo(557, 13, []byte("a1"), []byte("a2")) r2 := newTestRegionInfo(558, 14, []byte("a2"), []byte("a3")) r3 := newTestRegionInfo(559, 15, []byte("a3"), []byte("a4")) - mustRegionHeartbeat(c, s.svr, r1) - mustRegionHeartbeat(c, s.svr, r2) - mustRegionHeartbeat(c, s.svr, r3) + mustRegionHeartbeat(re, suite.svr, r1) + mustRegionHeartbeat(re, suite.svr, r2) + mustRegionHeartbeat(re, suite.svr, r3) body := fmt.Sprintf(`{"start_key":"%s", "end_key": "%s"}`, hex.EncodeToString([]byte("a1")), hex.EncodeToString([]byte("a3"))) - err := tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/regions/accelerate-schedule", s.urlPrefix), []byte(body), tu.StatusOK(c)) - c.Assert(err, IsNil) - idList := s.svr.GetRaftCluster().GetSuspectRegions() - c.Assert(idList, HasLen, 2) + err := tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/regions/accelerate-schedule", suite.urlPrefix), []byte(body), tu.StatusOK(re)) + suite.NoError(err) + idList := suite.svr.GetRaftCluster().GetSuspectRegions() + suite.Len(idList, 2) } -func (s *testRegionSuite) TestScatterRegions(c *C) { +func (suite *regionTestSuite) TestScatterRegions() { + re := suite.Require() r1 := newTestRegionInfo(601, 13, []byte("b1"), []byte("b2")) r1.GetMeta().Peers = append(r1.GetMeta().Peers, &metapb.Peer{Id: 5, StoreId: 14}, &metapb.Peer{Id: 6, StoreId: 15}) r2 := newTestRegionInfo(602, 13, []byte("b2"), []byte("b3")) r2.GetMeta().Peers = append(r2.GetMeta().Peers, &metapb.Peer{Id: 7, StoreId: 14}, &metapb.Peer{Id: 8, StoreId: 15}) r3 := newTestRegionInfo(603, 13, []byte("b4"), []byte("b4")) r3.GetMeta().Peers = append(r3.GetMeta().Peers, &metapb.Peer{Id: 9, StoreId: 14}, &metapb.Peer{Id: 10, StoreId: 15}) - mustRegionHeartbeat(c, s.svr, r1) - mustRegionHeartbeat(c, s.svr, r2) - mustRegionHeartbeat(c, s.svr, r3) - mustPutStore(c, s.svr, 13, metapb.StoreState_Up, metapb.NodeState_Serving, []*metapb.StoreLabel{}) - mustPutStore(c, s.svr, 14, metapb.StoreState_Up, metapb.NodeState_Serving, []*metapb.StoreLabel{}) - mustPutStore(c, s.svr, 15, metapb.StoreState_Up, metapb.NodeState_Serving, []*metapb.StoreLabel{}) - mustPutStore(c, s.svr, 16, metapb.StoreState_Up, metapb.NodeState_Serving, []*metapb.StoreLabel{}) + mustRegionHeartbeat(re, suite.svr, r1) + mustRegionHeartbeat(re, suite.svr, r2) + mustRegionHeartbeat(re, suite.svr, r3) + mustPutStore(re, suite.svr, 13, metapb.StoreState_Up, metapb.NodeState_Serving, []*metapb.StoreLabel{}) + mustPutStore(re, suite.svr, 14, metapb.StoreState_Up, metapb.NodeState_Serving, []*metapb.StoreLabel{}) + mustPutStore(re, suite.svr, 15, metapb.StoreState_Up, metapb.NodeState_Serving, []*metapb.StoreLabel{}) + mustPutStore(re, suite.svr, 16, metapb.StoreState_Up, metapb.NodeState_Serving, []*metapb.StoreLabel{}) body := fmt.Sprintf(`{"start_key":"%s", "end_key": "%s"}`, hex.EncodeToString([]byte("b1")), hex.EncodeToString([]byte("b3"))) - err := tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/regions/scatter", s.urlPrefix), []byte(body), tu.StatusOK(c)) - c.Assert(err, IsNil) - op1 := s.svr.GetRaftCluster().GetOperatorController().GetOperator(601) - op2 := s.svr.GetRaftCluster().GetOperatorController().GetOperator(602) - op3 := s.svr.GetRaftCluster().GetOperatorController().GetOperator(603) + err := tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/regions/scatter", suite.urlPrefix), []byte(body), tu.StatusOK(re)) + suite.NoError(err) + op1 := suite.svr.GetRaftCluster().GetOperatorController().GetOperator(601) + op2 := suite.svr.GetRaftCluster().GetOperatorController().GetOperator(602) + op3 := suite.svr.GetRaftCluster().GetOperatorController().GetOperator(603) // At least one operator used to scatter region - c.Assert(op1 != nil || op2 != nil || op3 != nil, IsTrue) + suite.True(op1 != nil || op2 != nil || op3 != nil) body = `{"regions_id": [601, 602, 603]}` - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/regions/scatter", s.urlPrefix), []byte(body), tu.StatusOK(c)) - c.Assert(err, IsNil) + err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/regions/scatter", suite.urlPrefix), []byte(body), tu.StatusOK(re)) + suite.NoError(err) } -func (s *testRegionSuite) TestSplitRegions(c *C) { +func (suite *regionTestSuite) TestSplitRegions() { + re := suite.Require() r1 := newTestRegionInfo(601, 13, []byte("aaa"), []byte("ggg")) r1.GetMeta().Peers = append(r1.GetMeta().Peers, &metapb.Peer{Id: 5, StoreId: 13}, &metapb.Peer{Id: 6, StoreId: 13}) - mustRegionHeartbeat(c, s.svr, r1) - mustPutStore(c, s.svr, 13, metapb.StoreState_Up, metapb.NodeState_Serving, []*metapb.StoreLabel{}) + mustRegionHeartbeat(re, suite.svr, r1) + mustPutStore(re, suite.svr, 13, metapb.StoreState_Up, metapb.NodeState_Serving, []*metapb.StoreLabel{}) newRegionID := uint64(11) body := fmt.Sprintf(`{"retry_limit":%v, "split_keys": ["%s","%s","%s"]}`, 3, hex.EncodeToString([]byte("bbb")), @@ -370,27 +382,27 @@ func (s *testRegionSuite) TestSplitRegions(c *C) { NewRegionsID []uint64 `json:"regions-id"` }{} err := json.Unmarshal(res, s) - c.Assert(err, IsNil) - c.Assert(s.ProcessedPercentage, Equals, 100) - c.Assert(s.NewRegionsID, DeepEquals, []uint64{newRegionID}) + suite.NoError(err) + suite.Equal(100, s.ProcessedPercentage) + suite.Equal([]uint64{newRegionID}, s.NewRegionsID) } - c.Assert(failpoint.Enable("github.com/tikv/pd/server/api/splitResponses", fmt.Sprintf("return(%v)", newRegionID)), IsNil) - err := tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/regions/split", s.urlPrefix), []byte(body), checkOpt) - c.Assert(failpoint.Disable("github.com/tikv/pd/server/api/splitResponses"), IsNil) - c.Assert(err, IsNil) + suite.NoError(failpoint.Enable("github.com/tikv/pd/server/api/splitResponses", fmt.Sprintf("return(%v)", newRegionID))) + err := tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/regions/split", suite.urlPrefix), []byte(body), checkOpt) + suite.NoError(failpoint.Disable("github.com/tikv/pd/server/api/splitResponses")) + suite.NoError(err) } -func (s *testRegionSuite) checkTopRegions(c *C, url string, regionIDs []uint64) { +func (suite *regionTestSuite) checkTopRegions(url string, regionIDs []uint64) { regions := &RegionsInfo{} - err := tu.ReadGetJSON(c, testDialClient, url, regions) - c.Assert(err, IsNil) - c.Assert(regions.Count, Equals, len(regionIDs)) + err := tu.ReadGetJSON(suite.Require(), testDialClient, url, regions) + suite.NoError(err) + suite.Len(regionIDs, regions.Count) for i, r := range regions.Regions { - c.Assert(r.ID, Equals, regionIDs[i]) + suite.Equal(regionIDs[i], r.ID) } } -func (s *testRegionSuite) TestTopN(c *C) { +func (suite *regionTestSuite) TestTopN() { writtenBytes := []uint64{10, 10, 9, 5, 3, 2, 2, 1, 0, 0} for n := 0; n <= len(writtenBytes)+1; n++ { regions := make([]*core.RegionInfo, 0, len(writtenBytes)) @@ -401,138 +413,150 @@ func (s *testRegionSuite) TestTopN(c *C) { } topN := TopNRegions(regions, func(a, b *core.RegionInfo) bool { return a.GetBytesWritten() < b.GetBytesWritten() }, n) if n > len(writtenBytes) { - c.Assert(len(topN), Equals, len(writtenBytes)) + suite.Len(writtenBytes, len(topN)) } else { - c.Assert(topN, HasLen, n) + suite.Len(topN, n) } for i := range topN { - c.Assert(topN[i].GetBytesWritten(), Equals, writtenBytes[i]) + suite.Equal(writtenBytes[i], topN[i].GetBytesWritten()) } } } -var _ = Suite(&testGetRegionSuite{}) - -type testGetRegionSuite struct { +type getRegionTestSuite struct { + suite.Suite svr *server.Server cleanup cleanUpFunc urlPrefix string } -func (s *testGetRegionSuite) SetUpSuite(c *C) { - s.svr, s.cleanup = mustNewServer(c) - mustWaitLeader(c, []*server.Server{s.svr}) +func TestGetRegionTestSuite(t *testing.T) { + suite.Run(t, new(getRegionTestSuite)) +} + +func (suite *getRegionTestSuite) SetupSuite() { + re := suite.Require() + suite.svr, suite.cleanup = mustNewServer(re) + mustWaitLeader(re, []*server.Server{suite.svr}) - addr := s.svr.GetAddr() - s.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) + addr := suite.svr.GetAddr() + suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) - mustBootstrapCluster(c, s.svr) + mustBootstrapCluster(re, suite.svr) } -func (s *testGetRegionSuite) TearDownSuite(c *C) { - s.cleanup() +func (suite *getRegionTestSuite) TearDownSuite() { + suite.cleanup() } -func (s *testGetRegionSuite) TestRegionKey(c *C) { +func (suite *getRegionTestSuite) TestRegionKey() { + re := suite.Require() r := newTestRegionInfo(99, 1, []byte{0xFF, 0xFF, 0xAA}, []byte{0xFF, 0xFF, 0xCC}, core.SetWrittenBytes(500), core.SetReadBytes(800), core.SetRegionConfVer(3), core.SetRegionVersion(2)) - mustRegionHeartbeat(c, s.svr, r) - url := fmt.Sprintf("%s/region/key/%s", s.urlPrefix, url.QueryEscape(string([]byte{0xFF, 0xFF, 0xBB}))) + mustRegionHeartbeat(re, suite.svr, r) + url := fmt.Sprintf("%s/region/key/%s", suite.urlPrefix, url.QueryEscape(string([]byte{0xFF, 0xFF, 0xBB}))) RegionInfo := &RegionInfo{} - err := tu.ReadGetJSON(c, testDialClient, url, RegionInfo) - c.Assert(err, IsNil) - c.Assert(r.GetID(), Equals, RegionInfo.ID) + err := tu.ReadGetJSON(re, testDialClient, url, RegionInfo) + suite.NoError(err) + suite.Equal(RegionInfo.ID, r.GetID()) } -func (s *testGetRegionSuite) TestScanRegionByKeys(c *C) { +func (suite *getRegionTestSuite) TestScanRegionByKeys() { + re := suite.Require() r1 := newTestRegionInfo(2, 1, []byte("a"), []byte("b")) r2 := newTestRegionInfo(3, 1, []byte("b"), []byte("c")) r3 := newTestRegionInfo(4, 2, []byte("c"), []byte("e")) r4 := newTestRegionInfo(5, 2, []byte("x"), []byte("z")) r := newTestRegionInfo(99, 1, []byte{0xFF, 0xFF, 0xAA}, []byte{0xFF, 0xFF, 0xCC}, core.SetWrittenBytes(500), core.SetReadBytes(800), core.SetRegionConfVer(3), core.SetRegionVersion(2)) - mustRegionHeartbeat(c, s.svr, r1) - mustRegionHeartbeat(c, s.svr, r2) - mustRegionHeartbeat(c, s.svr, r3) - mustRegionHeartbeat(c, s.svr, r4) - mustRegionHeartbeat(c, s.svr, r) + mustRegionHeartbeat(re, suite.svr, r1) + mustRegionHeartbeat(re, suite.svr, r2) + mustRegionHeartbeat(re, suite.svr, r3) + mustRegionHeartbeat(re, suite.svr, r4) + mustRegionHeartbeat(re, suite.svr, r) - url := fmt.Sprintf("%s/regions/key?key=%s", s.urlPrefix, "b") + url := fmt.Sprintf("%s/regions/key?key=%s", suite.urlPrefix, "b") regionIDs := []uint64{3, 4, 5, 99} regions := &RegionsInfo{} - err := tu.ReadGetJSON(c, testDialClient, url, regions) - c.Assert(err, IsNil) - c.Assert(regionIDs, HasLen, regions.Count) + err := tu.ReadGetJSON(re, testDialClient, url, regions) + suite.NoError(err) + suite.Len(regionIDs, regions.Count) for i, v := range regionIDs { - c.Assert(v, Equals, regions.Regions[i].ID) + suite.Equal(regions.Regions[i].ID, v) } - url = fmt.Sprintf("%s/regions/key?key=%s", s.urlPrefix, "d") + url = fmt.Sprintf("%s/regions/key?key=%s", suite.urlPrefix, "d") regionIDs = []uint64{4, 5, 99} regions = &RegionsInfo{} - err = tu.ReadGetJSON(c, testDialClient, url, regions) - c.Assert(err, IsNil) - c.Assert(regionIDs, HasLen, regions.Count) + err = tu.ReadGetJSON(re, testDialClient, url, regions) + suite.NoError(err) + suite.Len(regionIDs, regions.Count) for i, v := range regionIDs { - c.Assert(v, Equals, regions.Regions[i].ID) + suite.Equal(regions.Regions[i].ID, v) } - url = fmt.Sprintf("%s/regions/key?key=%s", s.urlPrefix, "g") + url = fmt.Sprintf("%s/regions/key?key=%s", suite.urlPrefix, "g") regionIDs = []uint64{5, 99} regions = &RegionsInfo{} - err = tu.ReadGetJSON(c, testDialClient, url, regions) - c.Assert(err, IsNil) - c.Assert(regionIDs, HasLen, regions.Count) + err = tu.ReadGetJSON(re, testDialClient, url, regions) + suite.NoError(err) + suite.Len(regionIDs, regions.Count) for i, v := range regionIDs { - c.Assert(v, Equals, regions.Regions[i].ID) + suite.Equal(regions.Regions[i].ID, v) } - url = fmt.Sprintf("%s/regions/key?end_key=%s", s.urlPrefix, "e") + url = fmt.Sprintf("%s/regions/key?end_key=%s", suite.urlPrefix, "e") regionIDs = []uint64{2, 3, 4} regions = &RegionsInfo{} - err = tu.ReadGetJSON(c, testDialClient, url, regions) - c.Assert(err, IsNil) - c.Assert(len(regionIDs), Equals, regions.Count) + err = tu.ReadGetJSON(re, testDialClient, url, regions) + suite.NoError(err) + suite.Equal(regions.Count, len(regionIDs)) for i, v := range regionIDs { - c.Assert(v, Equals, regions.Regions[i].ID) + suite.Equal(regions.Regions[i].ID, v) } - url = fmt.Sprintf("%s/regions/key?key=%s&end_key=%s", s.urlPrefix, "b", "g") + url = fmt.Sprintf("%s/regions/key?key=%s&end_key=%s", suite.urlPrefix, "b", "g") regionIDs = []uint64{3, 4} regions = &RegionsInfo{} - err = tu.ReadGetJSON(c, testDialClient, url, regions) - c.Assert(err, IsNil) - c.Assert(len(regionIDs), Equals, regions.Count) + err = tu.ReadGetJSON(re, testDialClient, url, regions) + suite.NoError(err) + suite.Equal(regions.Count, len(regionIDs)) for i, v := range regionIDs { - c.Assert(v, Equals, regions.Regions[i].ID) + suite.Equal(regions.Regions[i].ID, v) } - url = fmt.Sprintf("%s/regions/key?key=%s&end_key=%s", s.urlPrefix, "b", []byte{0xFF, 0xFF, 0xCC}) + url = fmt.Sprintf("%s/regions/key?key=%s&end_key=%s", suite.urlPrefix, "b", []byte{0xFF, 0xFF, 0xCC}) regionIDs = []uint64{3, 4, 5, 99} regions = &RegionsInfo{} - err = tu.ReadGetJSON(c, testDialClient, url, regions) - c.Assert(err, IsNil) - c.Assert(len(regionIDs), Equals, regions.Count) + err = tu.ReadGetJSON(re, testDialClient, url, regions) + suite.NoError(err) + suite.Equal(regions.Count, len(regionIDs)) for i, v := range regionIDs { - c.Assert(v, Equals, regions.Regions[i].ID) + suite.Equal(regions.Regions[i].ID, v) } } // Start a new test suite to prevent from being interfered by other tests. -var _ = Suite(&testGetRegionRangeHolesSuite{}) -type testGetRegionRangeHolesSuite struct { +type getRegionRangeHolesTestSuite struct { + suite.Suite svr *server.Server cleanup cleanUpFunc urlPrefix string } -func (s *testGetRegionRangeHolesSuite) SetUpSuite(c *C) { - s.svr, s.cleanup = mustNewServer(c) - mustWaitLeader(c, []*server.Server{s.svr}) - addr := s.svr.GetAddr() - s.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) - mustBootstrapCluster(c, s.svr) +func TestGetRegionRangeHolesTestSuite(t *testing.T) { + suite.Run(t, new(getRegionRangeHolesTestSuite)) } -func (s *testGetRegionRangeHolesSuite) TearDownSuite(c *C) { - s.cleanup() +func (suite *getRegionRangeHolesTestSuite) SetupSuite() { + re := suite.Require() + suite.svr, suite.cleanup = mustNewServer(re) + mustWaitLeader(re, []*server.Server{suite.svr}) + addr := suite.svr.GetAddr() + suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) + mustBootstrapCluster(re, suite.svr) } -func (s *testGetRegionRangeHolesSuite) TestRegionRangeHoles(c *C) { +func (suite *getRegionRangeHolesTestSuite) TearDownSuite() { + suite.cleanup() +} + +func (suite *getRegionRangeHolesTestSuite) TestRegionRangeHoles() { + re := suite.Require() // Missing r0 with range [0, 0xEA] r1 := newTestRegionInfo(2, 1, []byte{0xEA}, []byte{0xEB}) // Missing r2 with range [0xEB, 0xEC] @@ -540,54 +564,59 @@ func (s *testGetRegionRangeHolesSuite) TestRegionRangeHoles(c *C) { r4 := newTestRegionInfo(4, 2, []byte{0xED}, []byte{0xEE}) // Missing r5 with range [0xEE, 0xFE] r6 := newTestRegionInfo(5, 2, []byte{0xFE}, []byte{0xFF}) - mustRegionHeartbeat(c, s.svr, r1) - mustRegionHeartbeat(c, s.svr, r3) - mustRegionHeartbeat(c, s.svr, r4) - mustRegionHeartbeat(c, s.svr, r6) + mustRegionHeartbeat(re, suite.svr, r1) + mustRegionHeartbeat(re, suite.svr, r3) + mustRegionHeartbeat(re, suite.svr, r4) + mustRegionHeartbeat(re, suite.svr, r6) - url := fmt.Sprintf("%s/regions/range-holes", s.urlPrefix) + url := fmt.Sprintf("%s/regions/range-holes", suite.urlPrefix) rangeHoles := new([][]string) - c.Assert(tu.ReadGetJSON(c, testDialClient, url, rangeHoles), IsNil) - c.Assert(*rangeHoles, DeepEquals, [][]string{ + suite.NoError(tu.ReadGetJSON(re, testDialClient, url, rangeHoles)) + suite.Equal([][]string{ {"", core.HexRegionKeyStr(r1.GetStartKey())}, {core.HexRegionKeyStr(r1.GetEndKey()), core.HexRegionKeyStr(r3.GetStartKey())}, {core.HexRegionKeyStr(r4.GetEndKey()), core.HexRegionKeyStr(r6.GetStartKey())}, {core.HexRegionKeyStr(r6.GetEndKey()), ""}, - }) + }, *rangeHoles) } -var _ = Suite(&testRegionsReplicatedSuite{}) - -type testRegionsReplicatedSuite struct { +type regionsReplicatedTestSuite struct { + suite.Suite svr *server.Server cleanup cleanUpFunc urlPrefix string } -func (s *testRegionsReplicatedSuite) SetUpSuite(c *C) { - s.svr, s.cleanup = mustNewServer(c) - mustWaitLeader(c, []*server.Server{s.svr}) +func TestRegionsReplicatedTestSuite(t *testing.T) { + suite.Run(t, new(regionsReplicatedTestSuite)) +} - addr := s.svr.GetAddr() - s.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) +func (suite *regionsReplicatedTestSuite) SetupSuite() { + re := suite.Require() + suite.svr, suite.cleanup = mustNewServer(re) + mustWaitLeader(re, []*server.Server{suite.svr}) - mustBootstrapCluster(c, s.svr) + addr := suite.svr.GetAddr() + suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) + + mustBootstrapCluster(re, suite.svr) } -func (s *testRegionsReplicatedSuite) TearDownSuite(c *C) { - s.cleanup() +func (suite *regionsReplicatedTestSuite) TearDownSuite() { + suite.cleanup() } -func (s *testRegionsReplicatedSuite) TestCheckRegionsReplicated(c *C) { +func (suite *regionsReplicatedTestSuite) TestCheckRegionsReplicated() { + re := suite.Require() // enable placement rule - c.Assert(tu.CheckPostJSON(testDialClient, s.urlPrefix+"/config", []byte(`{"enable-placement-rules":"true"}`), tu.StatusOK(c)), IsNil) + suite.NoError(tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/config", []byte(`{"enable-placement-rules":"true"}`), tu.StatusOK(re))) defer func() { - c.Assert(tu.CheckPostJSON(testDialClient, s.urlPrefix+"/config", []byte(`{"enable-placement-rules":"false"}`), tu.StatusOK(c)), IsNil) + suite.NoError(tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/config", []byte(`{"enable-placement-rules":"false"}`), tu.StatusOK(re))) }() // add test region r1 := newTestRegionInfo(2, 1, []byte("a"), []byte("b")) - mustRegionHeartbeat(c, s.svr, r1) + mustRegionHeartbeat(re, suite.svr, r1) // set the bundle bundle := []placement.GroupBundle{ @@ -605,49 +634,48 @@ func (s *testRegionsReplicatedSuite) TestCheckRegionsReplicated(c *C) { status := "" // invalid url - url := fmt.Sprintf(`%s/regions/replicated?startKey=%s&endKey=%s`, s.urlPrefix, "_", "t") - err := tu.CheckGetJSON(testDialClient, url, nil, tu.Status(c, http.StatusBadRequest)) - c.Assert(err, IsNil) + url := fmt.Sprintf(`%s/regions/replicated?startKey=%s&endKey=%s`, suite.urlPrefix, "_", "t") + err := tu.CheckGetJSON(testDialClient, url, nil, tu.Status(re, http.StatusBadRequest)) + suite.NoError(err) - url = fmt.Sprintf(`%s/regions/replicated?startKey=%s&endKey=%s`, s.urlPrefix, hex.EncodeToString(r1.GetStartKey()), "_") - err = tu.CheckGetJSON(testDialClient, url, nil, tu.Status(c, http.StatusBadRequest)) - c.Assert(err, IsNil) + url = fmt.Sprintf(`%s/regions/replicated?startKey=%s&endKey=%s`, suite.urlPrefix, hex.EncodeToString(r1.GetStartKey()), "_") + err = tu.CheckGetJSON(testDialClient, url, nil, tu.Status(re, http.StatusBadRequest)) + suite.NoError(err) // correct test - url = fmt.Sprintf(`%s/regions/replicated?startKey=%s&endKey=%s`, s.urlPrefix, hex.EncodeToString(r1.GetStartKey()), hex.EncodeToString(r1.GetEndKey())) + url = fmt.Sprintf(`%s/regions/replicated?startKey=%s&endKey=%s`, suite.urlPrefix, hex.EncodeToString(r1.GetStartKey()), hex.EncodeToString(r1.GetEndKey())) // test one rule data, err := json.Marshal(bundle) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/config/placement-rule", data, tu.StatusOK(c)) - c.Assert(err, IsNil) - - err = tu.ReadGetJSON(c, testDialClient, url, &status) - c.Assert(err, IsNil) - c.Assert(status, Equals, "REPLICATED") - - c.Assert(failpoint.Enable("github.com/tikv/pd/server/api/mockPending", "return(true)"), IsNil) - err = tu.ReadGetJSON(c, testDialClient, url, &status) - c.Assert(err, IsNil) - c.Assert(status, Equals, "PENDING") - c.Assert(failpoint.Disable("github.com/tikv/pd/server/api/mockPending"), IsNil) - + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) + suite.NoError(err) + + err = tu.ReadGetJSON(re, testDialClient, url, &status) + suite.NoError(err) + suite.Equal("REPLICATED", status) + + suite.NoError(failpoint.Enable("github.com/tikv/pd/server/api/mockPending", "return(true)")) + err = tu.ReadGetJSON(re, testDialClient, url, &status) + suite.NoError(err) + suite.Equal("PENDING", status) + suite.NoError(failpoint.Disable("github.com/tikv/pd/server/api/mockPending")) // test multiple rules r1 = newTestRegionInfo(2, 1, []byte("a"), []byte("b")) r1.GetMeta().Peers = append(r1.GetMeta().Peers, &metapb.Peer{Id: 5, StoreId: 1}) - mustRegionHeartbeat(c, s.svr, r1) + mustRegionHeartbeat(re, suite.svr, r1) bundle[0].Rules = append(bundle[0].Rules, &placement.Rule{ ID: "bar", Index: 1, Role: "voter", Count: 1, }) data, err = json.Marshal(bundle) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/config/placement-rule", data, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) + suite.NoError(err) - err = tu.ReadGetJSON(c, testDialClient, url, &status) - c.Assert(err, IsNil) - c.Assert(status, Equals, "REPLICATED") + err = tu.ReadGetJSON(re, testDialClient, url, &status) + suite.NoError(err) + suite.Equal("REPLICATED", status) // test multiple bundles bundle = append(bundle, placement.GroupBundle{ @@ -660,21 +688,21 @@ func (s *testRegionsReplicatedSuite) TestCheckRegionsReplicated(c *C) { }, }) data, err = json.Marshal(bundle) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/config/placement-rule", data, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) + suite.NoError(err) - err = tu.ReadGetJSON(c, testDialClient, url, &status) - c.Assert(err, IsNil) - c.Assert(status, Equals, "INPROGRESS") + err = tu.ReadGetJSON(re, testDialClient, url, &status) + suite.NoError(err) + suite.Equal("INPROGRESS", status) r1 = newTestRegionInfo(2, 1, []byte("a"), []byte("b")) r1.GetMeta().Peers = append(r1.GetMeta().Peers, &metapb.Peer{Id: 5, StoreId: 1}, &metapb.Peer{Id: 6, StoreId: 1}, &metapb.Peer{Id: 7, StoreId: 1}) - mustRegionHeartbeat(c, s.svr, r1) + mustRegionHeartbeat(re, suite.svr, r1) - err = tu.ReadGetJSON(c, testDialClient, url, &status) - c.Assert(err, IsNil) - c.Assert(status, Equals, "REPLICATED") + err = tu.ReadGetJSON(re, testDialClient, url, &status) + suite.NoError(err) + suite.Equal("REPLICATED", status) } // Create n regions (0..n) of n stores (0..n). diff --git a/server/api/rule_test.go b/server/api/rule_test.go index 7e878168906..63af8b19c1c 100644 --- a/server/api/rule_test.go +++ b/server/api/rule_test.go @@ -20,42 +20,47 @@ import ( "fmt" "net/http" "net/url" + "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/apiutil" tu "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server" "github.com/tikv/pd/server/schedule/placement" ) -var _ = Suite(&testRuleSuite{}) - -type testRuleSuite struct { +type ruleTestSuite struct { + suite.Suite svr *server.Server cleanup cleanUpFunc urlPrefix string } -func (s *testRuleSuite) SetUpSuite(c *C) { - s.svr, s.cleanup = mustNewServer(c) - mustWaitLeader(c, []*server.Server{s.svr}) +func TestRuleTestSuite(t *testing.T) { + suite.Run(t, new(ruleTestSuite)) +} + +func (suite *ruleTestSuite) SetupSuite() { + re := suite.Require() + suite.svr, suite.cleanup = mustNewServer(re) + mustWaitLeader(re, []*server.Server{suite.svr}) - addr := s.svr.GetAddr() - s.urlPrefix = fmt.Sprintf("%s%s/api/v1/config", addr, apiPrefix) + addr := suite.svr.GetAddr() + suite.urlPrefix = fmt.Sprintf("%s%s/api/v1/config", addr, apiPrefix) - mustBootstrapCluster(c, s.svr) - PDServerCfg := s.svr.GetConfig().PDServerCfg + mustBootstrapCluster(re, suite.svr) + PDServerCfg := suite.svr.GetConfig().PDServerCfg PDServerCfg.KeyType = "raw" - err := s.svr.SetPDServerConfig(PDServerCfg) - c.Assert(err, IsNil) - c.Assert(tu.CheckPostJSON(testDialClient, s.urlPrefix, []byte(`{"enable-placement-rules":"true"}`), tu.StatusOK(c)), IsNil) + err := suite.svr.SetPDServerConfig(PDServerCfg) + suite.NoError(err) + suite.NoError(tu.CheckPostJSON(testDialClient, suite.urlPrefix, []byte(`{"enable-placement-rules":"true"}`), tu.StatusOK(re))) } -func (s *testRuleSuite) TearDownSuite(c *C) { - s.cleanup() +func (suite *ruleTestSuite) TearDownSuite() { + suite.cleanup() } -func (s *testRuleSuite) TearDownTest(c *C) { +func (suite *ruleTestSuite) TearDownTest() { def := placement.GroupBundle{ ID: "pd", Rules: []*placement.Rule{ @@ -63,35 +68,35 @@ func (s *testRuleSuite) TearDownTest(c *C) { }, } data, err := json.Marshal([]placement.GroupBundle{def}) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/placement-rule", data, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/placement-rule", data, tu.StatusOK(suite.Require())) + suite.NoError(err) } -func (s *testRuleSuite) TestSet(c *C) { +func (suite *ruleTestSuite) TestSet() { rule := placement.Rule{GroupID: "a", ID: "10", StartKeyHex: "1111", EndKeyHex: "3333", Role: "voter", Count: 1} successData, err := json.Marshal(rule) - c.Assert(err, IsNil) + suite.NoError(err) oldStartKey, err := hex.DecodeString(rule.StartKeyHex) - c.Assert(err, IsNil) + suite.NoError(err) oldEndKey, err := hex.DecodeString(rule.EndKeyHex) - c.Assert(err, IsNil) + suite.NoError(err) parseErrData := []byte("foo") rule1 := placement.Rule{GroupID: "a", ID: "10", StartKeyHex: "XXXX", EndKeyHex: "3333", Role: "voter", Count: 1} checkErrData, err := json.Marshal(rule1) - c.Assert(err, IsNil) + suite.NoError(err) rule2 := placement.Rule{GroupID: "a", ID: "10", StartKeyHex: "1111", EndKeyHex: "3333", Role: "voter", Count: -1} setErrData, err := json.Marshal(rule2) - c.Assert(err, IsNil) + suite.NoError(err) rule3 := placement.Rule{GroupID: "a", ID: "10", StartKeyHex: "1111", EndKeyHex: "3333", Role: "follower", Count: 3} updateData, err := json.Marshal(rule3) - c.Assert(err, IsNil) + suite.NoError(err) newStartKey, err := hex.DecodeString(rule.StartKeyHex) - c.Assert(err, IsNil) + suite.NoError(err) newEndKey, err := hex.DecodeString(rule.EndKeyHex) - c.Assert(err, IsNil) + suite.NoError(err) - testcases := []struct { + testCases := []struct { name string rawData []byte success bool @@ -148,42 +153,43 @@ func (s *testRuleSuite) TestSet(c *C) { `, }, } - - for _, testcase := range testcases { - c.Log(testcase.name) + re := suite.Require() + for _, testCase := range testCases { + suite.T().Log(testCase.name) // clear suspect keyRanges to prevent test case from others - s.svr.GetRaftCluster().ClearSuspectKeyRanges() - if testcase.success { - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/rule", testcase.rawData, tu.StatusOK(c)) + suite.svr.GetRaftCluster().ClearSuspectKeyRanges() + if testCase.success { + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rule", testCase.rawData, tu.StatusOK(re)) popKeyRangeMap := map[string]struct{}{} - for i := 0; i < len(testcase.popKeyRange)/2; i++ { - v, got := s.svr.GetRaftCluster().PopOneSuspectKeyRange() - c.Assert(got, IsTrue) + for i := 0; i < len(testCase.popKeyRange)/2; i++ { + v, got := suite.svr.GetRaftCluster().PopOneSuspectKeyRange() + suite.True(got) popKeyRangeMap[hex.EncodeToString(v[0])] = struct{}{} popKeyRangeMap[hex.EncodeToString(v[1])] = struct{}{} } - c.Assert(len(popKeyRangeMap), Equals, len(testcase.popKeyRange)) + suite.Len(testCase.popKeyRange, len(popKeyRangeMap)) for k := range popKeyRangeMap { - _, ok := testcase.popKeyRange[k] - c.Assert(ok, IsTrue) + _, ok := testCase.popKeyRange[k] + suite.True(ok) } } else { - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/rule", testcase.rawData, - tu.StatusNotOK(c), - tu.StringEqual(c, testcase.response)) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rule", testCase.rawData, + tu.StatusNotOK(re), + tu.StringEqual(re, testCase.response)) } - c.Assert(err, IsNil) + suite.NoError(err) } } -func (s *testRuleSuite) TestGet(c *C) { +func (suite *ruleTestSuite) TestGet() { rule := placement.Rule{GroupID: "a", ID: "20", StartKeyHex: "1111", EndKeyHex: "3333", Role: "voter", Count: 1} data, err := json.Marshal(rule) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/rule", data, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + re := suite.Require() + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rule", data, tu.StatusOK(re)) + suite.NoError(err) - testcases := []struct { + testCases := []struct { name string rule placement.Rule found bool @@ -202,34 +208,35 @@ func (s *testRuleSuite) TestGet(c *C) { code: 404, }, } - for _, testcase := range testcases { - c.Log(testcase.name) + for _, testCase := range testCases { + suite.T().Log(testCase.name) var resp placement.Rule - url := fmt.Sprintf("%s/rule/%s/%s", s.urlPrefix, testcase.rule.GroupID, testcase.rule.ID) - if testcase.found { - err = tu.ReadGetJSON(c, testDialClient, url, &resp) - compareRule(c, &resp, &testcase.rule) + url := fmt.Sprintf("%s/rule/%s/%s", suite.urlPrefix, testCase.rule.GroupID, testCase.rule.ID) + if testCase.found { + err = tu.ReadGetJSON(re, testDialClient, url, &resp) + suite.compareRule(&resp, &testCase.rule) } else { - err = tu.CheckGetJSON(testDialClient, url, nil, tu.Status(c, testcase.code)) + err = tu.CheckGetJSON(testDialClient, url, nil, tu.Status(re, testCase.code)) } - c.Assert(err, IsNil) + suite.NoError(err) } } -func (s *testRuleSuite) TestGetAll(c *C) { +func (suite *ruleTestSuite) TestGetAll() { rule := placement.Rule{GroupID: "b", ID: "20", StartKeyHex: "1111", EndKeyHex: "3333", Role: "voter", Count: 1} data, err := json.Marshal(rule) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/rule", data, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + re := suite.Require() + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rule", data, tu.StatusOK(re)) + suite.NoError(err) var resp2 []*placement.Rule - err = tu.ReadGetJSON(c, testDialClient, s.urlPrefix+"/rules", &resp2) - c.Assert(err, IsNil) - c.Assert(len(resp2), GreaterEqual, 1) + err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/rules", &resp2) + suite.NoError(err) + suite.GreaterOrEqual(len(resp2), 1) } -func (s *testRuleSuite) TestSetAll(c *C) { +func (suite *ruleTestSuite) TestSetAll() { rule1 := placement.Rule{GroupID: "a", ID: "12", StartKeyHex: "1111", EndKeyHex: "3333", Role: "voter", Count: 1} rule2 := placement.Rule{GroupID: "b", ID: "12", StartKeyHex: "1111", EndKeyHex: "3333", Role: "voter", Count: 1} rule3 := placement.Rule{GroupID: "a", ID: "12", StartKeyHex: "XXXX", EndKeyHex: "3333", Role: "voter", Count: 1} @@ -238,27 +245,27 @@ func (s *testRuleSuite) TestSetAll(c *C) { LocationLabels: []string{"host"}} rule6 := placement.Rule{GroupID: "pd", ID: "default", StartKeyHex: "", EndKeyHex: "", Role: "voter", Count: 3} - s.svr.GetPersistOptions().GetReplicationConfig().LocationLabels = []string{"host"} - defaultRule := s.svr.GetRaftCluster().GetRuleManager().GetRule("pd", "default") + suite.svr.GetPersistOptions().GetReplicationConfig().LocationLabels = []string{"host"} + defaultRule := suite.svr.GetRaftCluster().GetRuleManager().GetRule("pd", "default") defaultRule.LocationLabels = []string{"host"} - s.svr.GetRaftCluster().GetRuleManager().SetRule(defaultRule) + suite.svr.GetRaftCluster().GetRuleManager().SetRule(defaultRule) successData, err := json.Marshal([]*placement.Rule{&rule1, &rule2}) - c.Assert(err, IsNil) + suite.NoError(err) checkErrData, err := json.Marshal([]*placement.Rule{&rule1, &rule3}) - c.Assert(err, IsNil) + suite.NoError(err) setErrData, err := json.Marshal([]*placement.Rule{&rule1, &rule4}) - c.Assert(err, IsNil) + suite.NoError(err) defaultData, err := json.Marshal([]*placement.Rule{&rule1, &rule5}) - c.Assert(err, IsNil) + suite.NoError(err) recoverData, err := json.Marshal([]*placement.Rule{&rule1, &rule6}) - c.Assert(err, IsNil) + suite.NoError(err) - testcases := []struct { + testCases := []struct { name string rawData []byte success bool @@ -320,38 +327,38 @@ func (s *testRuleSuite) TestSetAll(c *C) { count: 3, }, } - - for _, testcase := range testcases { - c.Log(testcase.name) - - if testcase.success { - err := tu.CheckPostJSON(testDialClient, s.urlPrefix+"/rules", testcase.rawData, tu.StatusOK(c)) - c.Assert(err, IsNil) - if testcase.isDefaultRule { - c.Assert(testcase.count, Equals, int(s.svr.GetPersistOptions().GetReplicationConfig().MaxReplicas)) + re := suite.Require() + for _, testCase := range testCases { + suite.T().Log(testCase.name) + if testCase.success { + err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rules", testCase.rawData, tu.StatusOK(re)) + suite.NoError(err) + if testCase.isDefaultRule { + suite.Equal(int(suite.svr.GetPersistOptions().GetReplicationConfig().MaxReplicas), testCase.count) } } else { - err := tu.CheckPostJSON(testDialClient, s.urlPrefix+"/rules", testcase.rawData, - tu.StringEqual(c, testcase.response)) - c.Assert(err, IsNil) + err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rules", testCase.rawData, + tu.StringEqual(re, testCase.response)) + suite.NoError(err) } } } -func (s *testRuleSuite) TestGetAllByGroup(c *C) { +func (suite *ruleTestSuite) TestGetAllByGroup() { + re := suite.Require() rule := placement.Rule{GroupID: "c", ID: "20", StartKeyHex: "1111", EndKeyHex: "3333", Role: "voter", Count: 1} data, err := json.Marshal(rule) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/rule", data, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rule", data, tu.StatusOK(re)) + suite.NoError(err) rule1 := placement.Rule{GroupID: "c", ID: "30", StartKeyHex: "1111", EndKeyHex: "3333", Role: "voter", Count: 1} data, err = json.Marshal(rule1) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/rule", data, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rule", data, tu.StatusOK(re)) + suite.NoError(err) - testcases := []struct { + testCases := []struct { name string groupID string count int @@ -368,31 +375,32 @@ func (s *testRuleSuite) TestGetAllByGroup(c *C) { }, } - for _, testcase := range testcases { - c.Log(testcase.name) + for _, testCase := range testCases { + suite.T().Log(testCase.name) var resp []*placement.Rule - url := fmt.Sprintf("%s/rules/group/%s", s.urlPrefix, testcase.groupID) - err = tu.ReadGetJSON(c, testDialClient, url, &resp) - c.Assert(err, IsNil) - c.Assert(resp, HasLen, testcase.count) - if testcase.count == 2 { - compareRule(c, resp[0], &rule) - compareRule(c, resp[1], &rule1) + url := fmt.Sprintf("%s/rules/group/%s", suite.urlPrefix, testCase.groupID) + err = tu.ReadGetJSON(re, testDialClient, url, &resp) + suite.NoError(err) + suite.Len(resp, testCase.count) + if testCase.count == 2 { + suite.compareRule(resp[0], &rule) + suite.compareRule(resp[1], &rule1) } } } -func (s *testRuleSuite) TestGetAllByRegion(c *C) { +func (suite *ruleTestSuite) TestGetAllByRegion() { rule := placement.Rule{GroupID: "e", ID: "20", StartKeyHex: "1111", EndKeyHex: "3333", Role: "voter", Count: 1} data, err := json.Marshal(rule) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/rule", data, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + re := suite.Require() + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rule", data, tu.StatusOK(re)) + suite.NoError(err) r := newTestRegionInfo(4, 1, []byte{0x22, 0x22}, []byte{0x33, 0x33}) - mustRegionHeartbeat(c, s.svr, r) + mustRegionHeartbeat(re, suite.svr, r) - testcases := []struct { + testCases := []struct { name string regionID string success bool @@ -416,33 +424,34 @@ func (s *testRuleSuite) TestGetAllByRegion(c *C) { code: 404, }, } - for _, testcase := range testcases { - c.Log(testcase.name) + for _, testCase := range testCases { + suite.T().Log(testCase.name) var resp []*placement.Rule - url := fmt.Sprintf("%s/rules/region/%s", s.urlPrefix, testcase.regionID) + url := fmt.Sprintf("%s/rules/region/%s", suite.urlPrefix, testCase.regionID) - if testcase.success { - err = tu.ReadGetJSON(c, testDialClient, url, &resp) + if testCase.success { + err = tu.ReadGetJSON(re, testDialClient, url, &resp) for _, r := range resp { if r.GroupID == "e" { - compareRule(c, r, &rule) + suite.compareRule(r, &rule) } } } else { - err = tu.CheckGetJSON(testDialClient, url, nil, tu.Status(c, testcase.code)) + err = tu.CheckGetJSON(testDialClient, url, nil, tu.Status(re, testCase.code)) } - c.Assert(err, IsNil) + suite.NoError(err) } } -func (s *testRuleSuite) TestGetAllByKey(c *C) { +func (suite *ruleTestSuite) TestGetAllByKey() { rule := placement.Rule{GroupID: "f", ID: "40", StartKeyHex: "8888", EndKeyHex: "9111", Role: "voter", Count: 1} data, err := json.Marshal(rule) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/rule", data, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + re := suite.Require() + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rule", data, tu.StatusOK(re)) + suite.NoError(err) - testcases := []struct { + testCases := []struct { name string key string success bool @@ -469,33 +478,32 @@ func (s *testRuleSuite) TestGetAllByKey(c *C) { respSize: 1, }, } - - for _, testcase := range testcases { - c.Log(testcase.name) + for _, testCase := range testCases { + suite.T().Log(testCase.name) var resp []*placement.Rule - url := fmt.Sprintf("%s/rules/key/%s", s.urlPrefix, testcase.key) - if testcase.success { - err = tu.ReadGetJSON(c, testDialClient, url, &resp) - c.Assert(resp, HasLen, testcase.respSize) + url := fmt.Sprintf("%s/rules/key/%s", suite.urlPrefix, testCase.key) + if testCase.success { + err = tu.ReadGetJSON(re, testDialClient, url, &resp) + suite.Len(resp, testCase.respSize) } else { - err = tu.CheckGetJSON(testDialClient, url, nil, tu.Status(c, testcase.code)) + err = tu.CheckGetJSON(testDialClient, url, nil, tu.Status(re, testCase.code)) } - c.Assert(err, IsNil) + suite.NoError(err) } } -func (s *testRuleSuite) TestDelete(c *C) { +func (suite *ruleTestSuite) TestDelete() { rule := placement.Rule{GroupID: "g", ID: "10", StartKeyHex: "8888", EndKeyHex: "9111", Role: "voter", Count: 1} data, err := json.Marshal(rule) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/rule", data, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rule", data, tu.StatusOK(suite.Require())) + suite.NoError(err) oldStartKey, err := hex.DecodeString(rule.StartKeyHex) - c.Assert(err, IsNil) + suite.NoError(err) oldEndKey, err := hex.DecodeString(rule.EndKeyHex) - c.Assert(err, IsNil) + suite.NoError(err) - testcases := []struct { + testCases := []struct { name string groupID string id string @@ -517,41 +525,41 @@ func (s *testRuleSuite) TestDelete(c *C) { popKeyRange: map[string]struct{}{}, }, } - for _, testcase := range testcases { - c.Log(testcase.name) - url := fmt.Sprintf("%s/rule/%s/%s", s.urlPrefix, testcase.groupID, testcase.id) + for _, testCase := range testCases { + suite.T().Log(testCase.name) + url := fmt.Sprintf("%s/rule/%s/%s", suite.urlPrefix, testCase.groupID, testCase.id) // clear suspect keyRanges to prevent test case from others - s.svr.GetRaftCluster().ClearSuspectKeyRanges() + suite.svr.GetRaftCluster().ClearSuspectKeyRanges() statusCode, err := apiutil.DoDelete(testDialClient, url) - c.Assert(err, IsNil) - c.Assert(statusCode, Equals, http.StatusOK) - if len(testcase.popKeyRange) > 0 { + suite.NoError(err) + suite.Equal(http.StatusOK, statusCode) + if len(testCase.popKeyRange) > 0 { popKeyRangeMap := map[string]struct{}{} - for i := 0; i < len(testcase.popKeyRange)/2; i++ { - v, got := s.svr.GetRaftCluster().PopOneSuspectKeyRange() - c.Assert(got, IsTrue) + for i := 0; i < len(testCase.popKeyRange)/2; i++ { + v, got := suite.svr.GetRaftCluster().PopOneSuspectKeyRange() + suite.True(got) popKeyRangeMap[hex.EncodeToString(v[0])] = struct{}{} popKeyRangeMap[hex.EncodeToString(v[1])] = struct{}{} } - c.Assert(len(popKeyRangeMap), Equals, len(testcase.popKeyRange)) + suite.Len(testCase.popKeyRange, len(popKeyRangeMap)) for k := range popKeyRangeMap { - _, ok := testcase.popKeyRange[k] - c.Assert(ok, IsTrue) + _, ok := testCase.popKeyRange[k] + suite.True(ok) } } } } -func compareRule(c *C, r1 *placement.Rule, r2 *placement.Rule) { - c.Assert(r1.GroupID, Equals, r2.GroupID) - c.Assert(r1.ID, Equals, r2.ID) - c.Assert(r1.StartKeyHex, Equals, r2.StartKeyHex) - c.Assert(r1.EndKeyHex, Equals, r2.EndKeyHex) - c.Assert(r1.Role, Equals, r2.Role) - c.Assert(r1.Count, Equals, r2.Count) +func (suite *ruleTestSuite) compareRule(r1 *placement.Rule, r2 *placement.Rule) { + suite.Equal(r2.GroupID, r1.GroupID) + suite.Equal(r2.ID, r1.ID) + suite.Equal(r2.StartKeyHex, r1.StartKeyHex) + suite.Equal(r2.EndKeyHex, r1.EndKeyHex) + suite.Equal(r2.Role, r1.Role) + suite.Equal(r2.Count, r1.Count) } -func (s *testRuleSuite) TestBatch(c *C) { +func (suite *ruleTestSuite) TestBatch() { opt1 := placement.RuleOp{ Action: placement.RuleOpAdd, Rule: &placement.Rule{GroupID: "a", ID: "13", StartKeyHex: "1111", EndKeyHex: "3333", Role: "voter", Count: 1}, @@ -591,21 +599,21 @@ func (s *testRuleSuite) TestBatch(c *C) { } successData1, err := json.Marshal([]placement.RuleOp{opt1, opt2, opt3}) - c.Assert(err, IsNil) + suite.NoError(err) successData2, err := json.Marshal([]placement.RuleOp{opt5, opt7}) - c.Assert(err, IsNil) + suite.NoError(err) successData3, err := json.Marshal([]placement.RuleOp{opt4, opt6}) - c.Assert(err, IsNil) + suite.NoError(err) checkErrData, err := json.Marshal([]placement.RuleOp{opt8}) - c.Assert(err, IsNil) + suite.NoError(err) setErrData, err := json.Marshal([]placement.RuleOp{opt9}) - c.Assert(err, IsNil) + suite.NoError(err) - testcases := []struct { + testCases := []struct { name string rawData []byte success bool @@ -657,22 +665,23 @@ func (s *testRuleSuite) TestBatch(c *C) { `, }, } - - for _, testcase := range testcases { - c.Log(testcase.name) - if testcase.success { - err := tu.CheckPostJSON(testDialClient, s.urlPrefix+"/rules/batch", testcase.rawData, tu.StatusOK(c)) - c.Assert(err, IsNil) + re := suite.Require() + for _, testCase := range testCases { + suite.T().Log(testCase.name) + if testCase.success { + err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rules/batch", testCase.rawData, tu.StatusOK(re)) + suite.NoError(err) } else { - err := tu.CheckPostJSON(testDialClient, s.urlPrefix+"/rules/batch", testcase.rawData, - tu.StatusNotOK(c), - tu.StringEqual(c, testcase.response)) - c.Assert(err, IsNil) + err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rules/batch", testCase.rawData, + tu.StatusNotOK(re), + tu.StringEqual(re, testCase.response)) + suite.NoError(err) } } } -func (s *testRuleSuite) TestBundle(c *C) { +func (suite *ruleTestSuite) TestBundle() { + re := suite.Require() // GetAll b1 := placement.GroupBundle{ ID: "pd", @@ -681,10 +690,10 @@ func (s *testRuleSuite) TestBundle(c *C) { }, } var bundles []placement.GroupBundle - err := tu.ReadGetJSON(c, testDialClient, s.urlPrefix+"/placement-rule", &bundles) - c.Assert(err, IsNil) - c.Assert(bundles, HasLen, 1) - compareBundle(c, bundles[0], b1) + err := tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/placement-rule", &bundles) + suite.NoError(err) + suite.Len(bundles, 1) + suite.compareBundle(bundles[0], b1) // Set b2 := placement.GroupBundle{ @@ -696,59 +705,59 @@ func (s *testRuleSuite) TestBundle(c *C) { }, } data, err := json.Marshal(b2) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/placement-rule/foo", data, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/placement-rule/foo", data, tu.StatusOK(re)) + suite.NoError(err) // Get var bundle placement.GroupBundle - err = tu.ReadGetJSON(c, testDialClient, s.urlPrefix+"/placement-rule/foo", &bundle) - c.Assert(err, IsNil) - compareBundle(c, bundle, b2) + err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/placement-rule/foo", &bundle) + suite.NoError(err) + suite.compareBundle(bundle, b2) // GetAll again - err = tu.ReadGetJSON(c, testDialClient, s.urlPrefix+"/placement-rule", &bundles) - c.Assert(err, IsNil) - c.Assert(bundles, HasLen, 2) - compareBundle(c, bundles[0], b1) - compareBundle(c, bundles[1], b2) + err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/placement-rule", &bundles) + suite.NoError(err) + suite.Len(bundles, 2) + suite.compareBundle(bundles[0], b1) + suite.compareBundle(bundles[1], b2) // Delete - _, err = apiutil.DoDelete(testDialClient, s.urlPrefix+"/placement-rule/pd") - c.Assert(err, IsNil) + _, err = apiutil.DoDelete(testDialClient, suite.urlPrefix+"/placement-rule/pd") + suite.NoError(err) // GetAll again - err = tu.ReadGetJSON(c, testDialClient, s.urlPrefix+"/placement-rule", &bundles) - c.Assert(err, IsNil) - c.Assert(bundles, HasLen, 1) - compareBundle(c, bundles[0], b2) + err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/placement-rule", &bundles) + suite.NoError(err) + suite.Len(bundles, 1) + suite.compareBundle(bundles[0], b2) // SetAll b2.Rules = append(b2.Rules, &placement.Rule{GroupID: "foo", ID: "baz", Index: 2, Role: "follower", Count: 1}) b2.Index, b2.Override = 0, false b3 := placement.GroupBundle{ID: "foobar", Index: 100} data, err = json.Marshal([]placement.GroupBundle{b1, b2, b3}) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/placement-rule", data, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/placement-rule", data, tu.StatusOK(re)) + suite.NoError(err) // GetAll again - err = tu.ReadGetJSON(c, testDialClient, s.urlPrefix+"/placement-rule", &bundles) - c.Assert(err, IsNil) - c.Assert(bundles, HasLen, 3) - compareBundle(c, bundles[0], b2) - compareBundle(c, bundles[1], b1) - compareBundle(c, bundles[2], b3) + err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/placement-rule", &bundles) + suite.NoError(err) + suite.Len(bundles, 3) + suite.compareBundle(bundles[0], b2) + suite.compareBundle(bundles[1], b1) + suite.compareBundle(bundles[2], b3) // Delete using regexp - _, err = apiutil.DoDelete(testDialClient, s.urlPrefix+"/placement-rule/"+url.PathEscape("foo.*")+"?regexp") - c.Assert(err, IsNil) + _, err = apiutil.DoDelete(testDialClient, suite.urlPrefix+"/placement-rule/"+url.PathEscape("foo.*")+"?regexp") + suite.NoError(err) // GetAll again - err = tu.ReadGetJSON(c, testDialClient, s.urlPrefix+"/placement-rule", &bundles) - c.Assert(err, IsNil) - c.Assert(bundles, HasLen, 1) - compareBundle(c, bundles[0], b1) + err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/placement-rule", &bundles) + suite.NoError(err) + suite.Len(bundles, 1) + suite.compareBundle(bundles[0], b1) // Set id := "rule-without-group-id" @@ -759,24 +768,24 @@ func (s *testRuleSuite) TestBundle(c *C) { }, } data, err = json.Marshal(b4) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/placement-rule/"+id, data, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/placement-rule/"+id, data, tu.StatusOK(re)) + suite.NoError(err) b4.ID = id b4.Rules[0].GroupID = b4.ID // Get - err = tu.ReadGetJSON(c, testDialClient, s.urlPrefix+"/placement-rule/"+id, &bundle) - c.Assert(err, IsNil) - compareBundle(c, bundle, b4) + err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/placement-rule/"+id, &bundle) + suite.NoError(err) + suite.compareBundle(bundle, b4) // GetAll again - err = tu.ReadGetJSON(c, testDialClient, s.urlPrefix+"/placement-rule", &bundles) - c.Assert(err, IsNil) - c.Assert(bundles, HasLen, 2) - compareBundle(c, bundles[0], b1) - compareBundle(c, bundles[1], b4) + err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/placement-rule", &bundles) + suite.NoError(err) + suite.Len(bundles, 2) + suite.compareBundle(bundles[0], b1) + suite.compareBundle(bundles[1], b4) // SetAll b5 := placement.GroupBundle{ @@ -787,22 +796,22 @@ func (s *testRuleSuite) TestBundle(c *C) { }, } data, err = json.Marshal([]placement.GroupBundle{b1, b4, b5}) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/placement-rule", data, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/placement-rule", data, tu.StatusOK(re)) + suite.NoError(err) b5.Rules[0].GroupID = b5.ID // GetAll again - err = tu.ReadGetJSON(c, testDialClient, s.urlPrefix+"/placement-rule", &bundles) - c.Assert(err, IsNil) - c.Assert(bundles, HasLen, 3) - compareBundle(c, bundles[0], b1) - compareBundle(c, bundles[1], b4) - compareBundle(c, bundles[2], b5) + err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/placement-rule", &bundles) + suite.NoError(err) + suite.Len(bundles, 3) + suite.compareBundle(bundles[0], b1) + suite.compareBundle(bundles[1], b4) + suite.compareBundle(bundles[2], b5) } -func (s *testRuleSuite) TestBundleBadRequest(c *C) { +func (suite *ruleTestSuite) TestBundleBadRequest() { testCases := []struct { uri string data string @@ -816,20 +825,20 @@ func (s *testRuleSuite) TestBundleBadRequest(c *C) { {"/placement-rule", `[{"group_id":"foo", "rules": [{"group_id":"bar", "id":"baz", "role":"voter", "count":1}]}]`, false}, } for _, tc := range testCases { - err := tu.CheckPostJSON(testDialClient, s.urlPrefix+tc.uri, []byte(tc.data), + err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+tc.uri, []byte(tc.data), func(_ []byte, code int) { - c.Assert(code == http.StatusOK, Equals, tc.ok) + suite.Equal(tc.ok, code == http.StatusOK) }) - c.Assert(err, IsNil) + suite.NoError(err) } } -func compareBundle(c *C, b1, b2 placement.GroupBundle) { - c.Assert(b1.ID, Equals, b2.ID) - c.Assert(b1.Index, Equals, b2.Index) - c.Assert(b1.Override, Equals, b2.Override) - c.Assert(len(b1.Rules), Equals, len(b2.Rules)) +func (suite *ruleTestSuite) compareBundle(b1, b2 placement.GroupBundle) { + suite.Equal(b2.ID, b1.ID) + suite.Equal(b2.Index, b1.Index) + suite.Equal(b2.Override, b1.Override) + suite.Len(b2.Rules, len(b1.Rules)) for i := range b1.Rules { - compareRule(c, b1.Rules[i], b2.Rules[i]) + suite.compareRule(b1.Rules[i], b2.Rules[i]) } } diff --git a/server/api/scheduler_test.go b/server/api/scheduler_test.go index 8c20bdf6182..c4b30595967 100644 --- a/server/api/scheduler_test.go +++ b/server/api/scheduler_test.go @@ -18,11 +18,12 @@ import ( "encoding/json" "fmt" "net/http" + "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/apiutil" tu "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server" @@ -30,86 +31,91 @@ import ( _ "github.com/tikv/pd/server/schedulers" ) -var _ = Suite(&testScheduleSuite{}) - -type testScheduleSuite struct { +type scheduleTestSuite struct { + suite.Suite svr *server.Server cleanup cleanUpFunc urlPrefix string } -func (s *testScheduleSuite) SetUpSuite(c *C) { - s.svr, s.cleanup = mustNewServer(c) - mustWaitLeader(c, []*server.Server{s.svr}) +func TestScheduleTestSuite(t *testing.T) { + suite.Run(t, new(scheduleTestSuite)) +} + +func (suite *scheduleTestSuite) SetupSuite() { + re := suite.Require() + suite.svr, suite.cleanup = mustNewServer(re) + mustWaitLeader(re, []*server.Server{suite.svr}) - addr := s.svr.GetAddr() - s.urlPrefix = fmt.Sprintf("%s%s/api/v1/schedulers", addr, apiPrefix) + addr := suite.svr.GetAddr() + suite.urlPrefix = fmt.Sprintf("%s%s/api/v1/schedulers", addr, apiPrefix) - mustBootstrapCluster(c, s.svr) - mustPutStore(c, s.svr, 1, metapb.StoreState_Up, metapb.NodeState_Serving, nil) - mustPutStore(c, s.svr, 2, metapb.StoreState_Up, metapb.NodeState_Serving, nil) + mustBootstrapCluster(re, suite.svr) + mustPutStore(re, suite.svr, 1, metapb.StoreState_Up, metapb.NodeState_Serving, nil) + mustPutStore(re, suite.svr, 2, metapb.StoreState_Up, metapb.NodeState_Serving, nil) } -func (s *testScheduleSuite) TearDownSuite(c *C) { - s.cleanup() +func (suite *scheduleTestSuite) TearDownSuite() { + suite.cleanup() } -func (s *testScheduleSuite) TestOriginAPI(c *C) { - addURL := s.urlPrefix +func (suite *scheduleTestSuite) TestOriginAPI() { + addURL := suite.urlPrefix input := make(map[string]interface{}) input["name"] = "evict-leader-scheduler" input["store_id"] = 1 body, err := json.Marshal(input) - c.Assert(err, IsNil) - c.Assert(tu.CheckPostJSON(testDialClient, addURL, body, tu.StatusOK(c)), IsNil) - rc := s.svr.GetRaftCluster() - c.Assert(rc.GetSchedulers(), HasLen, 1) + suite.NoError(err) + re := suite.Require() + suite.NoError(tu.CheckPostJSON(testDialClient, addURL, body, tu.StatusOK(re))) + rc := suite.svr.GetRaftCluster() + suite.Len(rc.GetSchedulers(), 1) resp := make(map[string]interface{}) - listURL := fmt.Sprintf("%s%s%s/%s/list", s.svr.GetAddr(), apiPrefix, server.SchedulerConfigHandlerPath, "evict-leader-scheduler") - c.Assert(tu.ReadGetJSON(c, testDialClient, listURL, &resp), IsNil) - c.Assert(resp["store-id-ranges"], HasLen, 1) + listURL := fmt.Sprintf("%s%s%s/%s/list", suite.svr.GetAddr(), apiPrefix, server.SchedulerConfigHandlerPath, "evict-leader-scheduler") + suite.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + suite.Len(resp["store-id-ranges"], 1) input1 := make(map[string]interface{}) input1["name"] = "evict-leader-scheduler" input1["store_id"] = 2 body, err = json.Marshal(input1) - c.Assert(err, IsNil) - c.Assert(failpoint.Enable("github.com/tikv/pd/server/schedulers/persistFail", "return(true)"), IsNil) - c.Assert(tu.CheckPostJSON(testDialClient, addURL, body, tu.StatusNotOK(c)), IsNil) - c.Assert(rc.GetSchedulers(), HasLen, 1) + suite.NoError(err) + suite.NoError(failpoint.Enable("github.com/tikv/pd/server/schedulers/persistFail", "return(true)")) + suite.NoError(tu.CheckPostJSON(testDialClient, addURL, body, tu.StatusNotOK(re))) + suite.Len(rc.GetSchedulers(), 1) resp = make(map[string]interface{}) - c.Assert(tu.ReadGetJSON(c, testDialClient, listURL, &resp), IsNil) - c.Assert(resp["store-id-ranges"], HasLen, 1) - c.Assert(failpoint.Disable("github.com/tikv/pd/server/schedulers/persistFail"), IsNil) - c.Assert(tu.CheckPostJSON(testDialClient, addURL, body, tu.StatusOK(c)), IsNil) - c.Assert(rc.GetSchedulers(), HasLen, 1) + suite.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + suite.Len(resp["store-id-ranges"], 1) + suite.NoError(failpoint.Disable("github.com/tikv/pd/server/schedulers/persistFail")) + suite.NoError(tu.CheckPostJSON(testDialClient, addURL, body, tu.StatusOK(re))) + suite.Len(rc.GetSchedulers(), 1) resp = make(map[string]interface{}) - c.Assert(tu.ReadGetJSON(c, testDialClient, listURL, &resp), IsNil) - c.Assert(resp["store-id-ranges"], HasLen, 2) - deleteURL := fmt.Sprintf("%s/%s", s.urlPrefix, "evict-leader-scheduler-1") + suite.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + suite.Len(resp["store-id-ranges"], 2) + deleteURL := fmt.Sprintf("%s/%s", suite.urlPrefix, "evict-leader-scheduler-1") _, err = apiutil.DoDelete(testDialClient, deleteURL) - c.Assert(err, IsNil) - c.Assert(rc.GetSchedulers(), HasLen, 1) + suite.NoError(err) + suite.Len(rc.GetSchedulers(), 1) resp1 := make(map[string]interface{}) - c.Assert(tu.ReadGetJSON(c, testDialClient, listURL, &resp1), IsNil) - c.Assert(resp1["store-id-ranges"], HasLen, 1) - deleteURL = fmt.Sprintf("%s/%s", s.urlPrefix, "evict-leader-scheduler-2") - c.Assert(failpoint.Enable("github.com/tikv/pd/server/config/persistFail", "return(true)"), IsNil) + suite.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp1)) + suite.Len(resp1["store-id-ranges"], 1) + deleteURL = fmt.Sprintf("%s/%s", suite.urlPrefix, "evict-leader-scheduler-2") + suite.NoError(failpoint.Enable("github.com/tikv/pd/server/config/persistFail", "return(true)")) statusCode, err := apiutil.DoDelete(testDialClient, deleteURL) - c.Assert(err, IsNil) - c.Assert(statusCode, Equals, 500) - c.Assert(rc.GetSchedulers(), HasLen, 1) - c.Assert(failpoint.Disable("github.com/tikv/pd/server/config/persistFail"), IsNil) + suite.NoError(err) + suite.Equal(500, statusCode) + suite.Len(rc.GetSchedulers(), 1) + suite.NoError(failpoint.Disable("github.com/tikv/pd/server/config/persistFail")) statusCode, err = apiutil.DoDelete(testDialClient, deleteURL) - c.Assert(err, IsNil) - c.Assert(statusCode, Equals, 200) - c.Assert(rc.GetSchedulers(), HasLen, 0) - c.Assert(tu.CheckGetJSON(testDialClient, listURL, nil, tu.Status(c, 404)), IsNil) - + suite.NoError(err) + suite.Equal(200, statusCode) + suite.Len(rc.GetSchedulers(), 0) + suite.NoError(tu.CheckGetJSON(testDialClient, listURL, nil, tu.Status(re, 404))) statusCode, _ = apiutil.DoDelete(testDialClient, deleteURL) - c.Assert(statusCode, Equals, 404) + suite.Equal(404, statusCode) } -func (s *testScheduleSuite) TestAPI(c *C) { +func (suite *scheduleTestSuite) TestAPI() { + re := suite.Require() type arg struct { opt string value interface{} @@ -118,63 +124,63 @@ func (s *testScheduleSuite) TestAPI(c *C) { name string createdName string args []arg - extraTestFunc func(name string, c *C) + extraTestFunc func(name string) }{ { name: "balance-leader-scheduler", - extraTestFunc: func(name string, c *C) { + extraTestFunc: func(name string) { resp := make(map[string]interface{}) - listURL := fmt.Sprintf("%s%s%s/%s/list", s.svr.GetAddr(), apiPrefix, server.SchedulerConfigHandlerPath, name) - c.Assert(tu.ReadGetJSON(c, testDialClient, listURL, &resp), IsNil) - c.Assert(resp["batch"], Equals, 4.0) + listURL := fmt.Sprintf("%s%s%s/%s/list", suite.svr.GetAddr(), apiPrefix, server.SchedulerConfigHandlerPath, name) + suite.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + suite.Equal(4.0, resp["batch"]) dataMap := make(map[string]interface{}) dataMap["batch"] = 3 - updateURL := fmt.Sprintf("%s%s%s/%s/config", s.svr.GetAddr(), apiPrefix, server.SchedulerConfigHandlerPath, name) + updateURL := fmt.Sprintf("%s%s%s/%s/config", suite.svr.GetAddr(), apiPrefix, server.SchedulerConfigHandlerPath, name) body, err := json.Marshal(dataMap) - c.Assert(err, IsNil) - c.Assert(tu.CheckPostJSON(testDialClient, updateURL, body, tu.StatusOK(c)), IsNil) + suite.NoError(err) + suite.NoError(tu.CheckPostJSON(testDialClient, updateURL, body, tu.StatusOK(re))) resp = make(map[string]interface{}) - c.Assert(tu.ReadGetJSON(c, testDialClient, listURL, &resp), IsNil) - c.Assert(resp["batch"], Equals, 3.0) + suite.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + suite.Equal(3.0, resp["batch"]) // update again err = tu.CheckPostJSON(testDialClient, updateURL, body, - tu.StatusOK(c), - tu.StringEqual(c, "\"no changed\"\n")) - c.Assert(err, IsNil) + tu.StatusOK(re), + tu.StringEqual(re, "\"no changed\"\n")) + suite.NoError(err) // update invalidate batch dataMap = map[string]interface{}{} dataMap["batch"] = 100 body, err = json.Marshal(dataMap) - c.Assert(err, IsNil) + suite.NoError(err) err = tu.CheckPostJSON(testDialClient, updateURL, body, - tu.Status(c, http.StatusBadRequest), - tu.StringEqual(c, "\"invalid batch size which should be an integer between 1 and 10\"\n")) - c.Assert(err, IsNil) + tu.Status(re, http.StatusBadRequest), + tu.StringEqual(re, "\"invalid batch size which should be an integer between 1 and 10\"\n")) + suite.NoError(err) resp = make(map[string]interface{}) - c.Assert(tu.ReadGetJSON(c, testDialClient, listURL, &resp), IsNil) - c.Assert(resp["batch"], Equals, 3.0) + suite.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + suite.Equal(3.0, resp["batch"]) // empty body err = tu.CheckPostJSON(testDialClient, updateURL, nil, - tu.Status(c, http.StatusInternalServerError), - tu.StringEqual(c, "\"unexpected end of JSON input\"\n")) - c.Assert(err, IsNil) + tu.Status(re, http.StatusInternalServerError), + tu.StringEqual(re, "\"unexpected end of JSON input\"\n")) + suite.NoError(err) // config item not found dataMap = map[string]interface{}{} dataMap["error"] = 3 body, err = json.Marshal(dataMap) - c.Assert(err, IsNil) + suite.NoError(err) err = tu.CheckPostJSON(testDialClient, updateURL, body, - tu.Status(c, http.StatusBadRequest), - tu.StringEqual(c, "\"config item not found\"\n")) - c.Assert(err, IsNil) + tu.Status(re, http.StatusBadRequest), + tu.StringEqual(re, "\"config item not found\"\n")) + suite.NoError(err) }, }, { name: "balance-hot-region-scheduler", - extraTestFunc: func(name string, c *C) { + extraTestFunc: func(name string) { resp := make(map[string]interface{}) - listURL := fmt.Sprintf("%s%s%s/%s/list", s.svr.GetAddr(), apiPrefix, server.SchedulerConfigHandlerPath, name) - c.Assert(tu.ReadGetJSON(c, testDialClient, listURL, &resp), IsNil) + listURL := fmt.Sprintf("%s%s%s/%s/list", suite.svr.GetAddr(), apiPrefix, server.SchedulerConfigHandlerPath, name) + suite.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) expectMap := map[string]float64{ "min-hot-byte-rate": 100, "min-hot-key-rate": 10, @@ -188,25 +194,25 @@ func (s *testScheduleSuite) TestAPI(c *C) { "minor-dec-ratio": 0.99, } for key := range expectMap { - c.Assert(resp[key], DeepEquals, expectMap[key]) + suite.Equal(expectMap[key], resp[key]) } dataMap := make(map[string]interface{}) dataMap["max-zombie-rounds"] = 5.0 expectMap["max-zombie-rounds"] = 5.0 - updateURL := fmt.Sprintf("%s%s%s/%s/config", s.svr.GetAddr(), apiPrefix, server.SchedulerConfigHandlerPath, name) + updateURL := fmt.Sprintf("%s%s%s/%s/config", suite.svr.GetAddr(), apiPrefix, server.SchedulerConfigHandlerPath, name) body, err := json.Marshal(dataMap) - c.Assert(err, IsNil) - c.Assert(tu.CheckPostJSON(testDialClient, updateURL, body, tu.StatusOK(c)), IsNil) + suite.NoError(err) + suite.NoError(tu.CheckPostJSON(testDialClient, updateURL, body, tu.StatusOK(re))) resp = make(map[string]interface{}) - c.Assert(tu.ReadGetJSON(c, testDialClient, listURL, &resp), IsNil) + suite.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) for key := range expectMap { - c.Assert(resp[key], DeepEquals, expectMap[key]) + suite.Equal(expectMap[key], resp[key]) } // update again err = tu.CheckPostJSON(testDialClient, updateURL, body, - tu.StatusOK(c), - tu.StringEqual(c, "no changed")) - c.Assert(err, IsNil) + tu.StatusOK(re), + tu.StringEqual(re, "no changed")) + suite.NoError(err) }, }, {name: "balance-region-scheduler"}, @@ -216,38 +222,38 @@ func (s *testScheduleSuite) TestAPI(c *C) { name: "grant-leader-scheduler", createdName: "grant-leader-scheduler", args: []arg{{"store_id", 1}}, - extraTestFunc: func(name string, c *C) { + extraTestFunc: func(name string) { resp := make(map[string]interface{}) - listURL := fmt.Sprintf("%s%s%s/%s/list", s.svr.GetAddr(), apiPrefix, server.SchedulerConfigHandlerPath, name) - c.Assert(tu.ReadGetJSON(c, testDialClient, listURL, &resp), IsNil) + listURL := fmt.Sprintf("%s%s%s/%s/list", suite.svr.GetAddr(), apiPrefix, server.SchedulerConfigHandlerPath, name) + suite.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) exceptMap := make(map[string]interface{}) exceptMap["1"] = []interface{}{map[string]interface{}{"end-key": "", "start-key": ""}} - c.Assert(resp["store-id-ranges"], DeepEquals, exceptMap) + suite.Equal(exceptMap, resp["store-id-ranges"]) // using /pd/v1/schedule-config/grant-leader-scheduler/config to add new store to grant-leader-scheduler input := make(map[string]interface{}) input["name"] = "grant-leader-scheduler" input["store_id"] = 2 - updateURL := fmt.Sprintf("%s%s%s/%s/config", s.svr.GetAddr(), apiPrefix, server.SchedulerConfigHandlerPath, name) + updateURL := fmt.Sprintf("%s%s%s/%s/config", suite.svr.GetAddr(), apiPrefix, server.SchedulerConfigHandlerPath, name) body, err := json.Marshal(input) - c.Assert(err, IsNil) - c.Assert(tu.CheckPostJSON(testDialClient, updateURL, body, tu.StatusOK(c)), IsNil) + suite.NoError(err) + suite.NoError(tu.CheckPostJSON(testDialClient, updateURL, body, tu.StatusOK(re))) resp = make(map[string]interface{}) - c.Assert(tu.ReadGetJSON(c, testDialClient, listURL, &resp), IsNil) + suite.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) exceptMap["2"] = []interface{}{map[string]interface{}{"end-key": "", "start-key": ""}} - c.Assert(resp["store-id-ranges"], DeepEquals, exceptMap) + suite.Equal(exceptMap, resp["store-id-ranges"]) // using /pd/v1/schedule-config/grant-leader-scheduler/config to delete exists store from grant-leader-scheduler - deleteURL := fmt.Sprintf("%s%s%s/%s/delete/%s", s.svr.GetAddr(), apiPrefix, server.SchedulerConfigHandlerPath, name, "2") + deleteURL := fmt.Sprintf("%s%s%s/%s/delete/%s", suite.svr.GetAddr(), apiPrefix, server.SchedulerConfigHandlerPath, name, "2") _, err = apiutil.DoDelete(testDialClient, deleteURL) - c.Assert(err, IsNil) + suite.NoError(err) resp = make(map[string]interface{}) - c.Assert(tu.ReadGetJSON(c, testDialClient, listURL, &resp), IsNil) + suite.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) delete(exceptMap, "2") - c.Assert(resp["store-id-ranges"], DeepEquals, exceptMap) + suite.Equal(exceptMap, resp["store-id-ranges"]) statusCode, err := apiutil.DoDelete(testDialClient, deleteURL) - c.Assert(err, IsNil) - c.Assert(statusCode, Equals, 404) + suite.NoError(err) + suite.Equal(404, statusCode) }, }, { @@ -255,24 +261,24 @@ func (s *testScheduleSuite) TestAPI(c *C) { createdName: "scatter-range-test", args: []arg{{"start_key", ""}, {"end_key", ""}, {"range_name", "test"}}, // Test the scheduler config handler. - extraTestFunc: func(name string, c *C) { + extraTestFunc: func(name string) { resp := make(map[string]interface{}) - listURL := fmt.Sprintf("%s%s%s/%s/list", s.svr.GetAddr(), apiPrefix, server.SchedulerConfigHandlerPath, name) - c.Assert(tu.ReadGetJSON(c, testDialClient, listURL, &resp), IsNil) - c.Assert(resp["start-key"], Equals, "") - c.Assert(resp["end-key"], Equals, "") - c.Assert(resp["range-name"], Equals, "test") + listURL := fmt.Sprintf("%s%s%s/%s/list", suite.svr.GetAddr(), apiPrefix, server.SchedulerConfigHandlerPath, name) + suite.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + suite.Equal("", resp["start-key"]) + suite.Equal("", resp["end-key"]) + suite.Equal("test", resp["range-name"]) resp["start-key"] = "a_00" resp["end-key"] = "a_99" - updateURL := fmt.Sprintf("%s%s%s/%s/config", s.svr.GetAddr(), apiPrefix, server.SchedulerConfigHandlerPath, name) + updateURL := fmt.Sprintf("%s%s%s/%s/config", suite.svr.GetAddr(), apiPrefix, server.SchedulerConfigHandlerPath, name) body, err := json.Marshal(resp) - c.Assert(err, IsNil) - c.Assert(tu.CheckPostJSON(testDialClient, updateURL, body, tu.StatusOK(c)), IsNil) + suite.NoError(err) + suite.NoError(tu.CheckPostJSON(testDialClient, updateURL, body, tu.StatusOK(re))) resp = make(map[string]interface{}) - c.Assert(tu.ReadGetJSON(c, testDialClient, listURL, &resp), IsNil) - c.Assert(resp["start-key"], Equals, "a_00") - c.Assert(resp["end-key"], Equals, "a_99") - c.Assert(resp["range-name"], Equals, "test") + suite.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + suite.Equal("a_00", resp["start-key"]) + suite.Equal("a_99", resp["end-key"]) + suite.Equal("test", resp["range-name"]) }, }, { @@ -280,38 +286,38 @@ func (s *testScheduleSuite) TestAPI(c *C) { createdName: "evict-leader-scheduler", args: []arg{{"store_id", 1}}, // Test the scheduler config handler. - extraTestFunc: func(name string, c *C) { + extraTestFunc: func(name string) { resp := make(map[string]interface{}) - listURL := fmt.Sprintf("%s%s%s/%s/list", s.svr.GetAddr(), apiPrefix, server.SchedulerConfigHandlerPath, name) - c.Assert(tu.ReadGetJSON(c, testDialClient, listURL, &resp), IsNil) + listURL := fmt.Sprintf("%s%s%s/%s/list", suite.svr.GetAddr(), apiPrefix, server.SchedulerConfigHandlerPath, name) + suite.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) exceptMap := make(map[string]interface{}) exceptMap["1"] = []interface{}{map[string]interface{}{"end-key": "", "start-key": ""}} - c.Assert(resp["store-id-ranges"], DeepEquals, exceptMap) + suite.Equal(exceptMap, resp["store-id-ranges"]) // using /pd/v1/schedule-config/evict-leader-scheduler/config to add new store to evict-leader-scheduler input := make(map[string]interface{}) input["name"] = "evict-leader-scheduler" input["store_id"] = 2 - updateURL := fmt.Sprintf("%s%s%s/%s/config", s.svr.GetAddr(), apiPrefix, server.SchedulerConfigHandlerPath, name) + updateURL := fmt.Sprintf("%s%s%s/%s/config", suite.svr.GetAddr(), apiPrefix, server.SchedulerConfigHandlerPath, name) body, err := json.Marshal(input) - c.Assert(err, IsNil) - c.Assert(tu.CheckPostJSON(testDialClient, updateURL, body, tu.StatusOK(c)), IsNil) + suite.NoError(err) + suite.NoError(tu.CheckPostJSON(testDialClient, updateURL, body, tu.StatusOK(re))) resp = make(map[string]interface{}) - c.Assert(tu.ReadGetJSON(c, testDialClient, listURL, &resp), IsNil) + suite.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) exceptMap["2"] = []interface{}{map[string]interface{}{"end-key": "", "start-key": ""}} - c.Assert(resp["store-id-ranges"], DeepEquals, exceptMap) + suite.Equal(exceptMap, resp["store-id-ranges"]) // using /pd/v1/schedule-config/evict-leader-scheduler/config to delete exist store from evict-leader-scheduler - deleteURL := fmt.Sprintf("%s%s%s/%s/delete/%s", s.svr.GetAddr(), apiPrefix, server.SchedulerConfigHandlerPath, name, "2") + deleteURL := fmt.Sprintf("%s%s%s/%s/delete/%s", suite.svr.GetAddr(), apiPrefix, server.SchedulerConfigHandlerPath, name, "2") _, err = apiutil.DoDelete(testDialClient, deleteURL) - c.Assert(err, IsNil) + suite.NoError(err) resp = make(map[string]interface{}) - c.Assert(tu.ReadGetJSON(c, testDialClient, listURL, &resp), IsNil) + suite.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) delete(exceptMap, "2") - c.Assert(resp["store-id-ranges"], DeepEquals, exceptMap) + suite.Equal(exceptMap, resp["store-id-ranges"]) statusCode, err := apiutil.DoDelete(testDialClient, deleteURL) - c.Assert(err, IsNil) - c.Assert(statusCode, Equals, 404) + suite.NoError(err) + suite.Equal(404, statusCode) }, }, } @@ -322,8 +328,8 @@ func (s *testScheduleSuite) TestAPI(c *C) { input[a.opt] = a.value } body, err := json.Marshal(input) - c.Assert(err, IsNil) - s.testPauseOrResume(ca.name, ca.createdName, body, ca.extraTestFunc, c) + suite.NoError(err) + suite.testPauseOrResume(ca.name, ca.createdName, body) } // test pause and resume all schedulers. @@ -337,32 +343,32 @@ func (s *testScheduleSuite) TestAPI(c *C) { input[a.opt] = a.value } body, err := json.Marshal(input) - c.Assert(err, IsNil) - s.addScheduler(ca.name, ca.createdName, body, ca.extraTestFunc, c) + suite.NoError(err) + suite.addScheduler(body) } // test pause all schedulers. input := make(map[string]interface{}) input["delay"] = 30 pauseArgs, err := json.Marshal(input) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/all", pauseArgs, tu.StatusOK(c)) - c.Assert(err, IsNil) - handler := s.svr.GetHandler() + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/all", pauseArgs, tu.StatusOK(re)) + suite.NoError(err) + handler := suite.svr.GetHandler() for _, ca := range cases { createdName := ca.createdName if createdName == "" { createdName = ca.name } isPaused, err := handler.IsSchedulerPaused(createdName) - c.Assert(err, IsNil) - c.Assert(isPaused, IsTrue) + suite.NoError(err) + suite.True(isPaused) } input["delay"] = 1 pauseArgs, err = json.Marshal(input) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/all", pauseArgs, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/all", pauseArgs, tu.StatusOK(re)) + suite.NoError(err) time.Sleep(time.Second) for _, ca := range cases { createdName := ca.createdName @@ -370,29 +376,29 @@ func (s *testScheduleSuite) TestAPI(c *C) { createdName = ca.name } isPaused, err := handler.IsSchedulerPaused(createdName) - c.Assert(err, IsNil) - c.Assert(isPaused, IsFalse) + suite.NoError(err) + suite.False(isPaused) } // test resume all schedulers. input["delay"] = 30 pauseArgs, err = json.Marshal(input) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/all", pauseArgs, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/all", pauseArgs, tu.StatusOK(re)) + suite.NoError(err) input["delay"] = 0 pauseArgs, err = json.Marshal(input) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/all", pauseArgs, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/all", pauseArgs, tu.StatusOK(re)) + suite.NoError(err) for _, ca := range cases { createdName := ca.createdName if createdName == "" { createdName = ca.name } isPaused, err := handler.IsSchedulerPaused(createdName) - c.Assert(err, IsNil) - c.Assert(isPaused, IsFalse) + suite.NoError(err) + suite.False(isPaused) } // delete schedulers. @@ -401,124 +407,114 @@ func (s *testScheduleSuite) TestAPI(c *C) { if createdName == "" { createdName = ca.name } - s.deleteScheduler(createdName, c) + suite.deleteScheduler(createdName) } } -func (s *testScheduleSuite) TestDisable(c *C) { +func (suite *scheduleTestSuite) TestDisable() { name := "shuffle-leader-scheduler" input := make(map[string]interface{}) input["name"] = name body, err := json.Marshal(input) - c.Assert(err, IsNil) - s.addScheduler(name, name, body, nil, c) + suite.NoError(err) + suite.addScheduler(body) - u := fmt.Sprintf("%s%s/api/v1/config/schedule", s.svr.GetAddr(), apiPrefix) + re := suite.Require() + u := fmt.Sprintf("%s%s/api/v1/config/schedule", suite.svr.GetAddr(), apiPrefix) var scheduleConfig config.ScheduleConfig - err = tu.ReadGetJSON(c, testDialClient, u, &scheduleConfig) - c.Assert(err, IsNil) + err = tu.ReadGetJSON(re, testDialClient, u, &scheduleConfig) + suite.NoError(err) originSchedulers := scheduleConfig.Schedulers scheduleConfig.Schedulers = config.SchedulerConfigs{config.SchedulerConfig{Type: "shuffle-leader", Disable: true}} body, err = json.Marshal(scheduleConfig) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, u, body, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, u, body, tu.StatusOK(re)) + suite.NoError(err) var schedulers []string - err = tu.ReadGetJSON(c, testDialClient, s.urlPrefix, &schedulers) - c.Assert(err, IsNil) - c.Assert(schedulers, HasLen, 1) - c.Assert(schedulers[0], Equals, name) + err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix, &schedulers) + suite.NoError(err) + suite.Len(schedulers, 1) + suite.Equal(name, schedulers[0]) - err = tu.ReadGetJSON(c, testDialClient, fmt.Sprintf("%s?status=disabled", s.urlPrefix), &schedulers) - c.Assert(err, IsNil) - c.Assert(schedulers, HasLen, 1) - c.Assert(schedulers[0], Equals, name) + err = tu.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s?status=disabled", suite.urlPrefix), &schedulers) + suite.NoError(err) + suite.Len(schedulers, 1) + suite.Equal(name, schedulers[0]) // reset schedule config scheduleConfig.Schedulers = originSchedulers body, err = json.Marshal(scheduleConfig) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, u, body, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, u, body, tu.StatusOK(re)) + suite.NoError(err) - s.deleteScheduler(name, c) + suite.deleteScheduler(name) } -func (s *testScheduleSuite) addScheduler(name, createdName string, body []byte, extraTest func(string, *C), c *C) { - if createdName == "" { - createdName = name - } - err := tu.CheckPostJSON(testDialClient, s.urlPrefix, body, tu.StatusOK(c)) - c.Assert(err, IsNil) - - if extraTest != nil { - extraTest(createdName, c) - } +func (suite *scheduleTestSuite) addScheduler(body []byte) { + err := tu.CheckPostJSON(testDialClient, suite.urlPrefix, body, tu.StatusOK(suite.Require())) + suite.NoError(err) } -func (s *testScheduleSuite) deleteScheduler(createdName string, c *C) { - deleteURL := fmt.Sprintf("%s/%s", s.urlPrefix, createdName) +func (suite *scheduleTestSuite) deleteScheduler(createdName string) { + deleteURL := fmt.Sprintf("%s/%s", suite.urlPrefix, createdName) _, err := apiutil.DoDelete(testDialClient, deleteURL) - c.Assert(err, IsNil) + suite.NoError(err) } -func (s *testScheduleSuite) testPauseOrResume(name, createdName string, body []byte, extraTest func(string, *C), c *C) { +func (suite *scheduleTestSuite) testPauseOrResume(name, createdName string, body []byte) { if createdName == "" { createdName = name } - err := tu.CheckPostJSON(testDialClient, s.urlPrefix, body, tu.StatusOK(c)) - c.Assert(err, IsNil) - handler := s.svr.GetHandler() + re := suite.Require() + err := tu.CheckPostJSON(testDialClient, suite.urlPrefix, body, tu.StatusOK(re)) + suite.NoError(err) + handler := suite.svr.GetHandler() sches, err := handler.GetSchedulers() - c.Assert(err, IsNil) - c.Assert(sches[0], Equals, createdName) + suite.NoError(err) + suite.Equal(createdName, sches[0]) // test pause. input := make(map[string]interface{}) input["delay"] = 30 pauseArgs, err := json.Marshal(input) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/"+createdName, pauseArgs, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/"+createdName, pauseArgs, tu.StatusOK(re)) + suite.NoError(err) isPaused, err := handler.IsSchedulerPaused(createdName) - c.Assert(err, IsNil) - c.Assert(isPaused, IsTrue) + suite.NoError(err) + suite.True(isPaused) input["delay"] = 1 pauseArgs, err = json.Marshal(input) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/"+createdName, pauseArgs, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/"+createdName, pauseArgs, tu.StatusOK(re)) + suite.NoError(err) pausedAt, err := handler.GetPausedSchedulerDelayAt(createdName) - c.Assert(err, IsNil) + suite.NoError(err) resumeAt, err := handler.GetPausedSchedulerDelayUntil(createdName) - c.Assert(err, IsNil) - c.Assert(resumeAt-pausedAt, Equals, int64(1)) + suite.NoError(err) + suite.Equal(int64(1), resumeAt-pausedAt) time.Sleep(time.Second) isPaused, err = handler.IsSchedulerPaused(createdName) - c.Assert(err, IsNil) - c.Assert(isPaused, IsFalse) + suite.NoError(err) + suite.False(isPaused) // test resume. input = make(map[string]interface{}) input["delay"] = 30 pauseArgs, err = json.Marshal(input) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/"+createdName, pauseArgs, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/"+createdName, pauseArgs, tu.StatusOK(re)) + suite.NoError(err) input["delay"] = 0 pauseArgs, err = json.Marshal(input) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/"+createdName, pauseArgs, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/"+createdName, pauseArgs, tu.StatusOK(re)) + suite.NoError(err) isPaused, err = handler.IsSchedulerPaused(createdName) - c.Assert(err, IsNil) - c.Assert(isPaused, IsFalse) - - if extraTest != nil { - extraTest(createdName, c) - } - - s.deleteScheduler(createdName, c) + suite.NoError(err) + suite.False(isPaused) + suite.deleteScheduler(createdName) } diff --git a/server/api/server_test.go b/server/api/server_test.go index a4c6b6de6fb..273f62cab54 100644 --- a/server/api/server_test.go +++ b/server/api/server_test.go @@ -21,10 +21,11 @@ import ( "sync" "testing" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/log" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/apiutil" "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server" @@ -61,33 +62,29 @@ var ( } ) -func TestAPIServer(t *testing.T) { - TestingT(t) -} - func TestMain(m *testing.M) { goleak.VerifyTestMain(m, testutil.LeakOptions...) } type cleanUpFunc func() -func mustNewServer(c *C, opts ...func(cfg *config.Config)) (*server.Server, cleanUpFunc) { - _, svrs, cleanup := mustNewCluster(c, 1, opts...) +func mustNewServer(re *require.Assertions, opts ...func(cfg *config.Config)) (*server.Server, cleanUpFunc) { + _, svrs, cleanup := mustNewCluster(re, 1, opts...) return svrs[0], cleanup } var zapLogOnce sync.Once -func mustNewCluster(c *C, num int, opts ...func(cfg *config.Config)) ([]*config.Config, []*server.Server, cleanUpFunc) { +func mustNewCluster(re *require.Assertions, num int, opts ...func(cfg *config.Config)) ([]*config.Config, []*server.Server, cleanUpFunc) { ctx, cancel := context.WithCancel(context.Background()) svrs := make([]*server.Server, 0, num) - cfgs := server.NewTestMultiConfig(checkerWithNilAssert(c), num) + cfgs := server.NewTestMultiConfig(checkerWithNilAssert(re), num) ch := make(chan *server.Server, num) for _, cfg := range cfgs { go func(cfg *config.Config) { err := cfg.SetupLogger() - c.Assert(err, IsNil) + re.NoError(err) zapLogOnce.Do(func() { log.ReplaceGlobals(cfg.GetZapLogger(), cfg.GetZapLogProperties()) }) @@ -95,9 +92,9 @@ func mustNewCluster(c *C, num int, opts ...func(cfg *config.Config)) ([]*config. opt(cfg) } s, err := server.CreateServer(ctx, cfg, NewHandler) - c.Assert(err, IsNil) + re.NoError(err) err = s.Run() - c.Assert(err, IsNil) + re.NoError(err) ch <- s }(cfg) } @@ -108,7 +105,7 @@ func mustNewCluster(c *C, num int, opts ...func(cfg *config.Config)) ([]*config. } close(ch) // wait etcd and http servers - mustWaitLeader(c, svrs) + mustWaitLeader(re, svrs) // clean up clean := func() { @@ -124,8 +121,8 @@ func mustNewCluster(c *C, num int, opts ...func(cfg *config.Config)) ([]*config. return cfgs, svrs, clean } -func mustWaitLeader(c *C, svrs []*server.Server) { - testutil.WaitUntil(c, func() bool { +func mustWaitLeader(re *require.Assertions, svrs []*server.Server) { + testutil.Eventually(re, func() bool { var leader *pdpb.Member for _, svr := range svrs { l := svr.GetLeader() @@ -141,88 +138,92 @@ func mustWaitLeader(c *C, svrs []*server.Server) { }) } -func mustBootstrapCluster(c *C, s *server.Server) { - grpcPDClient := testutil.MustNewGrpcClient(c, s.GetAddr()) +func mustBootstrapCluster(re *require.Assertions, s *server.Server) { + grpcPDClient := testutil.MustNewGrpcClient(re, s.GetAddr()) req := &pdpb.BootstrapRequest{ Header: testutil.NewRequestHeader(s.ClusterID()), Store: store, Region: region, } resp, err := grpcPDClient.Bootstrap(context.Background(), req) - c.Assert(err, IsNil) - c.Assert(resp.GetHeader().GetError().GetType(), Equals, pdpb.ErrorType_OK) + re.NoError(err) + re.Equal(pdpb.ErrorType_OK, resp.GetHeader().GetError().GetType()) } -var _ = Suite(&testServiceSuite{}) - -type testServiceSuite struct { +type serviceTestSuite struct { + suite.Suite svr *server.Server cleanup cleanUpFunc } -func (s *testServiceSuite) SetUpSuite(c *C) { - s.svr, s.cleanup = mustNewServer(c) - mustWaitLeader(c, []*server.Server{s.svr}) +func TestServiceTestSuite(t *testing.T) { + suite.Run(t, new(serviceTestSuite)) +} + +func (suite *serviceTestSuite) SetupSuite() { + re := suite.Require() + suite.svr, suite.cleanup = mustNewServer(re) + mustWaitLeader(re, []*server.Server{suite.svr}) - mustBootstrapCluster(c, s.svr) - mustPutStore(c, s.svr, 1, metapb.StoreState_Up, metapb.NodeState_Serving, nil) + mustBootstrapCluster(re, suite.svr) + mustPutStore(re, suite.svr, 1, metapb.StoreState_Up, metapb.NodeState_Serving, nil) } -func (s *testServiceSuite) TearDownSuite(c *C) { - s.cleanup() +func (suite *serviceTestSuite) TearDownSuite() { + suite.cleanup() } -func (s *testServiceSuite) TestServiceLabels(c *C) { - accessPaths := s.svr.GetServiceLabels("Profile") - c.Assert(accessPaths, HasLen, 1) - c.Assert(accessPaths[0].Path, Equals, "/pd/api/v1/debug/pprof/profile") - c.Assert(accessPaths[0].Method, Equals, "") - serviceLabel := s.svr.GetAPIAccessServiceLabel( +func (suite *serviceTestSuite) TestServiceLabels() { + accessPaths := suite.svr.GetServiceLabels("Profile") + suite.Len(accessPaths, 1) + suite.Equal("/pd/api/v1/debug/pprof/profile", accessPaths[0].Path) + suite.Equal("", accessPaths[0].Method) + serviceLabel := suite.svr.GetAPIAccessServiceLabel( apiutil.NewAccessPath("/pd/api/v1/debug/pprof/profile", "")) - c.Assert(serviceLabel, Equals, "Profile") - serviceLabel = s.svr.GetAPIAccessServiceLabel( + suite.Equal("Profile", serviceLabel) + serviceLabel = suite.svr.GetAPIAccessServiceLabel( apiutil.NewAccessPath("/pd/api/v1/debug/pprof/profile", http.MethodGet)) - c.Assert(serviceLabel, Equals, "Profile") - - accessPaths = s.svr.GetServiceLabels("GetSchedulerConfig") - c.Assert(accessPaths, HasLen, 1) - c.Assert(accessPaths[0].Path, Equals, "/pd/api/v1/scheduler-config") - c.Assert(accessPaths[0].Method, Equals, "") - - accessPaths = s.svr.GetServiceLabels("ResignLeader") - c.Assert(accessPaths, HasLen, 1) - c.Assert(accessPaths[0].Path, Equals, "/pd/api/v1/leader/resign") - c.Assert(accessPaths[0].Method, Equals, http.MethodPost) - serviceLabel = s.svr.GetAPIAccessServiceLabel( + suite.Equal("Profile", serviceLabel) + + accessPaths = suite.svr.GetServiceLabels("GetSchedulerConfig") + suite.Len(accessPaths, 1) + suite.Equal("/pd/api/v1/scheduler-config", accessPaths[0].Path) + suite.Equal("", accessPaths[0].Method) + + accessPaths = suite.svr.GetServiceLabels("ResignLeader") + suite.Len(accessPaths, 1) + suite.Equal("/pd/api/v1/leader/resign", accessPaths[0].Path) + suite.Equal(http.MethodPost, accessPaths[0].Method) + serviceLabel = suite.svr.GetAPIAccessServiceLabel( apiutil.NewAccessPath("/pd/api/v1/leader/resign", http.MethodPost)) - c.Assert(serviceLabel, Equals, "ResignLeader") - serviceLabel = s.svr.GetAPIAccessServiceLabel( + suite.Equal("ResignLeader", serviceLabel) + serviceLabel = suite.svr.GetAPIAccessServiceLabel( apiutil.NewAccessPath("/pd/api/v1/leader/resign", http.MethodGet)) - c.Assert(serviceLabel, Equals, "") - serviceLabel = s.svr.GetAPIAccessServiceLabel( + suite.Equal("", serviceLabel) + serviceLabel = suite.svr.GetAPIAccessServiceLabel( apiutil.NewAccessPath("/pd/api/v1/leader/resign", "")) - c.Assert(serviceLabel, Equals, "") + suite.Equal("", serviceLabel) - accessPaths = s.svr.GetServiceLabels("QueryMetric") - c.Assert(accessPaths, HasLen, 4) + accessPaths = suite.svr.GetServiceLabels("QueryMetric") + suite.Len(accessPaths, 4) sort.Slice(accessPaths, func(i, j int) bool { if accessPaths[i].Path == accessPaths[j].Path { return accessPaths[i].Method < accessPaths[j].Method } return accessPaths[i].Path < accessPaths[j].Path }) - c.Assert(accessPaths[0].Path, Equals, "/pd/api/v1/metric/query") - c.Assert(accessPaths[0].Method, Equals, http.MethodGet) - c.Assert(accessPaths[1].Path, Equals, "/pd/api/v1/metric/query") - c.Assert(accessPaths[1].Method, Equals, http.MethodPost) - c.Assert(accessPaths[2].Path, Equals, "/pd/api/v1/metric/query_range") - c.Assert(accessPaths[2].Method, Equals, http.MethodGet) - c.Assert(accessPaths[3].Path, Equals, "/pd/api/v1/metric/query_range") - c.Assert(accessPaths[3].Method, Equals, http.MethodPost) - serviceLabel = s.svr.GetAPIAccessServiceLabel( + suite.Equal("/pd/api/v1/metric/query", accessPaths[0].Path) + suite.Equal(http.MethodGet, accessPaths[0].Method) + suite.Equal("/pd/api/v1/metric/query", accessPaths[1].Path) + suite.Equal(http.MethodPost, accessPaths[1].Method) + suite.Equal("/pd/api/v1/metric/query_range", accessPaths[2].Path) + suite.Equal(http.MethodGet, accessPaths[2].Method) + suite.Equal("/pd/api/v1/metric/query_range", accessPaths[3].Path) + suite.Equal(http.MethodPost, accessPaths[3].Method) + serviceLabel = suite.svr.GetAPIAccessServiceLabel( apiutil.NewAccessPath("/pd/api/v1/metric/query", http.MethodPost)) - c.Assert(serviceLabel, Equals, "QueryMetric") - serviceLabel = s.svr.GetAPIAccessServiceLabel( + suite.Equal("QueryMetric", serviceLabel) + serviceLabel = suite.svr.GetAPIAccessServiceLabel( apiutil.NewAccessPath("/pd/api/v1/metric/query", http.MethodGet)) - c.Assert(serviceLabel, Equals, "QueryMetric") + suite.Equal("QueryMetric", serviceLabel) } diff --git a/server/api/service_gc_safepoint_test.go b/server/api/service_gc_safepoint_test.go index ecfaa76bf55..291bba0fcaf 100644 --- a/server/api/service_gc_safepoint_test.go +++ b/server/api/service_gc_safepoint_test.go @@ -17,42 +17,47 @@ package api import ( "fmt" "net/http" + "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/apiutil" "github.com/tikv/pd/server" "github.com/tikv/pd/server/storage/endpoint" ) -var _ = Suite(&testServiceGCSafepointSuite{}) - -type testServiceGCSafepointSuite struct { +type serviceGCSafepointTestSuite struct { + suite.Suite svr *server.Server cleanup cleanUpFunc urlPrefix string } -func (s *testServiceGCSafepointSuite) SetUpSuite(c *C) { - s.svr, s.cleanup = mustNewServer(c) - mustWaitLeader(c, []*server.Server{s.svr}) +func TestServiceGCSafepointTestSuite(t *testing.T) { + suite.Run(t, new(serviceGCSafepointTestSuite)) +} + +func (suite *serviceGCSafepointTestSuite) SetupSuite() { + re := suite.Require() + suite.svr, suite.cleanup = mustNewServer(re) + mustWaitLeader(re, []*server.Server{suite.svr}) - addr := s.svr.GetAddr() - s.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) + addr := suite.svr.GetAddr() + suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) - mustBootstrapCluster(c, s.svr) - mustPutStore(c, s.svr, 1, metapb.StoreState_Up, metapb.NodeState_Serving, nil) + mustBootstrapCluster(re, suite.svr) + mustPutStore(re, suite.svr, 1, metapb.StoreState_Up, metapb.NodeState_Serving, nil) } -func (s *testServiceGCSafepointSuite) TearDownSuite(c *C) { - s.cleanup() +func (suite *serviceGCSafepointTestSuite) TearDownSuite() { + suite.cleanup() } -func (s *testServiceGCSafepointSuite) TestServiceGCSafepoint(c *C) { - sspURL := s.urlPrefix + "/gc/safepoint" +func (suite *serviceGCSafepointTestSuite) TestServiceGCSafepoint() { + sspURL := suite.urlPrefix + "/gc/safepoint" - storage := s.svr.GetStorage() + storage := suite.svr.GetStorage() list := &listServiceGCSafepoint{ ServiceGCSafepoints: []*endpoint.ServiceSafePoint{ { @@ -75,23 +80,23 @@ func (s *testServiceGCSafepointSuite) TestServiceGCSafepoint(c *C) { } for _, ssp := range list.ServiceGCSafepoints { err := storage.SaveServiceGCSafePoint(ssp) - c.Assert(err, IsNil) + suite.NoError(err) } storage.SaveGCSafePoint(1) res, err := testDialClient.Get(sspURL) - c.Assert(err, IsNil) + suite.NoError(err) defer res.Body.Close() listResp := &listServiceGCSafepoint{} err = apiutil.ReadJSON(res.Body, listResp) - c.Assert(err, IsNil) - c.Assert(listResp, DeepEquals, list) + suite.NoError(err) + suite.Equal(list, listResp) statusCode, err := apiutil.DoDelete(testDialClient, sspURL+"/a") - c.Assert(err, IsNil) - c.Assert(statusCode, Equals, http.StatusOK) + suite.NoError(err) + suite.Equal(http.StatusOK, statusCode) left, err := storage.LoadAllServiceGCSafePoints() - c.Assert(err, IsNil) - c.Assert(left, DeepEquals, list.ServiceGCSafepoints[1:]) + suite.NoError(err) + suite.Equal(list.ServiceGCSafepoints[1:], left) } diff --git a/server/api/service_middleware_test.go b/server/api/service_middleware_test.go index 6ea0343f53b..ac188dd8759 100644 --- a/server/api/service_middleware_test.go +++ b/server/api/service_middleware_test.go @@ -18,185 +18,191 @@ import ( "encoding/json" "fmt" "net/http" + "testing" - . "github.com/pingcap/check" "github.com/pingcap/failpoint" + "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/ratelimit" tu "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server" "github.com/tikv/pd/server/config" ) -var _ = Suite(&testAuditMiddlewareSuite{}) - -type testAuditMiddlewareSuite struct { +type auditMiddlewareTestSuite struct { + suite.Suite svr *server.Server cleanup cleanUpFunc urlPrefix string } -func (s *testAuditMiddlewareSuite) SetUpSuite(c *C) { - s.svr, s.cleanup = mustNewServer(c, func(cfg *config.Config) { +func TestAuditMiddlewareTestSuite(t *testing.T) { + suite.Run(t, new(auditMiddlewareTestSuite)) +} + +func (suite *auditMiddlewareTestSuite) SetupSuite() { + re := suite.Require() + suite.svr, suite.cleanup = mustNewServer(re, func(cfg *config.Config) { cfg.Replication.EnablePlacementRules = false }) - mustWaitLeader(c, []*server.Server{s.svr}) + mustWaitLeader(re, []*server.Server{suite.svr}) - addr := s.svr.GetAddr() - s.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) + addr := suite.svr.GetAddr() + suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) } -func (s *testAuditMiddlewareSuite) TearDownSuite(c *C) { - s.cleanup() +func (suite *auditMiddlewareTestSuite) TearDownSuite() { + suite.cleanup() } -func (s *testAuditMiddlewareSuite) TestConfigAuditSwitch(c *C) { - addr := fmt.Sprintf("%s/service-middleware/config", s.urlPrefix) - +func (suite *auditMiddlewareTestSuite) TestConfigAuditSwitch() { + addr := fmt.Sprintf("%s/service-middleware/config", suite.urlPrefix) sc := &config.ServiceMiddlewareConfig{} - c.Assert(tu.ReadGetJSON(c, testDialClient, addr, sc), IsNil) - c.Assert(sc.EnableAudit, Equals, false) + re := suite.Require() + suite.NoError(tu.ReadGetJSON(re, testDialClient, addr, sc)) + suite.False(sc.EnableAudit) ms := map[string]interface{}{ "enable-audit": "true", "enable-rate-limit": "true", } postData, err := json.Marshal(ms) - c.Assert(err, IsNil) - c.Assert(tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(c)), IsNil) + suite.NoError(err) + suite.NoError(tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re))) sc = &config.ServiceMiddlewareConfig{} - c.Assert(tu.ReadGetJSON(c, testDialClient, addr, sc), IsNil) - c.Assert(sc.EnableAudit, Equals, true) - c.Assert(sc.EnableRateLimit, Equals, true) + suite.NoError(tu.ReadGetJSON(re, testDialClient, addr, sc)) + suite.True(sc.EnableAudit) + suite.True(sc.EnableRateLimit) ms = map[string]interface{}{ "audit.enable-audit": "false", "enable-rate-limit": "false", } postData, err = json.Marshal(ms) - c.Assert(err, IsNil) - c.Assert(tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(c)), IsNil) + suite.NoError(err) + suite.NoError(tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re))) sc = &config.ServiceMiddlewareConfig{} - c.Assert(tu.ReadGetJSON(c, testDialClient, addr, sc), IsNil) - c.Assert(sc.EnableAudit, Equals, false) - c.Assert(sc.EnableRateLimit, Equals, false) + suite.NoError(tu.ReadGetJSON(re, testDialClient, addr, sc)) + suite.False(sc.EnableAudit) + suite.False(sc.EnableRateLimit) // test empty ms = map[string]interface{}{} postData, err = json.Marshal(ms) - c.Assert(err, IsNil) - c.Assert(tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(c), tu.StringContain(c, "The input is empty.")), IsNil) - + suite.NoError(err) + suite.NoError(tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re), tu.StringContain(re, "The input is empty."))) ms = map[string]interface{}{ "audit": "false", } postData, err = json.Marshal(ms) - c.Assert(err, IsNil) - c.Assert(tu.CheckPostJSON(testDialClient, addr, postData, tu.Status(c, http.StatusBadRequest), tu.StringEqual(c, "config item audit not found")), IsNil) - - c.Assert(failpoint.Enable("github.com/tikv/pd/server/config/persistServiceMiddlewareFail", "return(true)"), IsNil) + suite.NoError(err) + suite.NoError(tu.CheckPostJSON(testDialClient, addr, postData, tu.Status(re, http.StatusBadRequest), tu.StringEqual(re, "config item audit not found"))) + suite.NoError(failpoint.Enable("github.com/tikv/pd/server/config/persistServiceMiddlewareFail", "return(true)")) ms = map[string]interface{}{ "audit.enable-audit": "true", } postData, err = json.Marshal(ms) - c.Assert(err, IsNil) - c.Assert(tu.CheckPostJSON(testDialClient, addr, postData, tu.Status(c, http.StatusBadRequest)), IsNil) - c.Assert(failpoint.Disable("github.com/tikv/pd/server/config/persistServiceMiddlewareFail"), IsNil) - + suite.NoError(err) + suite.NoError(tu.CheckPostJSON(testDialClient, addr, postData, tu.Status(re, http.StatusBadRequest))) + suite.NoError(failpoint.Disable("github.com/tikv/pd/server/config/persistServiceMiddlewareFail")) ms = map[string]interface{}{ "audit.audit": "false", } postData, err = json.Marshal(ms) - c.Assert(err, IsNil) - c.Assert(tu.CheckPostJSON(testDialClient, addr, postData, tu.Status(c, http.StatusBadRequest), tu.StringEqual(c, "config item audit not found")), IsNil) + suite.NoError(err) + suite.NoError(tu.CheckPostJSON(testDialClient, addr, postData, tu.Status(re, http.StatusBadRequest), tu.StringEqual(re, "config item audit not found"))) } -var _ = Suite(&testRateLimitConfigSuite{}) - -type testRateLimitConfigSuite struct { +type rateLimitConfigTestSuite struct { + suite.Suite svr *server.Server cleanup cleanUpFunc urlPrefix string } -func (s *testRateLimitConfigSuite) SetUpSuite(c *C) { - s.svr, s.cleanup = mustNewServer(c) - mustWaitLeader(c, []*server.Server{s.svr}) - mustBootstrapCluster(c, s.svr) - s.urlPrefix = fmt.Sprintf("%s%s/api/v1", s.svr.GetAddr(), apiPrefix) +func TestRateLimitConfigTestSuite(t *testing.T) { + suite.Run(t, new(rateLimitConfigTestSuite)) } -func (s *testRateLimitConfigSuite) TearDownSuite(c *C) { - s.cleanup() +func (suite *rateLimitConfigTestSuite) SetupSuite() { + re := suite.Require() + suite.svr, suite.cleanup = mustNewServer(re) + mustWaitLeader(re, []*server.Server{suite.svr}) + mustBootstrapCluster(re, suite.svr) + suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", suite.svr.GetAddr(), apiPrefix) } -func (s *testRateLimitConfigSuite) TestUpdateRateLimitConfig(c *C) { - urlPrefix := fmt.Sprintf("%s%s/api/v1/service-middleware/config/rate-limit", s.svr.GetAddr(), apiPrefix) +func (suite *rateLimitConfigTestSuite) TearDownSuite() { + suite.cleanup() +} + +func (suite *rateLimitConfigTestSuite) TestUpdateRateLimitConfig() { + urlPrefix := fmt.Sprintf("%s%s/api/v1/service-middleware/config/rate-limit", suite.svr.GetAddr(), apiPrefix) // test empty type input := make(map[string]interface{}) input["type"] = 123 jsonBody, err := json.Marshal(input) - c.Assert(err, IsNil) - + suite.NoError(err) + re := suite.Require() err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, - tu.Status(c, http.StatusBadRequest), tu.StringEqual(c, "\"The type is empty.\"\n")) - c.Assert(err, IsNil) + tu.Status(re, http.StatusBadRequest), tu.StringEqual(re, "\"The type is empty.\"\n")) + suite.NoError(err) // test invalid type input = make(map[string]interface{}) input["type"] = "url" jsonBody, err = json.Marshal(input) - c.Assert(err, IsNil) + suite.NoError(err) err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, - tu.Status(c, http.StatusBadRequest), tu.StringEqual(c, "\"The type is invalid.\"\n")) - c.Assert(err, IsNil) + tu.Status(re, http.StatusBadRequest), tu.StringEqual(re, "\"The type is invalid.\"\n")) + suite.NoError(err) // test empty label input = make(map[string]interface{}) input["type"] = "label" input["label"] = "" jsonBody, err = json.Marshal(input) - c.Assert(err, IsNil) + suite.NoError(err) err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, - tu.Status(c, http.StatusBadRequest), tu.StringEqual(c, "\"The label is empty.\"\n")) - c.Assert(err, IsNil) + tu.Status(re, http.StatusBadRequest), tu.StringEqual(re, "\"The label is empty.\"\n")) + suite.NoError(err) // test no label matched input = make(map[string]interface{}) input["type"] = "label" input["label"] = "TestLabel" jsonBody, err = json.Marshal(input) - c.Assert(err, IsNil) + suite.NoError(err) err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, - tu.Status(c, http.StatusBadRequest), tu.StringEqual(c, "\"There is no label matched.\"\n")) - c.Assert(err, IsNil) + tu.Status(re, http.StatusBadRequest), tu.StringEqual(re, "\"There is no label matched.\"\n")) + suite.NoError(err) // test empty path input = make(map[string]interface{}) input["type"] = "path" input["path"] = "" jsonBody, err = json.Marshal(input) - c.Assert(err, IsNil) + suite.NoError(err) err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, - tu.Status(c, http.StatusBadRequest), tu.StringEqual(c, "\"The path is empty.\"\n")) - c.Assert(err, IsNil) + tu.Status(re, http.StatusBadRequest), tu.StringEqual(re, "\"The path is empty.\"\n")) + suite.NoError(err) // test path but no label matched input = make(map[string]interface{}) input["type"] = "path" input["path"] = "/pd/api/v1/test" jsonBody, err = json.Marshal(input) - c.Assert(err, IsNil) + suite.NoError(err) err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, - tu.Status(c, http.StatusBadRequest), tu.StringEqual(c, "\"There is no label matched.\"\n")) - c.Assert(err, IsNil) + tu.Status(re, http.StatusBadRequest), tu.StringEqual(re, "\"There is no label matched.\"\n")) + suite.NoError(err) // no change input = make(map[string]interface{}) input["type"] = "label" input["label"] = "GetHealthStatus" jsonBody, err = json.Marshal(input) - c.Assert(err, IsNil) + suite.NoError(err) err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, - tu.StatusOK(c), tu.StringEqual(c, "\"No changed.\"\n")) - c.Assert(err, IsNil) + tu.StatusOK(re), tu.StringEqual(re, "\"No changed.\"\n")) + suite.NoError(err) // change concurrency input = make(map[string]interface{}) @@ -205,16 +211,16 @@ func (s *testRateLimitConfigSuite) TestUpdateRateLimitConfig(c *C) { input["method"] = "GET" input["concurrency"] = 100 jsonBody, err = json.Marshal(input) - c.Assert(err, IsNil) + suite.NoError(err) err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, - tu.StatusOK(c), tu.StringContain(c, "Concurrency limiter is changed.")) - c.Assert(err, IsNil) + tu.StatusOK(re), tu.StringContain(re, "Concurrency limiter is changed.")) + suite.NoError(err) input["concurrency"] = 0 jsonBody, err = json.Marshal(input) - c.Assert(err, IsNil) + suite.NoError(err) err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, - tu.StatusOK(c), tu.StringContain(c, "Concurrency limiter is deleted.")) - c.Assert(err, IsNil) + tu.StatusOK(re), tu.StringContain(re, "Concurrency limiter is deleted.")) + suite.NoError(err) // change qps input = make(map[string]interface{}) @@ -223,10 +229,10 @@ func (s *testRateLimitConfigSuite) TestUpdateRateLimitConfig(c *C) { input["method"] = "GET" input["qps"] = 100 jsonBody, err = json.Marshal(input) - c.Assert(err, IsNil) + suite.NoError(err) err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, - tu.StatusOK(c), tu.StringContain(c, "QPS rate limiter is changed.")) - c.Assert(err, IsNil) + tu.StatusOK(re), tu.StringContain(re, "QPS rate limiter is changed.")) + suite.NoError(err) input = make(map[string]interface{}) input["type"] = "path" @@ -234,18 +240,18 @@ func (s *testRateLimitConfigSuite) TestUpdateRateLimitConfig(c *C) { input["method"] = "GET" input["qps"] = 0.3 jsonBody, err = json.Marshal(input) - c.Assert(err, IsNil) + suite.NoError(err) err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, - tu.StatusOK(c), tu.StringContain(c, "QPS rate limiter is changed.")) - c.Assert(err, IsNil) - c.Assert(s.svr.GetRateLimitConfig().LimiterConfig["GetHealthStatus"].QPSBurst, Equals, 1) + tu.StatusOK(re), tu.StringContain(re, "QPS rate limiter is changed.")) + suite.NoError(err) + suite.Equal(1, suite.svr.GetRateLimitConfig().LimiterConfig["GetHealthStatus"].QPSBurst) input["qps"] = -1 jsonBody, err = json.Marshal(input) - c.Assert(err, IsNil) + suite.NoError(err) err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, - tu.StatusOK(c), tu.StringContain(c, "QPS rate limiter is deleted.")) - c.Assert(err, IsNil) + tu.StatusOK(re), tu.StringContain(re, "QPS rate limiter is deleted.")) + suite.NoError(err) // change both input = make(map[string]interface{}) @@ -254,19 +260,19 @@ func (s *testRateLimitConfigSuite) TestUpdateRateLimitConfig(c *C) { input["qps"] = 100 input["concurrency"] = 100 jsonBody, err = json.Marshal(input) - c.Assert(err, IsNil) + suite.NoError(err) result := rateLimitResult{} err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, - tu.StatusOK(c), tu.StringContain(c, "Concurrency limiter is changed."), - tu.StringContain(c, "QPS rate limiter is changed."), - tu.ExtractJSON(c, &result), + tu.StatusOK(re), tu.StringContain(re, "Concurrency limiter is changed."), + tu.StringContain(re, "QPS rate limiter is changed."), + tu.ExtractJSON(re, &result), ) - c.Assert(result.LimiterConfig["Profile"].QPS, Equals, 100.) - c.Assert(result.LimiterConfig["Profile"].QPSBurst, Equals, 100) - c.Assert(result.LimiterConfig["Profile"].ConcurrencyLimit, Equals, uint64(100)) - c.Assert(err, IsNil) + suite.Equal(100., result.LimiterConfig["Profile"].QPS) + suite.Equal(100, result.LimiterConfig["Profile"].QPSBurst) + suite.Equal(uint64(100), result.LimiterConfig["Profile"].ConcurrencyLimit) + suite.NoError(err) - limiter := s.svr.GetServiceRateLimiter() + limiter := suite.svr.GetServiceRateLimiter() limiter.Update("SetRatelimitConfig", ratelimit.AddLabelAllowList()) // Allow list @@ -276,71 +282,68 @@ func (s *testRateLimitConfigSuite) TestUpdateRateLimitConfig(c *C) { input["qps"] = 100 input["concurrency"] = 100 jsonBody, err = json.Marshal(input) - c.Assert(err, IsNil) + suite.NoError(err) err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, - tu.StatusNotOK(c), tu.StringEqual(c, "\"This service is in allow list whose config can not be changed.\"\n")) - c.Assert(err, IsNil) + tu.StatusNotOK(re), tu.StringEqual(re, "\"This service is in allow list whose config can not be changed.\"\n")) + suite.NoError(err) } -func (s *testRateLimitConfigSuite) TestConfigRateLimitSwitch(c *C) { - addr := fmt.Sprintf("%s/service-middleware/config", s.urlPrefix) - +func (suite *rateLimitConfigTestSuite) TestConfigRateLimitSwitch() { + addr := fmt.Sprintf("%s/service-middleware/config", suite.urlPrefix) sc := &config.ServiceMiddlewareConfig{} - c.Assert(tu.ReadGetJSON(c, testDialClient, addr, sc), IsNil) - c.Assert(sc.EnableRateLimit, Equals, false) + re := suite.Require() + suite.NoError(tu.ReadGetJSON(re, testDialClient, addr, sc)) + suite.False(sc.EnableRateLimit) ms := map[string]interface{}{ "enable-rate-limit": "true", } postData, err := json.Marshal(ms) - c.Assert(err, IsNil) - c.Assert(tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(c)), IsNil) + suite.NoError(err) + suite.NoError(tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re))) sc = &config.ServiceMiddlewareConfig{} - c.Assert(tu.ReadGetJSON(c, testDialClient, addr, sc), IsNil) - c.Assert(sc.EnableRateLimit, Equals, true) + suite.NoError(tu.ReadGetJSON(re, testDialClient, addr, sc)) + suite.True(sc.EnableRateLimit) ms = map[string]interface{}{ "enable-rate-limit": "false", } postData, err = json.Marshal(ms) - c.Assert(err, IsNil) - c.Assert(tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(c)), IsNil) + suite.NoError(err) + suite.NoError(tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re))) sc = &config.ServiceMiddlewareConfig{} - c.Assert(tu.ReadGetJSON(c, testDialClient, addr, sc), IsNil) - c.Assert(sc.EnableRateLimit, Equals, false) + suite.NoError(tu.ReadGetJSON(re, testDialClient, addr, sc)) + suite.False(sc.EnableRateLimit) // test empty ms = map[string]interface{}{} postData, err = json.Marshal(ms) - c.Assert(err, IsNil) - c.Assert(tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(c), tu.StringContain(c, "The input is empty.")), IsNil) - + suite.NoError(err) + suite.NoError(tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re), tu.StringContain(re, "The input is empty."))) ms = map[string]interface{}{ "rate-limit": "false", } postData, err = json.Marshal(ms) - c.Assert(err, IsNil) - c.Assert(tu.CheckPostJSON(testDialClient, addr, postData, tu.Status(c, http.StatusBadRequest), tu.StringEqual(c, "config item rate-limit not found")), IsNil) - - c.Assert(failpoint.Enable("github.com/tikv/pd/server/config/persistServiceMiddlewareFail", "return(true)"), IsNil) + suite.NoError(err) + suite.NoError(tu.CheckPostJSON(testDialClient, addr, postData, tu.Status(re, http.StatusBadRequest), tu.StringEqual(re, "config item rate-limit not found"))) + suite.NoError(failpoint.Enable("github.com/tikv/pd/server/config/persistServiceMiddlewareFail", "return(true)")) ms = map[string]interface{}{ "rate-limit.enable-rate-limit": "true", } postData, err = json.Marshal(ms) - c.Assert(err, IsNil) - c.Assert(tu.CheckPostJSON(testDialClient, addr, postData, tu.Status(c, http.StatusBadRequest)), IsNil) - c.Assert(failpoint.Disable("github.com/tikv/pd/server/config/persistServiceMiddlewareFail"), IsNil) - + suite.NoError(err) + suite.NoError(tu.CheckPostJSON(testDialClient, addr, postData, tu.Status(re, http.StatusBadRequest))) + suite.NoError(failpoint.Disable("github.com/tikv/pd/server/config/persistServiceMiddlewareFail")) ms = map[string]interface{}{ "rate-limit.rate-limit": "false", } postData, err = json.Marshal(ms) - c.Assert(err, IsNil) - c.Assert(tu.CheckPostJSON(testDialClient, addr, postData, tu.Status(c, http.StatusBadRequest), tu.StringEqual(c, "config item rate-limit not found")), IsNil) + suite.NoError(err) + suite.NoError(tu.CheckPostJSON(testDialClient, addr, postData, tu.Status(re, http.StatusBadRequest), tu.StringEqual(re, "config item rate-limit not found"))) } -func (s *testRateLimitConfigSuite) TestConfigLimiterConifgByOriginAPI(c *C) { +func (suite *rateLimitConfigTestSuite) TestConfigLimiterConifgByOriginAPI() { // this test case is used to test updating `limiter-config` by origin API simply - addr := fmt.Sprintf("%s/service-middleware/config", s.urlPrefix) + addr := fmt.Sprintf("%s/service-middleware/config", suite.urlPrefix) dimensionConfig := ratelimit.DimensionConfig{QPS: 1} limiterConfig := map[string]interface{}{ "CreateOperator": dimensionConfig, @@ -349,9 +352,10 @@ func (s *testRateLimitConfigSuite) TestConfigLimiterConifgByOriginAPI(c *C) { "limiter-config": limiterConfig, } postData, err := json.Marshal(ms) - c.Assert(err, IsNil) - c.Assert(tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(c)), IsNil) + suite.NoError(err) + re := suite.Require() + suite.NoError(tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re))) sc := &config.ServiceMiddlewareConfig{} - c.Assert(tu.ReadGetJSON(c, testDialClient, addr, sc), IsNil) - c.Assert(sc.RateLimitConfig.LimiterConfig["CreateOperator"].QPS, Equals, 1.) + suite.NoError(tu.ReadGetJSON(re, testDialClient, addr, sc)) + suite.Equal(1., sc.RateLimitConfig.LimiterConfig["CreateOperator"].QPS) } diff --git a/server/api/stats_test.go b/server/api/stats_test.go index b2fbbf1b7bf..77c35b19679 100644 --- a/server/api/stats_test.go +++ b/server/api/stats_test.go @@ -17,39 +17,44 @@ package api import ( "fmt" "net/url" + "testing" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/apiutil" "github.com/tikv/pd/server" "github.com/tikv/pd/server/core" "github.com/tikv/pd/server/statistics" ) -var _ = Suite(&testStatsSuite{}) - -type testStatsSuite struct { +type statsTestSuite struct { + suite.Suite svr *server.Server cleanup cleanUpFunc urlPrefix string } -func (s *testStatsSuite) SetUpSuite(c *C) { - s.svr, s.cleanup = mustNewServer(c) - mustWaitLeader(c, []*server.Server{s.svr}) +func TestStatsTestSuite(t *testing.T) { + suite.Run(t, new(statsTestSuite)) +} + +func (suite *statsTestSuite) SetupSuite() { + re := suite.Require() + suite.svr, suite.cleanup = mustNewServer(re) + mustWaitLeader(re, []*server.Server{suite.svr}) - addr := s.svr.GetAddr() - s.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) + addr := suite.svr.GetAddr() + suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) - mustBootstrapCluster(c, s.svr) + mustBootstrapCluster(re, suite.svr) } -func (s *testStatsSuite) TearDownSuite(c *C) { - s.cleanup() +func (suite *statsTestSuite) TearDownSuite() { + suite.cleanup() } -func (s *testStatsSuite) TestRegionStats(c *C) { - statsURL := s.urlPrefix + "/stats/region" +func (suite *statsTestSuite) TestRegionStats() { + statsURL := suite.urlPrefix + "/stats/region" epoch := &metapb.RegionEpoch{ ConfVer: 1, Version: 1, @@ -117,8 +122,9 @@ func (s *testStatsSuite) TestRegionStats(c *C) { ), } + re := suite.Require() for _, r := range regions { - mustRegionHeartbeat(c, s.svr, r) + mustRegionHeartbeat(re, suite.svr, r) } // Distribution (L for leader, F for follower): @@ -141,21 +147,21 @@ func (s *testStatsSuite) TestRegionStats(c *C) { StorePeerKeys: map[uint64]int64{1: 201, 2: 50, 3: 50, 4: 170, 5: 151}, } res, err := testDialClient.Get(statsURL) - c.Assert(err, IsNil) + suite.NoError(err) defer res.Body.Close() stats := &statistics.RegionStats{} err = apiutil.ReadJSON(res.Body, stats) - c.Assert(err, IsNil) - c.Assert(stats, DeepEquals, statsAll) + suite.NoError(err) + suite.Equal(statsAll, stats) args := fmt.Sprintf("?start_key=%s&end_key=%s", url.QueryEscape("\x01\x02"), url.QueryEscape("xyz\x00\x00")) res, err = testDialClient.Get(statsURL + args) - c.Assert(err, IsNil) + suite.NoError(err) defer res.Body.Close() stats = &statistics.RegionStats{} err = apiutil.ReadJSON(res.Body, stats) - c.Assert(err, IsNil) - c.Assert(stats, DeepEquals, statsAll) + suite.NoError(err) + suite.Equal(statsAll, stats) stats23 := &statistics.RegionStats{ Count: 2, @@ -172,10 +178,10 @@ func (s *testStatsSuite) TestRegionStats(c *C) { args = fmt.Sprintf("?start_key=%s&end_key=%s", url.QueryEscape("a"), url.QueryEscape("x")) res, err = testDialClient.Get(statsURL + args) - c.Assert(err, IsNil) + suite.NoError(err) defer res.Body.Close() stats = &statistics.RegionStats{} err = apiutil.ReadJSON(res.Body, stats) - c.Assert(err, IsNil) - c.Assert(stats, DeepEquals, stats23) + suite.NoError(err) + suite.Equal(stats23, stats) } diff --git a/server/api/status_test.go b/server/api/status_test.go index 8f9c82069d2..bedf6c0e532 100644 --- a/server/api/status_test.go +++ b/server/api/status_test.go @@ -17,34 +17,32 @@ package api import ( "encoding/json" "io" + "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/server/versioninfo" ) -var _ = Suite(&testStatusAPISuite{}) - -type testStatusAPISuite struct{} - -func checkStatusResponse(c *C, body []byte) { +func checkStatusResponse(re *require.Assertions, body []byte) { got := status{} - c.Assert(json.Unmarshal(body, &got), IsNil) - c.Assert(got.BuildTS, Equals, versioninfo.PDBuildTS) - c.Assert(got.GitHash, Equals, versioninfo.PDGitHash) - c.Assert(got.Version, Equals, versioninfo.PDReleaseVersion) + re.NoError(json.Unmarshal(body, &got)) + re.Equal(versioninfo.PDBuildTS, got.BuildTS) + re.Equal(versioninfo.PDGitHash, got.GitHash) + re.Equal(versioninfo.PDReleaseVersion, got.Version) } -func (s *testStatusAPISuite) TestStatus(c *C) { - cfgs, _, clean := mustNewCluster(c, 3) +func TestStatus(t *testing.T) { + re := require.New(t) + cfgs, _, clean := mustNewCluster(re, 3) defer clean() for _, cfg := range cfgs { addr := cfg.ClientUrls + apiPrefix + "/api/v1/status" resp, err := testDialClient.Get(addr) - c.Assert(err, IsNil) + re.NoError(err) buf, err := io.ReadAll(resp.Body) - c.Assert(err, IsNil) - checkStatusResponse(c, buf) + re.NoError(err) + checkStatusResponse(re, buf) resp.Body.Close() } } diff --git a/server/api/store_test.go b/server/api/store_test.go index 99875c8f6ea..64cb500164d 100644 --- a/server/api/store_test.go +++ b/server/api/store_test.go @@ -21,22 +21,23 @@ import ( "io" "net/http" "net/url" + "testing" "time" "github.com/docker/go-units" "github.com/gogo/protobuf/proto" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" tu "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/core" ) -var _ = Suite(&testStoreSuite{}) - -type testStoreSuite struct { +type storeTestSuite struct { + suite.Suite svr *server.Server grpcSvr *server.GrpcServer cleanup cleanUpFunc @@ -44,20 +45,24 @@ type testStoreSuite struct { stores []*metapb.Store } -func requestStatusBody(c *C, client *http.Client, method string, url string) int { +func TestStoreTestSuite(t *testing.T) { + suite.Run(t, new(storeTestSuite)) +} + +func (suite *storeTestSuite) requestStatusBody(client *http.Client, method string, url string) int { req, err := http.NewRequest(method, url, nil) - c.Assert(err, IsNil) + suite.NoError(err) resp, err := client.Do(req) - c.Assert(err, IsNil) + suite.NoError(err) _, err = io.ReadAll(resp.Body) - c.Assert(err, IsNil) + suite.NoError(err) err = resp.Body.Close() - c.Assert(err, IsNil) + suite.NoError(err) return resp.StatusCode } -func (s *testStoreSuite) SetUpSuite(c *C) { - s.stores = []*metapb.Store{ +func (suite *storeTestSuite) SetupSuite() { + suite.stores = []*metapb.Store{ { // metapb.StoreState_Up == 0 Id: 1, @@ -91,26 +96,27 @@ func (s *testStoreSuite) SetUpSuite(c *C) { }, } // TODO: enable placmentrules - s.svr, s.cleanup = mustNewServer(c, func(cfg *config.Config) { cfg.Replication.EnablePlacementRules = false }) - mustWaitLeader(c, []*server.Server{s.svr}) + re := suite.Require() + suite.svr, suite.cleanup = mustNewServer(re, func(cfg *config.Config) { cfg.Replication.EnablePlacementRules = false }) + mustWaitLeader(re, []*server.Server{suite.svr}) - addr := s.svr.GetAddr() - s.grpcSvr = &server.GrpcServer{Server: s.svr} - s.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) + addr := suite.svr.GetAddr() + suite.grpcSvr = &server.GrpcServer{Server: suite.svr} + suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) - mustBootstrapCluster(c, s.svr) + mustBootstrapCluster(re, suite.svr) - for _, store := range s.stores { - mustPutStore(c, s.svr, store.Id, store.State, store.NodeState, nil) + for _, store := range suite.stores { + mustPutStore(re, suite.svr, store.Id, store.State, store.NodeState, nil) } } -func (s *testStoreSuite) TearDownSuite(c *C) { - s.cleanup() +func (suite *storeTestSuite) TearDownSuite() { + suite.cleanup() } -func checkStoresInfo(c *C, ss []*StoreInfo, want []*metapb.Store) { - c.Assert(len(ss), Equals, len(want)) +func checkStoresInfo(re *require.Assertions, ss []*StoreInfo, want []*metapb.Store) { + re.Len(ss, len(want)) mapWant := make(map[uint64]*metapb.Store) for _, s := range want { if _, ok := mapWant[s.Id]; !ok { @@ -122,35 +128,36 @@ func checkStoresInfo(c *C, ss []*StoreInfo, want []*metapb.Store) { expected := proto.Clone(mapWant[obtained.Id]).(*metapb.Store) // Ignore lastHeartbeat obtained.LastHeartbeat, expected.LastHeartbeat = 0, 0 - c.Assert(obtained, DeepEquals, expected) + re.Equal(expected, obtained) } } -func (s *testStoreSuite) TestStoresList(c *C) { - url := fmt.Sprintf("%s/stores", s.urlPrefix) +func (suite *storeTestSuite) TestStoresList() { + url := fmt.Sprintf("%s/stores", suite.urlPrefix) info := new(StoresInfo) - err := tu.ReadGetJSON(c, testDialClient, url, info) - c.Assert(err, IsNil) - checkStoresInfo(c, info.Stores, s.stores[:3]) + re := suite.Require() + err := tu.ReadGetJSON(re, testDialClient, url, info) + suite.NoError(err) + checkStoresInfo(re, info.Stores, suite.stores[:3]) - url = fmt.Sprintf("%s/stores?state=0", s.urlPrefix) + url = fmt.Sprintf("%s/stores?state=0", suite.urlPrefix) info = new(StoresInfo) - err = tu.ReadGetJSON(c, testDialClient, url, info) - c.Assert(err, IsNil) - checkStoresInfo(c, info.Stores, s.stores[:2]) + err = tu.ReadGetJSON(re, testDialClient, url, info) + suite.NoError(err) + checkStoresInfo(re, info.Stores, suite.stores[:2]) - url = fmt.Sprintf("%s/stores?state=1", s.urlPrefix) + url = fmt.Sprintf("%s/stores?state=1", suite.urlPrefix) info = new(StoresInfo) - err = tu.ReadGetJSON(c, testDialClient, url, info) - c.Assert(err, IsNil) - checkStoresInfo(c, info.Stores, s.stores[2:3]) + err = tu.ReadGetJSON(re, testDialClient, url, info) + suite.NoError(err) + checkStoresInfo(re, info.Stores, suite.stores[2:3]) } -func (s *testStoreSuite) TestStoreGet(c *C) { - url := fmt.Sprintf("%s/store/1", s.urlPrefix) - s.grpcSvr.StoreHeartbeat( +func (suite *storeTestSuite) TestStoreGet() { + url := fmt.Sprintf("%s/store/1", suite.urlPrefix) + suite.grpcSvr.StoreHeartbeat( context.Background(), &pdpb.StoreHeartbeatRequest{ - Header: &pdpb.RequestHeader{ClusterId: s.svr.ClusterID()}, + Header: &pdpb.RequestHeader{ClusterId: suite.svr.ClusterID()}, Stats: &pdpb.StoreStats{ StoreId: 1, Capacity: 1798985089024, @@ -160,92 +167,94 @@ func (s *testStoreSuite) TestStoreGet(c *C) { }, ) info := new(StoreInfo) - err := tu.ReadGetJSON(c, testDialClient, url, info) - c.Assert(err, IsNil) + err := tu.ReadGetJSON(suite.Require(), testDialClient, url, info) + suite.NoError(err) capacity, _ := units.RAMInBytes("1.636TiB") available, _ := units.RAMInBytes("1.555TiB") - c.Assert(int64(info.Status.Capacity), Equals, capacity) - c.Assert(int64(info.Status.Available), Equals, available) - checkStoresInfo(c, []*StoreInfo{info}, s.stores[:1]) + suite.Equal(capacity, int64(info.Status.Capacity)) + suite.Equal(available, int64(info.Status.Available)) + checkStoresInfo(suite.Require(), []*StoreInfo{info}, suite.stores[:1]) } -func (s *testStoreSuite) TestStoreLabel(c *C) { - url := fmt.Sprintf("%s/store/1", s.urlPrefix) +func (suite *storeTestSuite) TestStoreLabel() { + url := fmt.Sprintf("%s/store/1", suite.urlPrefix) + re := suite.Require() var info StoreInfo - err := tu.ReadGetJSON(c, testDialClient, url, &info) - c.Assert(err, IsNil) - c.Assert(info.Store.Labels, HasLen, 0) + err := tu.ReadGetJSON(re, testDialClient, url, &info) + suite.NoError(err) + suite.Len(info.Store.Labels, 0) // Test merge. // enable label match check. labelCheck := map[string]string{"strictly-match-label": "true"} lc, _ := json.Marshal(labelCheck) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/config", lc, tu.StatusOK(c)) - c.Assert(err, IsNil) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/config", lc, tu.StatusOK(re)) + suite.NoError(err) // Test set. labels := map[string]string{"zone": "cn", "host": "local"} b, err := json.Marshal(labels) - c.Assert(err, IsNil) + suite.NoError(err) // TODO: supports strictly match check in placement rules err = tu.CheckPostJSON(testDialClient, url+"/label", b, - tu.StatusNotOK(c), - tu.StringContain(c, "key matching the label was not found")) - c.Assert(err, IsNil) + tu.StatusNotOK(re), + tu.StringContain(re, "key matching the label was not found")) + suite.NoError(err) locationLabels := map[string]string{"location-labels": "zone,host"} ll, _ := json.Marshal(locationLabels) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/config", ll, tu.StatusOK(c)) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, url+"/label", b, tu.StatusOK(c)) - c.Assert(err, IsNil) - - err = tu.ReadGetJSON(c, testDialClient, url, &info) - c.Assert(err, IsNil) - c.Assert(info.Store.Labels, HasLen, len(labels)) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/config", ll, tu.StatusOK(re)) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, url+"/label", b, tu.StatusOK(re)) + suite.NoError(err) + + err = tu.ReadGetJSON(re, testDialClient, url, &info) + suite.NoError(err) + suite.Len(info.Store.Labels, len(labels)) for _, l := range info.Store.Labels { - c.Assert(labels[l.Key], Equals, l.Value) + suite.Equal(l.Value, labels[l.Key]) } // Test merge. // disable label match check. labelCheck = map[string]string{"strictly-match-label": "false"} lc, _ = json.Marshal(labelCheck) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/config", lc, tu.StatusOK(c)) - c.Assert(err, IsNil) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/config", lc, tu.StatusOK(re)) + suite.NoError(err) labels = map[string]string{"zack": "zack1", "Host": "host1"} b, err = json.Marshal(labels) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, url+"/label", b, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, url+"/label", b, tu.StatusOK(re)) + suite.NoError(err) expectLabel := map[string]string{"zone": "cn", "zack": "zack1", "host": "host1"} - err = tu.ReadGetJSON(c, testDialClient, url, &info) - c.Assert(err, IsNil) - c.Assert(info.Store.Labels, HasLen, len(expectLabel)) + err = tu.ReadGetJSON(re, testDialClient, url, &info) + suite.NoError(err) + suite.Len(info.Store.Labels, len(expectLabel)) for _, l := range info.Store.Labels { - c.Assert(expectLabel[l.Key], Equals, l.Value) + suite.Equal(expectLabel[l.Key], l.Value) } // delete label b, err = json.Marshal(map[string]string{"host": ""}) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, url+"/label", b, tu.StatusOK(c)) - c.Assert(err, IsNil) - err = tu.ReadGetJSON(c, testDialClient, url, &info) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, url+"/label", b, tu.StatusOK(re)) + suite.NoError(err) + err = tu.ReadGetJSON(re, testDialClient, url, &info) + suite.NoError(err) delete(expectLabel, "host") - c.Assert(info.Store.Labels, HasLen, len(expectLabel)) + suite.Len(info.Store.Labels, len(expectLabel)) for _, l := range info.Store.Labels { - c.Assert(expectLabel[l.Key], Equals, l.Value) + suite.Equal(expectLabel[l.Key], l.Value) } - s.stores[0].Labels = info.Store.Labels + suite.stores[0].Labels = info.Store.Labels } -func (s *testStoreSuite) TestStoreDelete(c *C) { +func (suite *storeTestSuite) TestStoreDelete() { + re := suite.Require() // prepare enough online stores to store replica. for id := 1111; id <= 1115; id++ { - mustPutStore(c, s.svr, uint64(id), metapb.StoreState_Up, metapb.NodeState_Serving, nil) + mustPutStore(re, suite.svr, uint64(id), metapb.StoreState_Up, metapb.NodeState_Serving, nil) } table := []struct { id int @@ -261,135 +270,136 @@ func (s *testStoreSuite) TestStoreDelete(c *C) { }, } for _, t := range table { - url := fmt.Sprintf("%s/store/%d", s.urlPrefix, t.id) - status := requestStatusBody(c, testDialClient, http.MethodDelete, url) - c.Assert(status, Equals, t.status) + url := fmt.Sprintf("%s/store/%d", suite.urlPrefix, t.id) + status := suite.requestStatusBody(testDialClient, http.MethodDelete, url) + suite.Equal(t.status, status) } // store 6 origin status:offline - url := fmt.Sprintf("%s/store/6", s.urlPrefix) + url := fmt.Sprintf("%s/store/6", suite.urlPrefix) store := new(StoreInfo) - err := tu.ReadGetJSON(c, testDialClient, url, store) - c.Assert(err, IsNil) - c.Assert(store.Store.PhysicallyDestroyed, IsFalse) - c.Assert(store.Store.State, Equals, metapb.StoreState_Offline) + err := tu.ReadGetJSON(re, testDialClient, url, store) + suite.NoError(err) + suite.False(store.Store.PhysicallyDestroyed) + suite.Equal(metapb.StoreState_Offline, store.Store.State) // up store success because it is offline but not physically destroyed - status := requestStatusBody(c, testDialClient, http.MethodPost, fmt.Sprintf("%s/state?state=Up", url)) - c.Assert(status, Equals, http.StatusOK) + status := suite.requestStatusBody(testDialClient, http.MethodPost, fmt.Sprintf("%s/state?state=Up", url)) + suite.Equal(http.StatusOK, status) - status = requestStatusBody(c, testDialClient, http.MethodGet, url) - c.Assert(status, Equals, http.StatusOK) + status = suite.requestStatusBody(testDialClient, http.MethodGet, url) + suite.Equal(http.StatusOK, status) store = new(StoreInfo) - err = tu.ReadGetJSON(c, testDialClient, url, store) - c.Assert(err, IsNil) - c.Assert(store.Store.State, Equals, metapb.StoreState_Up) - c.Assert(store.Store.PhysicallyDestroyed, IsFalse) + err = tu.ReadGetJSON(re, testDialClient, url, store) + suite.NoError(err) + suite.Equal(metapb.StoreState_Up, store.Store.State) + suite.False(store.Store.PhysicallyDestroyed) // offline store with physically destroyed - status = requestStatusBody(c, testDialClient, http.MethodDelete, fmt.Sprintf("%s?force=true", url)) - c.Assert(status, Equals, http.StatusOK) - err = tu.ReadGetJSON(c, testDialClient, url, store) - c.Assert(err, IsNil) - c.Assert(store.Store.State, Equals, metapb.StoreState_Offline) - c.Assert(store.Store.PhysicallyDestroyed, IsTrue) + status = suite.requestStatusBody(testDialClient, http.MethodDelete, fmt.Sprintf("%s?force=true", url)) + suite.Equal(http.StatusOK, status) + err = tu.ReadGetJSON(re, testDialClient, url, store) + suite.NoError(err) + suite.Equal(metapb.StoreState_Offline, store.Store.State) + suite.True(store.Store.PhysicallyDestroyed) // try to up store again failed because it is physically destroyed - status = requestStatusBody(c, testDialClient, http.MethodPost, fmt.Sprintf("%s/state?state=Up", url)) - c.Assert(status, Equals, http.StatusBadRequest) + status = suite.requestStatusBody(testDialClient, http.MethodPost, fmt.Sprintf("%s/state?state=Up", url)) + suite.Equal(http.StatusBadRequest, status) // reset store 6 - s.cleanup() - s.SetUpSuite(c) + suite.cleanup() + suite.SetupSuite() } -func (s *testStoreSuite) TestStoreSetState(c *C) { +func (suite *storeTestSuite) TestStoreSetState() { + re := suite.Require() // prepare enough online stores to store replica. for id := 1111; id <= 1115; id++ { - mustPutStore(c, s.svr, uint64(id), metapb.StoreState_Up, metapb.NodeState_Serving, nil) + mustPutStore(re, suite.svr, uint64(id), metapb.StoreState_Up, metapb.NodeState_Serving, nil) } - url := fmt.Sprintf("%s/store/1", s.urlPrefix) + url := fmt.Sprintf("%s/store/1", suite.urlPrefix) info := StoreInfo{} - err := tu.ReadGetJSON(c, testDialClient, url, &info) - c.Assert(err, IsNil) - c.Assert(info.Store.State, Equals, metapb.StoreState_Up) + err := tu.ReadGetJSON(re, testDialClient, url, &info) + suite.NoError(err) + suite.Equal(metapb.StoreState_Up, info.Store.State) // Set to Offline. info = StoreInfo{} - err = tu.CheckPostJSON(testDialClient, url+"/state?state=Offline", nil, tu.StatusOK(c)) - c.Assert(err, IsNil) - err = tu.ReadGetJSON(c, testDialClient, url, &info) - c.Assert(err, IsNil) - c.Assert(info.Store.State, Equals, metapb.StoreState_Offline) + err = tu.CheckPostJSON(testDialClient, url+"/state?state=Offline", nil, tu.StatusOK(re)) + suite.NoError(err) + err = tu.ReadGetJSON(re, testDialClient, url, &info) + suite.NoError(err) + suite.Equal(metapb.StoreState_Offline, info.Store.State) // store not found info = StoreInfo{} - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/store/10086/state?state=Offline", nil, tu.StatusNotOK(c)) - c.Assert(err, IsNil) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/store/10086/state?state=Offline", nil, tu.StatusNotOK(re)) + suite.NoError(err) // Invalid state. invalidStates := []string{"Foo", "Tombstone"} for _, state := range invalidStates { info = StoreInfo{} - err = tu.CheckPostJSON(testDialClient, url+"/state?state="+state, nil, tu.StatusNotOK(c)) - c.Assert(err, IsNil) - err := tu.ReadGetJSON(c, testDialClient, url, &info) - c.Assert(err, IsNil) - c.Assert(info.Store.State, Equals, metapb.StoreState_Offline) + err = tu.CheckPostJSON(testDialClient, url+"/state?state="+state, nil, tu.StatusNotOK(re)) + suite.NoError(err) + err := tu.ReadGetJSON(re, testDialClient, url, &info) + suite.NoError(err) + suite.Equal(metapb.StoreState_Offline, info.Store.State) } // Set back to Up. info = StoreInfo{} - err = tu.CheckPostJSON(testDialClient, url+"/state?state=Up", nil, tu.StatusOK(c)) - c.Assert(err, IsNil) - err = tu.ReadGetJSON(c, testDialClient, url, &info) - c.Assert(err, IsNil) - c.Assert(info.Store.State, Equals, metapb.StoreState_Up) - s.cleanup() - s.SetUpSuite(c) + err = tu.CheckPostJSON(testDialClient, url+"/state?state=Up", nil, tu.StatusOK(re)) + suite.NoError(err) + err = tu.ReadGetJSON(re, testDialClient, url, &info) + suite.NoError(err) + suite.Equal(metapb.StoreState_Up, info.Store.State) + suite.cleanup() + suite.SetupSuite() } -func (s *testStoreSuite) TestUrlStoreFilter(c *C) { +func (suite *storeTestSuite) TestUrlStoreFilter() { table := []struct { u string want []*metapb.Store }{ { u: "http://localhost:2379/pd/api/v1/stores", - want: s.stores[:3], + want: suite.stores[:3], }, { u: "http://localhost:2379/pd/api/v1/stores?state=2", - want: s.stores[3:], + want: suite.stores[3:], }, { u: "http://localhost:2379/pd/api/v1/stores?state=0", - want: s.stores[:2], + want: suite.stores[:2], }, { u: "http://localhost:2379/pd/api/v1/stores?state=2&state=1", - want: s.stores[2:], + want: suite.stores[2:], }, } for _, t := range table { uu, err := url.Parse(t.u) - c.Assert(err, IsNil) + suite.NoError(err) f, err := newStoreStateFilter(uu) - c.Assert(err, IsNil) - c.Assert(f.filter(s.stores), DeepEquals, t.want) + suite.NoError(err) + suite.Equal(t.want, f.filter(suite.stores)) } u, err := url.Parse("http://localhost:2379/pd/api/v1/stores?state=foo") - c.Assert(err, IsNil) + suite.NoError(err) _, err = newStoreStateFilter(u) - c.Assert(err, NotNil) + suite.Error(err) u, err = url.Parse("http://localhost:2379/pd/api/v1/stores?state=999999") - c.Assert(err, IsNil) + suite.NoError(err) _, err = newStoreStateFilter(u) - c.Assert(err, NotNil) + suite.Error(err) } -func (s *testStoreSuite) TestDownState(c *C) { +func (suite *storeTestSuite) TestDownState() { store := core.NewStoreInfo( &metapb.Store{ State: metapb.StoreState_Up, @@ -397,27 +407,27 @@ func (s *testStoreSuite) TestDownState(c *C) { core.SetStoreStats(&pdpb.StoreStats{}), core.SetLastHeartbeatTS(time.Now()), ) - storeInfo := newStoreInfo(s.svr.GetScheduleConfig(), store) - c.Assert(storeInfo.Store.StateName, Equals, metapb.StoreState_Up.String()) + storeInfo := newStoreInfo(suite.svr.GetScheduleConfig(), store) + suite.Equal(metapb.StoreState_Up.String(), storeInfo.Store.StateName) newStore := store.Clone(core.SetLastHeartbeatTS(time.Now().Add(-time.Minute * 2))) - storeInfo = newStoreInfo(s.svr.GetScheduleConfig(), newStore) - c.Assert(storeInfo.Store.StateName, Equals, disconnectedName) + storeInfo = newStoreInfo(suite.svr.GetScheduleConfig(), newStore) + suite.Equal(disconnectedName, storeInfo.Store.StateName) newStore = store.Clone(core.SetLastHeartbeatTS(time.Now().Add(-time.Hour * 2))) - storeInfo = newStoreInfo(s.svr.GetScheduleConfig(), newStore) - c.Assert(storeInfo.Store.StateName, Equals, downStateName) + storeInfo = newStoreInfo(suite.svr.GetScheduleConfig(), newStore) + suite.Equal(downStateName, storeInfo.Store.StateName) } -func (s *testStoreSuite) TestGetAllLimit(c *C) { - testcases := []struct { +func (suite *storeTestSuite) TestGetAllLimit() { + testCases := []struct { name string url string expectedStores map[uint64]struct{} }{ { name: "includeTombstone", - url: fmt.Sprintf("%s/stores/limit?include_tombstone=true", s.urlPrefix), + url: fmt.Sprintf("%s/stores/limit?include_tombstone=true", suite.urlPrefix), expectedStores: map[uint64]struct{}{ 1: {}, 4: {}, @@ -427,7 +437,7 @@ func (s *testStoreSuite) TestGetAllLimit(c *C) { }, { name: "excludeTombStone", - url: fmt.Sprintf("%s/stores/limit?include_tombstone=false", s.urlPrefix), + url: fmt.Sprintf("%s/stores/limit?include_tombstone=false", suite.urlPrefix), expectedStores: map[uint64]struct{}{ 1: {}, 4: {}, @@ -436,7 +446,7 @@ func (s *testStoreSuite) TestGetAllLimit(c *C) { }, { name: "default", - url: fmt.Sprintf("%s/stores/limit", s.urlPrefix), + url: fmt.Sprintf("%s/stores/limit", suite.urlPrefix), expectedStores: map[uint64]struct{}{ 1: {}, 4: {}, @@ -445,66 +455,68 @@ func (s *testStoreSuite) TestGetAllLimit(c *C) { }, } - for _, testcase := range testcases { - c.Logf(testcase.name) + re := suite.Require() + for _, testCase := range testCases { + suite.T().Logf(testCase.name) info := make(map[uint64]interface{}, 4) - err := tu.ReadGetJSON(c, testDialClient, testcase.url, &info) - c.Assert(err, IsNil) - c.Assert(len(info), Equals, len(testcase.expectedStores)) - for id := range testcase.expectedStores { + err := tu.ReadGetJSON(re, testDialClient, testCase.url, &info) + suite.NoError(err) + suite.Len(testCase.expectedStores, len(info)) + for id := range testCase.expectedStores { _, ok := info[id] - c.Assert(ok, IsTrue) + suite.True(ok) } } } -func (s *testStoreSuite) TestStoreLimitTTL(c *C) { +func (suite *storeTestSuite) TestStoreLimitTTL() { // add peer - url := fmt.Sprintf("%s/store/1/limit?ttlSecond=%v", s.urlPrefix, 5) + url := fmt.Sprintf("%s/store/1/limit?ttlSecond=%v", suite.urlPrefix, 5) data := map[string]interface{}{ "type": "add-peer", "rate": 999, } postData, err := json.Marshal(data) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, url, postData, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + re := suite.Require() + err = tu.CheckPostJSON(testDialClient, url, postData, tu.StatusOK(re)) + suite.NoError(err) // remove peer data = map[string]interface{}{ "type": "remove-peer", "rate": 998, } postData, err = json.Marshal(data) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, url, postData, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, url, postData, tu.StatusOK(re)) + suite.NoError(err) // all store limit add peer - url = fmt.Sprintf("%s/stores/limit?ttlSecond=%v", s.urlPrefix, 3) + url = fmt.Sprintf("%s/stores/limit?ttlSecond=%v", suite.urlPrefix, 3) data = map[string]interface{}{ "type": "add-peer", "rate": 997, } postData, err = json.Marshal(data) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, url, postData, tu.StatusOK(c)) - c.Assert(err, IsNil) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, url, postData, tu.StatusOK(re)) + suite.NoError(err) // all store limit remove peer data = map[string]interface{}{ "type": "remove-peer", "rate": 996, } postData, err = json.Marshal(data) - c.Assert(err, IsNil) - err = tu.CheckPostJSON(testDialClient, url, postData, tu.StatusOK(c)) - c.Assert(err, IsNil) - - c.Assert(s.svr.GetPersistOptions().GetStoreLimit(uint64(1)).AddPeer, Equals, float64(999)) - c.Assert(s.svr.GetPersistOptions().GetStoreLimit(uint64(1)).RemovePeer, Equals, float64(998)) - c.Assert(s.svr.GetPersistOptions().GetStoreLimit(uint64(2)).AddPeer, Equals, float64(997)) - c.Assert(s.svr.GetPersistOptions().GetStoreLimit(uint64(2)).RemovePeer, Equals, float64(996)) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, url, postData, tu.StatusOK(re)) + suite.NoError(err) + + suite.Equal(float64(999), suite.svr.GetPersistOptions().GetStoreLimit(uint64(1)).AddPeer) + suite.Equal(float64(998), suite.svr.GetPersistOptions().GetStoreLimit(uint64(1)).RemovePeer) + suite.Equal(float64(997), suite.svr.GetPersistOptions().GetStoreLimit(uint64(2)).AddPeer) + suite.Equal(float64(996), suite.svr.GetPersistOptions().GetStoreLimit(uint64(2)).RemovePeer) time.Sleep(5 * time.Second) - c.Assert(s.svr.GetPersistOptions().GetStoreLimit(uint64(1)).AddPeer, Not(Equals), float64(999)) - c.Assert(s.svr.GetPersistOptions().GetStoreLimit(uint64(1)).RemovePeer, Not(Equals), float64(998)) - c.Assert(s.svr.GetPersistOptions().GetStoreLimit(uint64(2)).AddPeer, Not(Equals), float64(997)) - c.Assert(s.svr.GetPersistOptions().GetStoreLimit(uint64(2)).RemovePeer, Not(Equals), float64(996)) + suite.NotEqual(float64(999), suite.svr.GetPersistOptions().GetStoreLimit(uint64(1)).AddPeer) + suite.NotEqual(float64(998), suite.svr.GetPersistOptions().GetStoreLimit(uint64(1)).RemovePeer) + suite.NotEqual(float64(997), suite.svr.GetPersistOptions().GetStoreLimit(uint64(2)).AddPeer) + suite.NotEqual(float64(996), suite.svr.GetPersistOptions().GetStoreLimit(uint64(2)).RemovePeer) } diff --git a/server/api/trend_test.go b/server/api/trend_test.go index 6e5d565573e..972af465ef9 100644 --- a/server/api/trend_test.go +++ b/server/api/trend_test.go @@ -16,76 +16,73 @@ package api import ( "fmt" + "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/require" tu "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server" "github.com/tikv/pd/server/core" "github.com/tikv/pd/server/schedule/operator" ) -var _ = Suite(&testTrendSuite{}) - -type testTrendSuite struct{} - -func (s *testTrendSuite) TestTrend(c *C) { - svr, cleanup := mustNewServer(c) +func TestTrend(t *testing.T) { + re := require.New(t) + svr, cleanup := mustNewServer(re) defer cleanup() - mustWaitLeader(c, []*server.Server{svr}) + mustWaitLeader(re, []*server.Server{svr}) - mustBootstrapCluster(c, svr) + mustBootstrapCluster(re, svr) for i := 1; i <= 3; i++ { - mustPutStore(c, svr, uint64(i), metapb.StoreState_Up, metapb.NodeState_Serving, nil) + mustPutStore(re, svr, uint64(i), metapb.StoreState_Up, metapb.NodeState_Serving, nil) } // Create 3 regions, all peers on store1 and store2, and the leaders are all on store1. - region4 := s.newRegionInfo(4, "", "a", 2, 2, []uint64{1, 2}, nil, 1) - region5 := s.newRegionInfo(5, "a", "b", 2, 2, []uint64{1, 2}, nil, 1) - region6 := s.newRegionInfo(6, "b", "", 2, 2, []uint64{1, 2}, nil, 1) - mustRegionHeartbeat(c, svr, region4) - mustRegionHeartbeat(c, svr, region5) - mustRegionHeartbeat(c, svr, region6) + region4 := newRegionInfo(4, "", "a", 2, 2, []uint64{1, 2}, nil, 1) + region5 := newRegionInfo(5, "a", "b", 2, 2, []uint64{1, 2}, nil, 1) + region6 := newRegionInfo(6, "b", "", 2, 2, []uint64{1, 2}, nil, 1) + mustRegionHeartbeat(re, svr, region4) + mustRegionHeartbeat(re, svr, region5) + mustRegionHeartbeat(re, svr, region6) // Create 3 operators that transfers leader, moves follower, moves leader. - c.Assert(svr.GetHandler().AddTransferLeaderOperator(4, 2), IsNil) - c.Assert(svr.GetHandler().AddTransferPeerOperator(5, 2, 3), IsNil) + re.NoError(svr.GetHandler().AddTransferLeaderOperator(4, 2)) + re.NoError(svr.GetHandler().AddTransferPeerOperator(5, 2, 3)) time.Sleep(1 * time.Second) - c.Assert(svr.GetHandler().AddTransferPeerOperator(6, 1, 3), IsNil) - + re.NoError(svr.GetHandler().AddTransferPeerOperator(6, 1, 3)) // Complete the operators. - mustRegionHeartbeat(c, svr, region4.Clone(core.WithLeader(region4.GetStorePeer(2)))) + mustRegionHeartbeat(re, svr, region4.Clone(core.WithLeader(region4.GetStorePeer(2)))) op, err := svr.GetHandler().GetOperator(5) - c.Assert(err, IsNil) - c.Assert(op, NotNil) + re.NoError(err) + re.NotNil(op) newPeerID := op.Step(0).(operator.AddLearner).PeerID region5 = region5.Clone(core.WithAddPeer(&metapb.Peer{Id: newPeerID, StoreId: 3, Role: metapb.PeerRole_Learner}), core.WithIncConfVer()) - mustRegionHeartbeat(c, svr, region5) + mustRegionHeartbeat(re, svr, region5) region5 = region5.Clone(core.WithPromoteLearner(newPeerID), core.WithRemoveStorePeer(2), core.WithIncConfVer()) - mustRegionHeartbeat(c, svr, region5) + mustRegionHeartbeat(re, svr, region5) op, err = svr.GetHandler().GetOperator(6) - c.Assert(err, IsNil) - c.Assert(op, NotNil) + re.NoError(err) + re.NotNil(op) newPeerID = op.Step(0).(operator.AddLearner).PeerID region6 = region6.Clone(core.WithAddPeer(&metapb.Peer{Id: newPeerID, StoreId: 3, Role: metapb.PeerRole_Learner}), core.WithIncConfVer()) - mustRegionHeartbeat(c, svr, region6) + mustRegionHeartbeat(re, svr, region6) region6 = region6.Clone(core.WithPromoteLearner(newPeerID), core.WithLeader(region6.GetStorePeer(2)), core.WithRemoveStorePeer(1), core.WithIncConfVer()) - mustRegionHeartbeat(c, svr, region6) + mustRegionHeartbeat(re, svr, region6) var trend Trend - err = tu.ReadGetJSON(c, testDialClient, fmt.Sprintf("%s%s/api/v1/trend", svr.GetAddr(), apiPrefix), &trend) - c.Assert(err, IsNil) + err = tu.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s%s/api/v1/trend", svr.GetAddr(), apiPrefix), &trend) + re.NoError(err) // Check store states. expectLeaderCount := map[uint64]int{1: 1, 2: 2, 3: 0} expectRegionCount := map[uint64]int{1: 2, 2: 2, 3: 2} - c.Assert(trend.Stores, HasLen, 3) + re.Len(trend.Stores, 3) for _, store := range trend.Stores { - c.Assert(store.LeaderCount, Equals, expectLeaderCount[store.ID]) - c.Assert(store.RegionCount, Equals, expectRegionCount[store.ID]) + re.Equal(expectLeaderCount[store.ID], store.LeaderCount) + re.Equal(expectRegionCount[store.ID], store.RegionCount) } // Check history. @@ -94,13 +91,13 @@ func (s *testTrendSuite) TestTrend(c *C) { {From: 1, To: 3, Kind: "region"}: 1, {From: 2, To: 3, Kind: "region"}: 1, } - c.Assert(trend.History.Entries, HasLen, 3) + re.Len(trend.History.Entries, 3) for _, history := range trend.History.Entries { - c.Assert(history.Count, Equals, expectHistory[trendHistoryEntry{From: history.From, To: history.To, Kind: history.Kind}]) + re.Equal(expectHistory[trendHistoryEntry{From: history.From, To: history.To, Kind: history.Kind}], history.Count) } } -func (s *testTrendSuite) newRegionInfo(id uint64, startKey, endKey string, confVer, ver uint64, voters []uint64, learners []uint64, leaderStore uint64) *core.RegionInfo { +func newRegionInfo(id uint64, startKey, endKey string, confVer, ver uint64, voters []uint64, learners []uint64, leaderStore uint64) *core.RegionInfo { var ( peers = make([]*metapb.Peer, 0, len(voters)+len(learners)) leader *metapb.Peer diff --git a/server/api/tso_test.go b/server/api/tso_test.go index af73fd1ded0..2f59a2ecf07 100644 --- a/server/api/tso_test.go +++ b/server/api/tso_test.go @@ -16,44 +16,50 @@ package api import ( "fmt" + "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/suite" tu "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server" "github.com/tikv/pd/server/config" ) -var _ = Suite(&testTsoSuite{}) - -type testTsoSuite struct { +type tsoTestSuite struct { + suite.Suite svr *server.Server cleanup cleanUpFunc urlPrefix string } -func (s *testTsoSuite) SetUpSuite(c *C) { - s.svr, s.cleanup = mustNewServer(c, func(cfg *config.Config) { +func TestTSOTestSuite(t *testing.T) { + suite.Run(t, new(tsoTestSuite)) +} + +func (suite *tsoTestSuite) SetupSuite() { + re := suite.Require() + suite.svr, suite.cleanup = mustNewServer(re, func(cfg *config.Config) { cfg.EnableLocalTSO = true cfg.Labels[config.ZoneLabel] = "dc-1" }) - mustWaitLeader(c, []*server.Server{s.svr}) + mustWaitLeader(re, []*server.Server{suite.svr}) - addr := s.svr.GetAddr() - s.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) + addr := suite.svr.GetAddr() + suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) } -func (s *testTsoSuite) TearDownSuite(c *C) { - s.cleanup() +func (suite *tsoTestSuite) TearDownSuite() { + suite.cleanup() } -func (s *testTsoSuite) TestTransferAllocator(c *C) { - tu.WaitUntil(c, func() bool { - s.svr.GetTSOAllocatorManager().ClusterDCLocationChecker() - _, err := s.svr.GetTSOAllocatorManager().GetAllocator("dc-1") +func (suite *tsoTestSuite) TestTransferAllocator() { + re := suite.Require() + tu.Eventually(re, func() bool { + suite.svr.GetTSOAllocatorManager().ClusterDCLocationChecker() + _, err := suite.svr.GetTSOAllocatorManager().GetAllocator("dc-1") return err == nil }, tu.WithRetryTimes(5), tu.WithSleepInterval(3*time.Second)) - addr := s.urlPrefix + "/tso/allocator/transfer/pd1?dcLocation=dc-1" - err := tu.CheckPostJSON(testDialClient, addr, nil, tu.StatusOK(c)) - c.Assert(err, IsNil) + addr := suite.urlPrefix + "/tso/allocator/transfer/pd1?dcLocation=dc-1" + err := tu.CheckPostJSON(testDialClient, addr, nil, tu.StatusOK(re)) + suite.NoError(err) } diff --git a/server/api/unsafe_operation_test.go b/server/api/unsafe_operation_test.go index f9d060d2fe3..62df25c6b68 100644 --- a/server/api/unsafe_operation_test.go +++ b/server/api/unsafe_operation_test.go @@ -17,63 +17,69 @@ package api import ( "encoding/json" "fmt" + "testing" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/suite" tu "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server" "github.com/tikv/pd/server/cluster" ) -var _ = Suite(&testUnsafeAPISuite{}) - -type testUnsafeAPISuite struct { +type unsafeOperationTestSuite struct { + suite.Suite svr *server.Server cleanup cleanUpFunc urlPrefix string } -func (s *testUnsafeAPISuite) SetUpSuite(c *C) { - s.svr, s.cleanup = mustNewServer(c) - mustWaitLeader(c, []*server.Server{s.svr}) +func TestUnsafeOperationTestSuite(t *testing.T) { + suite.Run(t, new(unsafeOperationTestSuite)) +} + +func (suite *unsafeOperationTestSuite) SetupSuite() { + re := suite.Require() + suite.svr, suite.cleanup = mustNewServer(re) + mustWaitLeader(re, []*server.Server{suite.svr}) - addr := s.svr.GetAddr() - s.urlPrefix = fmt.Sprintf("%s%s/api/v1/admin/unsafe", addr, apiPrefix) + addr := suite.svr.GetAddr() + suite.urlPrefix = fmt.Sprintf("%s%s/api/v1/admin/unsafe", addr, apiPrefix) - mustBootstrapCluster(c, s.svr) - mustPutStore(c, s.svr, 1, metapb.StoreState_Offline, metapb.NodeState_Removing, nil) + mustBootstrapCluster(re, suite.svr) + mustPutStore(re, suite.svr, 1, metapb.StoreState_Offline, metapb.NodeState_Removing, nil) } -func (s *testUnsafeAPISuite) TearDownSuite(c *C) { - s.cleanup() +func (suite *unsafeOperationTestSuite) TearDownSuite() { + suite.cleanup() } -func (s *testUnsafeAPISuite) TestRemoveFailedStores(c *C) { +func (suite *unsafeOperationTestSuite) TestRemoveFailedStores() { input := map[string]interface{}{"stores": []uint64{}} data, _ := json.Marshal(input) - err := tu.CheckPostJSON(testDialClient, s.urlPrefix+"/remove-failed-stores", data, tu.StatusNotOK(c), - tu.StringEqual(c, "\"[PD:unsaferecovery:ErrUnsafeRecoveryInvalidInput]invalid input no store specified\"\n")) - c.Assert(err, IsNil) + re := suite.Require() + err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/remove-failed-stores", data, tu.StatusNotOK(re), + tu.StringEqual(re, "\"[PD:unsaferecovery:ErrUnsafeRecoveryInvalidInput]invalid input no store specified\"\n")) + suite.NoError(err) input = map[string]interface{}{"stores": []string{"abc", "def"}} data, _ = json.Marshal(input) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/remove-failed-stores", data, tu.StatusNotOK(c), - tu.StringEqual(c, "\"Store ids are invalid\"\n")) - c.Assert(err, IsNil) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/remove-failed-stores", data, tu.StatusNotOK(re), + tu.StringEqual(re, "\"Store ids are invalid\"\n")) + suite.NoError(err) input = map[string]interface{}{"stores": []uint64{1, 2}} data, _ = json.Marshal(input) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/remove-failed-stores", data, tu.StatusNotOK(c), - tu.StringEqual(c, "\"[PD:unsaferecovery:ErrUnsafeRecoveryInvalidInput]invalid input store 2 doesn't exist\"\n")) - c.Assert(err, IsNil) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/remove-failed-stores", data, tu.StatusNotOK(re), + tu.StringEqual(re, "\"[PD:unsaferecovery:ErrUnsafeRecoveryInvalidInput]invalid input store 2 doesn't exist\"\n")) + suite.NoError(err) input = map[string]interface{}{"stores": []uint64{1}} data, _ = json.Marshal(input) - err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/remove-failed-stores", data, tu.StatusOK(c)) - c.Assert(err, IsNil) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/remove-failed-stores", data, tu.StatusOK(re)) + suite.NoError(err) // Test show var output []cluster.StageOutput - err = tu.ReadGetJSON(c, testDialClient, s.urlPrefix+"/remove-failed-stores/show", &output) - c.Assert(err, IsNil) + err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/remove-failed-stores/show", &output) + suite.NoError(err) } diff --git a/server/api/version_test.go b/server/api/version_test.go index 7cbbab688cf..41254649c34 100644 --- a/server/api/version_test.go +++ b/server/api/version_test.go @@ -19,70 +19,67 @@ import ( "io" "os" "path/filepath" - "strings" + "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/failpoint" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/assertutil" "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server" "github.com/tikv/pd/server/config" ) -var _ = Suite(&testVersionSuite{}) - -func checkerWithNilAssert(c *C) *assertutil.Checker { +func checkerWithNilAssert(re *require.Assertions) *assertutil.Checker { checker := assertutil.NewChecker() checker.FailNow = func() { - c.FailNow() + re.FailNow("should be nil") } checker.IsNil = func(obtained interface{}) { - c.Assert(obtained, IsNil) + re.Nil(obtained) } return checker } -type testVersionSuite struct{} - -func (s *testVersionSuite) TestGetVersion(c *C) { +func TestGetVersion(t *testing.T) { // TODO: enable it. - c.Skip("Temporary disable. See issue: https://github.com/tikv/pd/issues/1893") + t.Skip("Temporary disable. See issue: https://github.com/tikv/pd/issues/1893") + re := require.New(t) fname := filepath.Join(os.TempDir(), "stdout") old := os.Stdout temp, _ := os.Create(fname) os.Stdout = temp - cfg := server.NewTestSingleConfig(checkerWithNilAssert(c)) + cfg := server.NewTestSingleConfig(checkerWithNilAssert(re)) reqCh := make(chan struct{}) go func() { <-reqCh time.Sleep(200 * time.Millisecond) addr := cfg.ClientUrls + apiPrefix + "/api/v1/version" resp, err := testDialClient.Get(addr) - c.Assert(err, IsNil) + re.NoError(err) defer resp.Body.Close() _, err = io.ReadAll(resp.Body) - c.Assert(err, IsNil) + re.NoError(err) }() ctx, cancel := context.WithCancel(context.Background()) ch := make(chan *server.Server) go func(cfg *config.Config) { s, err := server.CreateServer(ctx, cfg, NewHandler) - c.Assert(err, IsNil) - c.Assert(failpoint.Enable("github.com/tikv/pd/server/memberNil", `return(true)`), IsNil) + re.NoError(err) + re.NoError(failpoint.Enable("github.com/tikv/pd/server/memberNil", `return(true)`)) reqCh <- struct{}{} err = s.Run() - c.Assert(err, IsNil) + re.NoError(err) ch <- s }(cfg) svr := <-ch close(ch) out, _ := os.ReadFile(fname) - c.Assert(strings.Contains(string(out), "PANIC"), IsFalse) + re.NotContains(string(out), "PANIC") // clean up func() { @@ -93,5 +90,5 @@ func (s *testVersionSuite) TestGetVersion(c *C) { cancel() testutil.CleanServer(cfg.DataDir) }() - c.Assert(failpoint.Disable("github.com/tikv/pd/server/memberNil"), IsNil) + re.NoError(failpoint.Disable("github.com/tikv/pd/server/memberNil")) } diff --git a/server/schedule/placement/region_rule_cache_test.go b/server/schedule/placement/region_rule_cache_test.go index 7c578a11e34..a9224c49162 100644 --- a/server/schedule/placement/region_rule_cache_test.go +++ b/server/schedule/placement/region_rule_cache_test.go @@ -181,7 +181,7 @@ func TestRegionRuleFitCache(t *testing.T) { } for _, testCase := range testCases { t.Log(testCase.name) - re.Equal(false, cache.IsUnchanged(testCase.region, testCase.rules, mockStoresNoHeartbeat(3))) + re.False(cache.IsUnchanged(testCase.region, testCase.rules, mockStoresNoHeartbeat(3))) } // Invalid Input4 re.False(cache.IsUnchanged(mockRegion(3, 0), addExtraRules(0), nil)) diff --git a/tests/client/client_test.go b/tests/client/client_test.go index c6c4f45ee47..003b4f73c32 100644 --- a/tests/client/client_test.go +++ b/tests/client/client_test.go @@ -692,7 +692,7 @@ func (suite *clientTestSuite) SetupSuite() { re := suite.Require() suite.srv, suite.cleanup, err = server.NewTestServer(suite.checkerWithNilAssert()) suite.NoError(err) - suite.grpcPDClient = testutil.MustNewGrpcClientWithTestify(re, suite.srv.GetAddr()) + suite.grpcPDClient = testutil.MustNewGrpcClient(re, suite.srv.GetAddr()) suite.grpcSvr = &server.GrpcServer{Server: suite.srv} suite.mustWaitLeader(map[string]*server.Server{suite.srv.GetAddr(): suite.srv}) diff --git a/tests/pdctl/helper.go b/tests/pdctl/helper.go index bbfa351cc32..9a6adf566a3 100644 --- a/tests/pdctl/helper.go +++ b/tests/pdctl/helper.go @@ -43,9 +43,9 @@ func ExecuteCommand(root *cobra.Command, args ...string) (output []byte, err err } // CheckStoresInfo is used to check the test results. -// CheckStoresInfo will not check Store.State because this field has been omitted pdctl output +// CheckStoresInfo will not check Store.State because this field has been omitted pd-ctl output func CheckStoresInfo(re *require.Assertions, stores []*api.StoreInfo, want []*api.StoreInfo) { - re.Equal(len(want), len(stores)) + re.Len(stores, len(want)) mapWant := make(map[uint64]*api.StoreInfo) for _, s := range want { if _, ok := mapWant[s.Store.Id]; !ok { @@ -70,7 +70,7 @@ func CheckStoresInfo(re *require.Assertions, stores []*api.StoreInfo, want []*ap // CheckRegionInfo is used to check the test results. func CheckRegionInfo(re *require.Assertions, output *api.RegionInfo, expected *core.RegionInfo) { - region := api.NewRegionInfo(expected) + region := api.NewAPIRegionInfo(expected) output.Adjust() re.Equal(region, output) } diff --git a/tests/server/api/api_test.go b/tests/server/api/api_test.go index 49f8026e2af..012c09b287e 100644 --- a/tests/server/api/api_test.go +++ b/tests/server/api/api_test.go @@ -507,7 +507,7 @@ func TestRemovingProgress(t *testing.T) { cluster.WaitLeader() leader := cluster.GetServer(cluster.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leader.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClient(re, leader.GetAddr()) clusterID := leader.GetClusterID() req := &pdpb.BootstrapRequest{ Header: testutil.NewRequestHeader(clusterID), @@ -624,7 +624,7 @@ func TestPreparingProgress(t *testing.T) { cluster.WaitLeader() leader := cluster.GetServer(cluster.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leader.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClient(re, leader.GetAddr()) clusterID := leader.GetClusterID() req := &pdpb.BootstrapRequest{ Header: testutil.NewRequestHeader(clusterID), diff --git a/tests/server/cluster/cluster_test.go b/tests/server/cluster/cluster_test.go index 1bf2ecad485..d5f623712ae 100644 --- a/tests/server/cluster/cluster_test.go +++ b/tests/server/cluster/cluster_test.go @@ -68,7 +68,7 @@ func TestBootstrap(t *testing.T) { tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() // IsBootstrapped returns false. @@ -108,7 +108,7 @@ func TestDamagedRegion(t *testing.T) { tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) rc := leaderServer.GetRaftCluster() @@ -187,7 +187,7 @@ func TestGetPutConfig(t *testing.T) { tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) rc := leaderServer.GetRaftCluster() @@ -407,7 +407,7 @@ func TestRaftClusterRestart(t *testing.T) { tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) @@ -437,7 +437,7 @@ func TestRaftClusterMultipleRestart(t *testing.T) { tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) // add an offline store @@ -480,7 +480,7 @@ func TestGetPDMembers(t *testing.T) { tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() req := &pdpb.GetMembersRequest{Header: testutil.NewRequestHeader(clusterID)} resp, err := grpcPDClient.GetMembers(context.Background(), req) @@ -499,7 +499,7 @@ func TestNotLeader(t *testing.T) { re.NoError(tc.RunInitialServers()) tc.WaitLeader() followerServer := tc.GetServer(tc.GetFollower()) - grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, followerServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClient(re, followerServer.GetAddr()) clusterID := followerServer.GetClusterID() req := &pdpb.AllocIDRequest{Header: testutil.NewRequestHeader(clusterID)} resp, err := grpcPDClient.AllocID(context.Background(), req) @@ -523,7 +523,7 @@ func TestStoreVersionChange(t *testing.T) { tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) svr := leaderServer.GetServer() @@ -560,7 +560,7 @@ func TestConcurrentHandleRegion(t *testing.T) { re.NoError(err) tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) storeAddrs := []string{"127.0.1.1:0", "127.0.1.1:1", "127.0.1.1:2"} @@ -676,7 +676,7 @@ func TestSetScheduleOpt(t *testing.T) { tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) @@ -836,7 +836,7 @@ func TestTiFlashWithPlacementRules(t *testing.T) { re.NoError(err) tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) @@ -886,7 +886,7 @@ func TestReplicationModeStatus(t *testing.T) { re.NoError(err) tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() req := newBootstrapRequest(clusterID) res, err := grpcPDClient.Bootstrap(context.Background(), req) @@ -984,7 +984,7 @@ func TestOfflineStoreLimit(t *testing.T) { re.NoError(err) tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) storeAddrs := []string{"127.0.1.1:0", "127.0.1.1:1"} @@ -1076,7 +1076,7 @@ func TestUpgradeStoreLimit(t *testing.T) { re.NoError(err) tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) rc := leaderServer.GetRaftCluster() @@ -1135,7 +1135,7 @@ func TestStaleTermHeartbeat(t *testing.T) { re.NoError(err) tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) storeAddrs := []string{"127.0.1.1:0", "127.0.1.1:1", "127.0.1.1:2"} @@ -1255,7 +1255,7 @@ func TestMinResolvedTS(t *testing.T) { tc.WaitLeader() leaderServer := tc.GetServer(tc.GetLeader()) id := leaderServer.GetAllocator() - grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) rc := leaderServer.GetRaftCluster() diff --git a/tests/server/cluster/cluster_work_test.go b/tests/server/cluster/cluster_work_test.go index 5dee7da02cd..6b41a3a92c8 100644 --- a/tests/server/cluster/cluster_work_test.go +++ b/tests/server/cluster/cluster_work_test.go @@ -41,7 +41,7 @@ func TestValidRequestRegion(t *testing.T) { cluster.WaitLeader() leaderServer := cluster.GetServer(cluster.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) rc := leaderServer.GetRaftCluster() @@ -85,7 +85,7 @@ func TestAskSplit(t *testing.T) { cluster.WaitLeader() leaderServer := cluster.GetServer(cluster.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) rc := leaderServer.GetRaftCluster() @@ -133,7 +133,7 @@ func TestSuspectRegions(t *testing.T) { cluster.WaitLeader() leaderServer := cluster.GetServer(cluster.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) rc := leaderServer.GetRaftCluster() diff --git a/tests/server/config/config_test.go b/tests/server/config/config_test.go index f375397a9f3..3f333c1722b 100644 --- a/tests/server/config/config_test.go +++ b/tests/server/config/config_test.go @@ -21,7 +21,7 @@ import ( "net/http" "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/ratelimit" "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server" @@ -35,36 +35,18 @@ var dialClient = &http.Client{ }, } -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testConfigPresistSuite{}) - -type testConfigPresistSuite struct { - cleanup func() - cluster *tests.TestCluster -} - -func (s *testConfigPresistSuite) SetUpSuite(c *C) { +func TestRateLimitConfigReload(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) - s.cleanup = cancel + defer cancel() cluster, err := tests.NewTestCluster(ctx, 3) - c.Assert(err, IsNil) - c.Assert(cluster.RunInitialServers(), IsNil) - c.Assert(cluster.WaitLeader(), Not(HasLen), 0) - s.cluster = cluster -} - -func (s *testConfigPresistSuite) TearDownSuite(c *C) { - s.cleanup() - s.cluster.Destroy() -} - -func (s *testConfigPresistSuite) TestRateLimitConfigReload(c *C) { - leader := s.cluster.GetServer(s.cluster.GetLeader()) + re.NoError(err) + defer cluster.Destroy() + re.NoError(cluster.RunInitialServers()) + re.NotEmpty(cluster.WaitLeader()) + leader := cluster.GetServer(cluster.GetLeader()) - c.Assert(leader.GetServer().GetServiceMiddlewareConfig().RateLimitConfig.LimiterConfig, HasLen, 0) + re.Len(leader.GetServer().GetServiceMiddlewareConfig().RateLimitConfig.LimiterConfig, 0) limitCfg := make(map[string]ratelimit.DimensionConfig) limitCfg["GetRegions"] = ratelimit.DimensionConfig{QPS: 1} @@ -73,29 +55,29 @@ func (s *testConfigPresistSuite) TestRateLimitConfigReload(c *C) { "limiter-config": limitCfg, } data, err := json.Marshal(input) - c.Assert(err, IsNil) + re.NoError(err) req, _ := http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) resp, err := dialClient.Do(req) - c.Assert(err, IsNil) + re.NoError(err) resp.Body.Close() - c.Assert(leader.GetServer().GetServiceMiddlewarePersistOptions().IsRateLimitEnabled(), Equals, true) - c.Assert(leader.GetServer().GetServiceMiddlewarePersistOptions().GetRateLimitConfig().LimiterConfig, HasLen, 1) + re.True(leader.GetServer().GetServiceMiddlewarePersistOptions().IsRateLimitEnabled()) + re.Len(leader.GetServer().GetServiceMiddlewarePersistOptions().GetRateLimitConfig().LimiterConfig, 1) oldLeaderName := leader.GetServer().Name() leader.GetServer().GetMember().ResignEtcdLeader(leader.GetServer().Context(), oldLeaderName, "") - mustWaitLeader(c, s.cluster.GetServers()) - leader = s.cluster.GetServer(s.cluster.GetLeader()) + mustWaitLeader(re, cluster.GetServers()) + leader = cluster.GetServer(cluster.GetLeader()) - c.Assert(leader.GetServer().GetServiceMiddlewarePersistOptions().IsRateLimitEnabled(), Equals, true) - c.Assert(leader.GetServer().GetServiceMiddlewarePersistOptions().GetRateLimitConfig().LimiterConfig, HasLen, 1) + re.True(leader.GetServer().GetServiceMiddlewarePersistOptions().IsRateLimitEnabled()) + re.Len(leader.GetServer().GetServiceMiddlewarePersistOptions().GetRateLimitConfig().LimiterConfig, 1) } -func mustWaitLeader(c *C, svrs map[string]*tests.TestServer) *server.Server { +func mustWaitLeader(re *require.Assertions, svrs map[string]*tests.TestServer) *server.Server { var leader *server.Server - testutil.WaitUntil(c, func() bool { - for _, s := range svrs { - if !s.GetServer().IsClosed() && s.GetServer().GetMember().IsLeader() { - leader = s.GetServer() + testutil.Eventually(re, func() bool { + for _, svr := range svrs { + if !svr.GetServer().IsClosed() && svr.GetServer().GetMember().IsLeader() { + leader = svr.GetServer() return true } } diff --git a/tests/server/id/id_test.go b/tests/server/id/id_test.go index d9279cd3616..3375cc55adb 100644 --- a/tests/server/id/id_test.go +++ b/tests/server/id/id_test.go @@ -93,7 +93,7 @@ func TestCommand(t *testing.T) { leaderServer := cluster.GetServer(cluster.GetLeader()) req := &pdpb.AllocIDRequest{Header: testutil.NewRequestHeader(leaderServer.GetClusterID())} - grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) var last uint64 for i := uint64(0); i < 2*allocStep; i++ { resp, err := grpcPDClient.AllocID(context.Background(), req) diff --git a/tests/server/member/member_test.go b/tests/server/member/member_test.go index 1864500df74..552a2d0c221 100644 --- a/tests/server/member/member_test.go +++ b/tests/server/member/member_test.go @@ -348,7 +348,7 @@ func sendRequest(re *require.Assertions, wg *sync.WaitGroup, done <-chan bool, a default: // We don't need to check the response and error, // just make sure the server will not panic. - grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, addr) + grpcPDClient := testutil.MustNewGrpcClient(re, addr) if grpcPDClient != nil { _, _ = grpcPDClient.AllocID(context.Background(), req) } diff --git a/tests/server/tso/consistency_test.go b/tests/server/tso/consistency_test.go index 430160cd5ac..6e1fe54c7c8 100644 --- a/tests/server/tso/consistency_test.go +++ b/tests/server/tso/consistency_test.go @@ -70,7 +70,7 @@ func (suite *tsoConsistencyTestSuite) TestNormalGlobalTSO() { cluster.WaitLeader() leaderServer := cluster.GetServer(cluster.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClientWithTestify(suite.Require(), leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClient(suite.Require(), leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() req := &pdpb.TsoRequest{ Header: testutil.NewRequestHeader(clusterID), @@ -143,10 +143,10 @@ func (suite *tsoConsistencyTestSuite) TestSynchronizedGlobalTSO() { suite.leaderServer = cluster.GetServer(cluster.GetLeader()) suite.NotNil(suite.leaderServer) - suite.dcClientMap[tso.GlobalDCLocation] = testutil.MustNewGrpcClientWithTestify(re, suite.leaderServer.GetAddr()) + suite.dcClientMap[tso.GlobalDCLocation] = testutil.MustNewGrpcClient(re, suite.leaderServer.GetAddr()) for _, dcLocation := range dcLocationConfig { pdName := suite.leaderServer.GetAllocatorLeader(dcLocation).GetName() - suite.dcClientMap[dcLocation] = testutil.MustNewGrpcClientWithTestify(re, cluster.GetServer(pdName).GetAddr()) + suite.dcClientMap[dcLocation] = testutil.MustNewGrpcClient(re, cluster.GetServer(pdName).GetAddr()) } ctx, cancel := context.WithCancel(context.Background()) @@ -218,10 +218,10 @@ func (suite *tsoConsistencyTestSuite) TestSynchronizedGlobalTSOOverflow() { suite.leaderServer = cluster.GetServer(cluster.GetLeader()) suite.NotNil(suite.leaderServer) - suite.dcClientMap[tso.GlobalDCLocation] = testutil.MustNewGrpcClientWithTestify(re, suite.leaderServer.GetAddr()) + suite.dcClientMap[tso.GlobalDCLocation] = testutil.MustNewGrpcClient(re, suite.leaderServer.GetAddr()) for _, dcLocation := range dcLocationConfig { pdName := suite.leaderServer.GetAllocatorLeader(dcLocation).GetName() - suite.dcClientMap[dcLocation] = testutil.MustNewGrpcClientWithTestify(re, cluster.GetServer(pdName).GetAddr()) + suite.dcClientMap[dcLocation] = testutil.MustNewGrpcClient(re, cluster.GetServer(pdName).GetAddr()) } ctx, cancel := context.WithCancel(context.Background()) @@ -250,10 +250,10 @@ func (suite *tsoConsistencyTestSuite) TestLocalAllocatorLeaderChange() { suite.leaderServer = cluster.GetServer(cluster.GetLeader()) suite.NotNil(suite.leaderServer) - suite.dcClientMap[tso.GlobalDCLocation] = testutil.MustNewGrpcClientWithTestify(re, suite.leaderServer.GetAddr()) + suite.dcClientMap[tso.GlobalDCLocation] = testutil.MustNewGrpcClient(re, suite.leaderServer.GetAddr()) for _, dcLocation := range dcLocationConfig { pdName := suite.leaderServer.GetAllocatorLeader(dcLocation).GetName() - suite.dcClientMap[dcLocation] = testutil.MustNewGrpcClientWithTestify(re, cluster.GetServer(pdName).GetAddr()) + suite.dcClientMap[dcLocation] = testutil.MustNewGrpcClient(re, cluster.GetServer(pdName).GetAddr()) } ctx, cancel := context.WithCancel(context.Background()) @@ -311,7 +311,7 @@ func (suite *tsoConsistencyTestSuite) TestLocalTSOAfterMemberChanged() { cluster.WaitAllLeaders(re, dcLocationConfig) leaderServer := cluster.GetServer(cluster.GetLeader()) - leaderCli := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) + leaderCli := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) req := &pdpb.TsoRequest{ Header: testutil.NewRequestHeader(cluster.GetCluster().GetId()), Count: tsoCount, @@ -352,7 +352,7 @@ func (suite *tsoConsistencyTestSuite) testTSO(cluster *tests.TestCluster, dcLoca dcClientMap := make(map[string]pdpb.PDClient) for _, dcLocation := range dcLocationConfig { pdName := leaderServer.GetAllocatorLeader(dcLocation).GetName() - dcClientMap[dcLocation] = testutil.MustNewGrpcClientWithTestify(re, cluster.GetServer(pdName).GetAddr()) + dcClientMap[dcLocation] = testutil.MustNewGrpcClient(re, cluster.GetServer(pdName).GetAddr()) } var wg sync.WaitGroup @@ -412,7 +412,7 @@ func TestFallbackTSOConsistency(t *testing.T) { cluster.WaitLeader() server := cluster.GetServer(cluster.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, server.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClient(re, server.GetAddr()) svr := server.GetServer() svr.Close() re.NoError(failpoint.Disable("github.com/tikv/pd/server/tso/fallBackSync")) diff --git a/tests/server/tso/global_tso_test.go b/tests/server/tso/global_tso_test.go index 795841b6830..d89b341d7c1 100644 --- a/tests/server/tso/global_tso_test.go +++ b/tests/server/tso/global_tso_test.go @@ -82,7 +82,7 @@ func TestZeroTSOCount(t *testing.T) { cluster.WaitLeader() leaderServer := cluster.GetServer(cluster.GetLeader()) - grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, leaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() req := &pdpb.TsoRequest{ @@ -116,7 +116,7 @@ func TestRequestFollower(t *testing.T) { } re.NotNil(followerServer) - grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, followerServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClient(re, followerServer.GetAddr()) clusterID := followerServer.GetClusterID() req := &pdpb.TsoRequest{ Header: testutil.NewRequestHeader(clusterID), @@ -161,7 +161,7 @@ func TestDelaySyncTimestamp(t *testing.T) { } re.NotNil(nextLeaderServer) - grpcPDClient := testutil.MustNewGrpcClientWithTestify(re, nextLeaderServer.GetAddr()) + grpcPDClient := testutil.MustNewGrpcClient(re, nextLeaderServer.GetAddr()) clusterID := nextLeaderServer.GetClusterID() req := &pdpb.TsoRequest{ Header: testutil.NewRequestHeader(clusterID), diff --git a/tests/server/tso/tso_test.go b/tests/server/tso/tso_test.go index 8cb6b6d837f..1ac0a285940 100644 --- a/tests/server/tso/tso_test.go +++ b/tests/server/tso/tso_test.go @@ -78,7 +78,7 @@ func requestLocalTSOs(re *require.Assertions, cluster *tests.TestCluster, dcLoca leaderServer := cluster.GetServer(cluster.GetLeader()) for _, dcLocation := range dcLocationConfig { pdName := leaderServer.GetAllocatorLeader(dcLocation).GetName() - dcClientMap[dcLocation] = testutil.MustNewGrpcClientWithTestify(re, cluster.GetServer(pdName).GetAddr()) + dcClientMap[dcLocation] = testutil.MustNewGrpcClient(re, cluster.GetServer(pdName).GetAddr()) } for _, dcLocation := range dcLocationConfig { req := &pdpb.TsoRequest{ From 2aa6049c691be21c67768ad9a74859891a6ac8b3 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Thu, 23 Jun 2022 10:50:37 +0800 Subject: [PATCH 72/82] scripts, tests: update the check-test.sh to detect more inefficient assert functions (#5219) ref tikv/pd#4813 Update the check-test.sh to detect more inefficient assert functions. Signed-off-by: JmPotato --- scripts/check-test.sh | 8 ++++++++ tests/pdctl/store/store_test.go | 3 +-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/scripts/check-test.sh b/scripts/check-test.sh index f65d506565f..867d234cce6 100755 --- a/scripts/check-test.sh +++ b/scripts/check-test.sh @@ -49,4 +49,12 @@ if [ "$res" ]; then exit 1 fi +res=$(grep -rn --include=\*_test.go -E "(re|suite|require)\.(Equal|NotEqual)\((t, )?(true|false)" . | sort -u) + +if [ "$res" ]; then + echo "following packages use the inefficient assert function: please replace require.Equal/NotEqual(true, xxx) with require.True/False" + echo "$res" + exit 1 +fi + exit 0 diff --git a/tests/pdctl/store/store_test.go b/tests/pdctl/store/store_test.go index c2c9420d01a..583c432a541 100644 --- a/tests/pdctl/store/store_test.go +++ b/tests/pdctl/store/store_test.go @@ -17,7 +17,6 @@ package store_test import ( "context" "encoding/json" - "strings" "testing" "time" @@ -101,7 +100,7 @@ func TestStore(t *testing.T) { args = []string{"-u", pdAddr, "store", "--state", "Up,Tombstone"} output, err = pdctl.ExecuteCommand(cmd, args...) re.NoError(err) - re.Equal(false, strings.Contains(string(output), "\"state\":")) + re.NotContains(string(output), "\"state\":") storesInfo = new(api.StoresInfo) re.NoError(json.Unmarshal(output, &storesInfo)) From bbe48e480de06cb2673669fa8df1726b16252f8e Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Thu, 23 Jun 2022 12:28:36 +0800 Subject: [PATCH 73/82] server: migrate test framework to testify (#5198) ref tikv/pd#4813, ref tikv/pd#5193 Signed-off-by: lhy1024 Co-authored-by: Ti Chi Robot --- server/server_test.go | 184 ++++++++++++++++++++---------------------- 1 file changed, 88 insertions(+), 96 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index 2a31cfb3b1c..58f572fb2df 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -21,7 +21,7 @@ import ( "net/http" "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/apiutil" "github.com/tikv/pd/pkg/assertutil" "github.com/tikv/pd/pkg/etcdutil" @@ -32,17 +32,26 @@ import ( "go.uber.org/goleak" ) -func TestServer(t *testing.T) { - TestingT(t) -} - func TestMain(m *testing.M) { goleak.VerifyTestMain(m, testutil.LeakOptions...) } -func mustWaitLeader(c *C, svrs []*Server) *Server { +type leaderServerTestSuite struct { + suite.Suite + + ctx context.Context + cancel context.CancelFunc + svrs map[string]*Server + leaderPath string +} + +func TestLeaderServerTestSuite(t *testing.T) { + suite.Run(t, new(leaderServerTestSuite)) +} + +func (suite *leaderServerTestSuite) mustWaitLeader(svrs []*Server) *Server { var leader *Server - testutil.WaitUntil(c, func() bool { + testutil.Eventually(suite.Require(), func() bool { for _, s := range svrs { if !s.IsClosed() && s.member.IsLeader() { leader = s @@ -54,65 +63,52 @@ func mustWaitLeader(c *C, svrs []*Server) *Server { return leader } -func checkerWithNilAssert(c *C) *assertutil.Checker { +func (suite *leaderServerTestSuite) checkerWithNilAssert() *assertutil.Checker { checker := assertutil.NewChecker() checker.FailNow = func() { - c.FailNow() + suite.FailNow("should be nil") } checker.IsNil = func(obtained interface{}) { - c.Assert(obtained, IsNil) + suite.Nil(obtained) } return checker } -var _ = Suite(&testLeaderServerSuite{}) +func (suite *leaderServerTestSuite) SetupSuite() { + suite.ctx, suite.cancel = context.WithCancel(context.Background()) + suite.svrs = make(map[string]*Server) -type testLeaderServerSuite struct { - ctx context.Context - cancel context.CancelFunc - svrs map[string]*Server - leaderPath string -} - -func (s *testLeaderServerSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) - s.svrs = make(map[string]*Server) - - cfgs := NewTestMultiConfig(checkerWithNilAssert(c), 3) + cfgs := NewTestMultiConfig(suite.checkerWithNilAssert(), 3) ch := make(chan *Server, 3) for i := 0; i < 3; i++ { cfg := cfgs[i] go func() { - svr, err := CreateServer(s.ctx, cfg) - c.Assert(err, IsNil) + svr, err := CreateServer(suite.ctx, cfg) + suite.NoError(err) err = svr.Run() - c.Assert(err, IsNil) + suite.NoError(err) ch <- svr }() } for i := 0; i < 3; i++ { svr := <-ch - s.svrs[svr.GetAddr()] = svr - s.leaderPath = svr.GetMember().GetLeaderPath() + suite.svrs[svr.GetAddr()] = svr + suite.leaderPath = svr.GetMember().GetLeaderPath() } } -func (s *testLeaderServerSuite) TearDownSuite(c *C) { - s.cancel() - for _, svr := range s.svrs { +func (suite *leaderServerTestSuite) TearDownSuite() { + suite.cancel() + for _, svr := range suite.svrs { svr.Close() testutil.CleanServer(svr.cfg.DataDir) } } -var _ = Suite(&testServerSuite{}) - -type testServerSuite struct{} - -func newTestServersWithCfgs(ctx context.Context, c *C, cfgs []*config.Config) ([]*Server, CleanupFunc) { +func (suite *leaderServerTestSuite) newTestServersWithCfgs(ctx context.Context, cfgs []*config.Config) ([]*Server, CleanupFunc) { svrs := make([]*Server, 0, len(cfgs)) ch := make(chan *Server) @@ -128,19 +124,19 @@ func newTestServersWithCfgs(ctx context.Context, c *C, cfgs []*config.Config) ([ ch <- svr } }() - c.Assert(err, IsNil) + suite.NoError(err) err = svr.Run() - c.Assert(err, IsNil) + suite.NoError(err) failed = false }(cfg) } for i := 0; i < len(cfgs); i++ { svr := <-ch - c.Assert(svr, NotNil) + suite.NotNil(svr) svrs = append(svrs, svr) } - mustWaitLeader(c, svrs) + suite.mustWaitLeader(svrs) cleanup := func() { for _, svr := range svrs { @@ -154,10 +150,10 @@ func newTestServersWithCfgs(ctx context.Context, c *C, cfgs []*config.Config) ([ return svrs, cleanup } -func (s *testServerSuite) TestCheckClusterID(c *C) { +func (suite *leaderServerTestSuite) TestCheckClusterID() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - cfgs := NewTestMultiConfig(checkerWithNilAssert(c), 2) + cfgs := NewTestMultiConfig(suite.checkerWithNilAssert(), 2) for i, cfg := range cfgs { cfg.DataDir = fmt.Sprintf("/tmp/test_pd_check_clusterID_%d", i) // Clean up before testing. @@ -170,7 +166,7 @@ func (s *testServerSuite) TestCheckClusterID(c *C) { cfgA, cfgB := cfgs[0], cfgs[1] // Start a standalone cluster. - svrsA, cleanA := newTestServersWithCfgs(ctx, c, []*config.Config{cfgA}) + svrsA, cleanA := suite.newTestServersWithCfgs(ctx, []*config.Config{cfgA}) defer cleanA() // Close it. for _, svr := range svrsA { @@ -178,38 +174,34 @@ func (s *testServerSuite) TestCheckClusterID(c *C) { } // Start another cluster. - _, cleanB := newTestServersWithCfgs(ctx, c, []*config.Config{cfgB}) + _, cleanB := suite.newTestServersWithCfgs(ctx, []*config.Config{cfgB}) defer cleanB() // Start previous cluster, expect an error. cfgA.InitialCluster = originInitial svr, err := CreateServer(ctx, cfgA) - c.Assert(err, IsNil) + suite.NoError(err) etcd, err := embed.StartEtcd(svr.etcdCfg) - c.Assert(err, IsNil) + suite.NoError(err) urlsMap, err := types.NewURLsMap(svr.cfg.InitialCluster) - c.Assert(err, IsNil) + suite.NoError(err) tlsConfig, err := svr.cfg.Security.ToTLSConfig() - c.Assert(err, IsNil) + suite.NoError(err) err = etcdutil.CheckClusterID(etcd.Server.Cluster().ID(), urlsMap, tlsConfig) - c.Assert(err, NotNil) + suite.Error(err) etcd.Close() testutil.CleanServer(cfgA.DataDir) } -var _ = Suite(&testServerHandlerSuite{}) - -type testServerHandlerSuite struct{} - -func (s *testServerHandlerSuite) TestRegisterServerHandler(c *C) { +func (suite *leaderServerTestSuite) TestRegisterServerHandler() { mokHandler := func(ctx context.Context, s *Server) (http.Handler, ServiceGroup, error) { mux := http.NewServeMux() mux.HandleFunc("/pd/apis/mok/v1/hello", func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, "Hello World") // test getting ip clientIP := apiutil.GetIPAddrFromHTTPRequest(r) - c.Assert(clientIP, Equals, "127.0.0.1") + suite.Equal("127.0.0.1", clientIP) }) info := ServiceGroup{ Name: "mok", @@ -217,38 +209,38 @@ func (s *testServerHandlerSuite) TestRegisterServerHandler(c *C) { } return mux, info, nil } - cfg := NewTestSingleConfig(checkerWithNilAssert(c)) + cfg := NewTestSingleConfig(suite.checkerWithNilAssert()) ctx, cancel := context.WithCancel(context.Background()) svr, err := CreateServer(ctx, cfg, mokHandler) - c.Assert(err, IsNil) + suite.NoError(err) _, err = CreateServer(ctx, cfg, mokHandler, mokHandler) // Repeat register. - c.Assert(err, NotNil) + suite.Error(err) defer func() { cancel() svr.Close() testutil.CleanServer(svr.cfg.DataDir) }() err = svr.Run() - c.Assert(err, IsNil) + suite.NoError(err) resp, err := http.Get(fmt.Sprintf("%s/pd/apis/mok/v1/hello", svr.GetAddr())) - c.Assert(err, IsNil) - c.Assert(resp.StatusCode, Equals, http.StatusOK) + suite.NoError(err) + suite.Equal(http.StatusOK, resp.StatusCode) defer resp.Body.Close() bodyBytes, err := io.ReadAll(resp.Body) - c.Assert(err, IsNil) + suite.NoError(err) bodyString := string(bodyBytes) - c.Assert(bodyString, Equals, "Hello World\n") + suite.Equal("Hello World\n", bodyString) } -func (s *testServerHandlerSuite) TestSourceIpForHeaderForwarded(c *C) { +func (suite *leaderServerTestSuite) TestSourceIpForHeaderForwarded() { mokHandler := func(ctx context.Context, s *Server) (http.Handler, ServiceGroup, error) { mux := http.NewServeMux() mux.HandleFunc("/pd/apis/mok/v1/hello", func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, "Hello World") // test getting ip clientIP := apiutil.GetIPAddrFromHTTPRequest(r) - c.Assert(clientIP, Equals, "127.0.0.2") + suite.Equal("127.0.0.2", clientIP) }) info := ServiceGroup{ Name: "mok", @@ -256,42 +248,42 @@ func (s *testServerHandlerSuite) TestSourceIpForHeaderForwarded(c *C) { } return mux, info, nil } - cfg := NewTestSingleConfig(checkerWithNilAssert(c)) + cfg := NewTestSingleConfig(suite.checkerWithNilAssert()) ctx, cancel := context.WithCancel(context.Background()) svr, err := CreateServer(ctx, cfg, mokHandler) - c.Assert(err, IsNil) + suite.NoError(err) _, err = CreateServer(ctx, cfg, mokHandler, mokHandler) // Repeat register. - c.Assert(err, NotNil) + suite.Error(err) defer func() { cancel() svr.Close() testutil.CleanServer(svr.cfg.DataDir) }() err = svr.Run() - c.Assert(err, IsNil) + suite.NoError(err) req, err := http.NewRequest("GET", fmt.Sprintf("%s/pd/apis/mok/v1/hello", svr.GetAddr()), nil) - c.Assert(err, IsNil) + suite.NoError(err) req.Header.Add("X-Forwarded-For", "127.0.0.2") resp, err := http.DefaultClient.Do(req) - c.Assert(err, IsNil) - c.Assert(resp.StatusCode, Equals, http.StatusOK) + suite.NoError(err) + suite.Equal(http.StatusOK, resp.StatusCode) defer resp.Body.Close() bodyBytes, err := io.ReadAll(resp.Body) - c.Assert(err, IsNil) + suite.NoError(err) bodyString := string(bodyBytes) - c.Assert(bodyString, Equals, "Hello World\n") + suite.Equal("Hello World\n", bodyString) } -func (s *testServerHandlerSuite) TestSourceIpForHeaderXReal(c *C) { +func (suite *leaderServerTestSuite) TestSourceIpForHeaderXReal() { mokHandler := func(ctx context.Context, s *Server) (http.Handler, ServiceGroup, error) { mux := http.NewServeMux() mux.HandleFunc("/pd/apis/mok/v1/hello", func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, "Hello World") // test getting ip clientIP := apiutil.GetIPAddrFromHTTPRequest(r) - c.Assert(clientIP, Equals, "127.0.0.2") + suite.Equal("127.0.0.2", clientIP) }) info := ServiceGroup{ Name: "mok", @@ -299,42 +291,42 @@ func (s *testServerHandlerSuite) TestSourceIpForHeaderXReal(c *C) { } return mux, info, nil } - cfg := NewTestSingleConfig(checkerWithNilAssert(c)) + cfg := NewTestSingleConfig(suite.checkerWithNilAssert()) ctx, cancel := context.WithCancel(context.Background()) svr, err := CreateServer(ctx, cfg, mokHandler) - c.Assert(err, IsNil) + suite.NoError(err) _, err = CreateServer(ctx, cfg, mokHandler, mokHandler) // Repeat register. - c.Assert(err, NotNil) + suite.Error(err) defer func() { cancel() svr.Close() testutil.CleanServer(svr.cfg.DataDir) }() err = svr.Run() - c.Assert(err, IsNil) + suite.NoError(err) req, err := http.NewRequest("GET", fmt.Sprintf("%s/pd/apis/mok/v1/hello", svr.GetAddr()), nil) - c.Assert(err, IsNil) + suite.NoError(err) req.Header.Add("X-Real-Ip", "127.0.0.2") resp, err := http.DefaultClient.Do(req) - c.Assert(err, IsNil) - c.Assert(resp.StatusCode, Equals, http.StatusOK) + suite.NoError(err) + suite.Equal(http.StatusOK, resp.StatusCode) defer resp.Body.Close() bodyBytes, err := io.ReadAll(resp.Body) - c.Assert(err, IsNil) + suite.NoError(err) bodyString := string(bodyBytes) - c.Assert(bodyString, Equals, "Hello World\n") + suite.Equal("Hello World\n", bodyString) } -func (s *testServerHandlerSuite) TestSourceIpForHeaderBoth(c *C) { +func (suite *leaderServerTestSuite) TestSourceIpForHeaderBoth() { mokHandler := func(ctx context.Context, s *Server) (http.Handler, ServiceGroup, error) { mux := http.NewServeMux() mux.HandleFunc("/pd/apis/mok/v1/hello", func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, "Hello World") // test getting ip clientIP := apiutil.GetIPAddrFromHTTPRequest(r) - c.Assert(clientIP, Equals, "127.0.0.2") + suite.Equal("127.0.0.2", clientIP) }) info := ServiceGroup{ Name: "mok", @@ -342,31 +334,31 @@ func (s *testServerHandlerSuite) TestSourceIpForHeaderBoth(c *C) { } return mux, info, nil } - cfg := NewTestSingleConfig(checkerWithNilAssert(c)) + cfg := NewTestSingleConfig(suite.checkerWithNilAssert()) ctx, cancel := context.WithCancel(context.Background()) svr, err := CreateServer(ctx, cfg, mokHandler) - c.Assert(err, IsNil) + suite.NoError(err) _, err = CreateServer(ctx, cfg, mokHandler, mokHandler) // Repeat register. - c.Assert(err, NotNil) + suite.Error(err) defer func() { cancel() svr.Close() testutil.CleanServer(svr.cfg.DataDir) }() err = svr.Run() - c.Assert(err, IsNil) + suite.NoError(err) req, err := http.NewRequest("GET", fmt.Sprintf("%s/pd/apis/mok/v1/hello", svr.GetAddr()), nil) - c.Assert(err, IsNil) + suite.NoError(err) req.Header.Add("X-Forwarded-For", "127.0.0.2") req.Header.Add("X-Real-Ip", "127.0.0.3") resp, err := http.DefaultClient.Do(req) - c.Assert(err, IsNil) - c.Assert(resp.StatusCode, Equals, http.StatusOK) + suite.NoError(err) + suite.Equal(http.StatusOK, resp.StatusCode) defer resp.Body.Close() bodyBytes, err := io.ReadAll(resp.Body) - c.Assert(err, IsNil) + suite.NoError(err) bodyString := string(bodyBytes) - c.Assert(bodyString, Equals, "Hello World\n") + suite.Equal("Hello World\n", bodyString) } From e4ca2e699a69ec0e8158871003a1877a0ca2c8e5 Mon Sep 17 00:00:00 2001 From: LLThomas Date: Thu, 23 Jun 2022 19:20:37 +0800 Subject: [PATCH 74/82] server/cluster: migrate test framework to testify (#5203) ref tikv/pd#4813 As the title says. Signed-off-by: LLThomas --- server/cluster/cluster_stat_test.go | 62 +- server/cluster/cluster_test.go | 1087 +++++++++-------- server/cluster/cluster_worker_test.go | 42 +- server/cluster/coordinator_test.go | 938 +++++++------- server/cluster/store_limiter_test.go | 37 +- .../unsafe_recovery_controller_test.go | 558 +++++---- 6 files changed, 1409 insertions(+), 1315 deletions(-) diff --git a/server/cluster/cluster_stat_test.go b/server/cluster/cluster_stat_test.go index 0b2a094b5b3..e5352b7ac0a 100644 --- a/server/cluster/cluster_stat_test.go +++ b/server/cluster/cluster_stat_test.go @@ -16,16 +16,12 @@ package cluster import ( "fmt" + "testing" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testClusterStatSuite{}) - -type testClusterStatSuite struct { -} - func cpu(usage int64) []*pdpb.RecordPair { n := 10 name := "cpu" @@ -39,19 +35,20 @@ func cpu(usage int64) []*pdpb.RecordPair { return pairs } -func (s *testClusterStatSuite) TestCPUEntriesAppend(c *C) { +func TestCPUEntriesAppend(t *testing.T) { + re := require.New(t) N := 10 checkAppend := func(appended bool, usage int64, threads ...string) { entries := NewCPUEntries(N) - c.Assert(entries, NotNil) + re.NotNil(entries) for i := 0; i < N; i++ { entry := &StatEntry{ CpuUsages: cpu(usage), } - c.Assert(entries.Append(entry, threads...), Equals, appended) + re.Equal(appended, entries.Append(entry, threads...)) } - c.Assert(entries.cpu.Get(), Equals, float64(usage)) + re.Equal(float64(usage), entries.cpu.Get()) } checkAppend(true, 20) @@ -59,10 +56,11 @@ func (s *testClusterStatSuite) TestCPUEntriesAppend(c *C) { checkAppend(false, 0, "cup") } -func (s *testClusterStatSuite) TestCPUEntriesCPU(c *C) { +func TestCPUEntriesCPU(t *testing.T) { + re := require.New(t) N := 10 entries := NewCPUEntries(N) - c.Assert(entries, NotNil) + re.NotNil(entries) usages := cpu(20) for i := 0; i < N; i++ { @@ -71,13 +69,14 @@ func (s *testClusterStatSuite) TestCPUEntriesCPU(c *C) { } entries.Append(entry) } - c.Assert(entries.CPU(), Equals, float64(20)) + re.Equal(float64(20), entries.CPU()) } -func (s *testClusterStatSuite) TestStatEntriesAppend(c *C) { +func TestStatEntriesAppend(t *testing.T) { + re := require.New(t) N := 10 cst := NewStatEntries(N) - c.Assert(cst, NotNil) + re.NotNil(cst) ThreadsCollected = []string{"cpu:"} // fill 2*N entries, 2 entries for each store @@ -86,19 +85,20 @@ func (s *testClusterStatSuite) TestStatEntriesAppend(c *C) { StoreId: uint64(i % N), CpuUsages: cpu(20), } - c.Assert(cst.Append(entry), IsTrue) + re.True(cst.Append(entry)) } // use i as the store ID for i := 0; i < N; i++ { - c.Assert(cst.stats[uint64(i)].CPU(), Equals, float64(20)) + re.Equal(float64(20), cst.stats[uint64(i)].CPU()) } } -func (s *testClusterStatSuite) TestStatEntriesCPU(c *C) { +func TestStatEntriesCPU(t *testing.T) { + re := require.New(t) N := 10 cst := NewStatEntries(N) - c.Assert(cst, NotNil) + re.NotNil(cst) // the average cpu usage is 20% usages := cpu(20) @@ -110,14 +110,15 @@ func (s *testClusterStatSuite) TestStatEntriesCPU(c *C) { StoreId: uint64(i % N), CpuUsages: usages, } - c.Assert(cst.Append(entry), IsTrue) + re.True(cst.Append(entry)) } - c.Assert(cst.total, Equals, int64(2*N)) + re.Equal(int64(2*N), cst.total) // the cpu usage of the whole cluster is 20% - c.Assert(cst.CPU(), Equals, float64(20)) + re.Equal(float64(20), cst.CPU()) } -func (s *testClusterStatSuite) TestStatEntriesCPUStale(c *C) { +func TestStatEntriesCPUStale(t *testing.T) { + re := require.New(t) N := 10 cst := NewStatEntries(N) // make all entries stale immediately @@ -132,13 +133,14 @@ func (s *testClusterStatSuite) TestStatEntriesCPUStale(c *C) { } cst.Append(entry) } - c.Assert(cst.CPU(), Equals, float64(0)) + re.Equal(float64(0), cst.CPU()) } -func (s *testClusterStatSuite) TestStatEntriesState(c *C) { +func TestStatEntriesState(t *testing.T) { + re := require.New(t) Load := func(usage int64) *State { cst := NewStatEntries(10) - c.Assert(cst, NotNil) + re.NotNil(cst) usages := cpu(usage) ThreadsCollected = []string{"cpu:"} @@ -152,8 +154,8 @@ func (s *testClusterStatSuite) TestStatEntriesState(c *C) { } return &State{cst} } - c.Assert(Load(0).State(), Equals, LoadStateIdle) - c.Assert(Load(5).State(), Equals, LoadStateLow) - c.Assert(Load(10).State(), Equals, LoadStateNormal) - c.Assert(Load(30).State(), Equals, LoadStateHigh) + re.Equal(LoadStateIdle, Load(0).State()) + re.Equal(LoadStateLow, Load(5).State()) + re.Equal(LoadStateNormal, Load(10).State()) + re.Equal(LoadStateHigh, Load(30).State()) } diff --git a/server/cluster/cluster_test.go b/server/cluster/cluster_test.go index 530abff2b87..1d2120744e1 100644 --- a/server/cluster/cluster_test.go +++ b/server/cluster/cluster_test.go @@ -19,16 +19,15 @@ import ( "fmt" "math" "math/rand" - "strings" "sync" "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/errs" "github.com/tikv/pd/pkg/mock/mockid" "github.com/tikv/pd/pkg/progress" @@ -44,29 +43,14 @@ import ( "github.com/tikv/pd/server/versioninfo" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testClusterInfoSuite{}) - -type testClusterInfoSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *testClusterInfoSuite) TearDownTest(c *C) { - s.cancel() -} +func TestStoreHeartbeat(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() -func (s *testClusterInfoSuite) SetUpTest(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) -} - -func (s *testClusterInfoSuite) TestStoreHeartbeat(c *C) { _, opt, err := newTestScheduleConfig() - c.Assert(err, IsNil) - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + re.NoError(err) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) n, np := uint64(3), uint64(3) stores := newTestStores(n, "2.0.0") @@ -74,9 +58,9 @@ func (s *testClusterInfoSuite) TestStoreHeartbeat(c *C) { regions := newTestRegions(n, n, np) for _, region := range regions { - c.Assert(cluster.putRegion(region), IsNil) + re.NoError(cluster.putRegion(region)) } - c.Assert(cluster.core.Regions.GetRegionCount(), Equals, int(n)) + re.Equal(int(n), cluster.core.Regions.GetRegionCount()) for i, store := range stores { storeStats := &pdpb.StoreStats{ @@ -85,30 +69,30 @@ func (s *testClusterInfoSuite) TestStoreHeartbeat(c *C) { Available: 50, RegionCount: 1, } - c.Assert(cluster.HandleStoreHeartbeat(storeStats), NotNil) + re.Error(cluster.HandleStoreHeartbeat(storeStats)) - c.Assert(cluster.putStoreLocked(store), IsNil) - c.Assert(cluster.GetStoreCount(), Equals, i+1) + re.NoError(cluster.putStoreLocked(store)) + re.Equal(i+1, cluster.GetStoreCount()) - c.Assert(store.GetLastHeartbeatTS().UnixNano(), Equals, int64(0)) + re.Equal(int64(0), store.GetLastHeartbeatTS().UnixNano()) - c.Assert(cluster.HandleStoreHeartbeat(storeStats), IsNil) + re.NoError(cluster.HandleStoreHeartbeat(storeStats)) s := cluster.GetStore(store.GetID()) - c.Assert(s.GetLastHeartbeatTS().UnixNano(), Not(Equals), int64(0)) - c.Assert(s.GetStoreStats(), DeepEquals, storeStats) + re.NotEqual(int64(0), s.GetLastHeartbeatTS().UnixNano()) + re.Equal(storeStats, s.GetStoreStats()) storeMetasAfterHeartbeat = append(storeMetasAfterHeartbeat, s.GetMeta()) } - c.Assert(cluster.GetStoreCount(), Equals, int(n)) + re.Equal(int(n), cluster.GetStoreCount()) for i, store := range stores { tmp := &metapb.Store{} ok, err := cluster.storage.LoadStore(store.GetID(), tmp) - c.Assert(ok, IsTrue) - c.Assert(err, IsNil) - c.Assert(tmp, DeepEquals, storeMetasAfterHeartbeat[i]) + re.True(ok) + re.NoError(err) + re.Equal(storeMetasAfterHeartbeat[i], tmp) } hotHeartBeat := &pdpb.StoreStats{ StoreId: 1, @@ -137,56 +121,60 @@ func (s *testClusterInfoSuite) TestStoreHeartbeat(c *C) { }, PeerStats: []*pdpb.PeerStat{}, } - c.Assert(cluster.HandleStoreHeartbeat(hotHeartBeat), IsNil) - c.Assert(cluster.HandleStoreHeartbeat(hotHeartBeat), IsNil) - c.Assert(cluster.HandleStoreHeartbeat(hotHeartBeat), IsNil) + re.NoError(cluster.HandleStoreHeartbeat(hotHeartBeat)) + re.NoError(cluster.HandleStoreHeartbeat(hotHeartBeat)) + re.NoError(cluster.HandleStoreHeartbeat(hotHeartBeat)) time.Sleep(20 * time.Millisecond) storeStats := cluster.hotStat.RegionStats(statistics.Read, 3) - c.Assert(storeStats[1], HasLen, 1) - c.Assert(storeStats[1][0].RegionID, Equals, uint64(1)) + re.Len(storeStats[1], 1) + re.Equal(uint64(1), storeStats[1][0].RegionID) interval := float64(hotHeartBeat.Interval.EndTimestamp - hotHeartBeat.Interval.StartTimestamp) - c.Assert(storeStats[1][0].Loads, HasLen, int(statistics.RegionStatCount)) - c.Assert(storeStats[1][0].Loads[statistics.RegionReadBytes], Equals, float64(hotHeartBeat.PeerStats[0].ReadBytes)/interval) - c.Assert(storeStats[1][0].Loads[statistics.RegionReadKeys], Equals, float64(hotHeartBeat.PeerStats[0].ReadKeys)/interval) - c.Assert(storeStats[1][0].Loads[statistics.RegionReadQuery], Equals, float64(hotHeartBeat.PeerStats[0].QueryStats.Get)/interval) + re.Len(storeStats[1][0].Loads, int(statistics.RegionStatCount)) + re.Equal(float64(hotHeartBeat.PeerStats[0].ReadBytes)/interval, storeStats[1][0].Loads[statistics.RegionReadBytes]) + re.Equal(float64(hotHeartBeat.PeerStats[0].ReadKeys)/interval, storeStats[1][0].Loads[statistics.RegionReadKeys]) + re.Equal(float64(hotHeartBeat.PeerStats[0].QueryStats.Get)/interval, storeStats[1][0].Loads[statistics.RegionReadQuery]) // After cold heartbeat, we won't find region 1 peer in regionStats - c.Assert(cluster.HandleStoreHeartbeat(coldHeartBeat), IsNil) + re.NoError(cluster.HandleStoreHeartbeat(coldHeartBeat)) time.Sleep(20 * time.Millisecond) storeStats = cluster.hotStat.RegionStats(statistics.Read, 1) - c.Assert(storeStats[1], HasLen, 0) + re.Len(storeStats[1], 0) // After hot heartbeat, we can find region 1 peer again - c.Assert(cluster.HandleStoreHeartbeat(hotHeartBeat), IsNil) + re.NoError(cluster.HandleStoreHeartbeat(hotHeartBeat)) time.Sleep(20 * time.Millisecond) storeStats = cluster.hotStat.RegionStats(statistics.Read, 3) - c.Assert(storeStats[1], HasLen, 1) - c.Assert(storeStats[1][0].RegionID, Equals, uint64(1)) + re.Len(storeStats[1], 1) + re.Equal(uint64(1), storeStats[1][0].RegionID) // after several cold heartbeats, and one hot heartbeat, we also can't find region 1 peer - c.Assert(cluster.HandleStoreHeartbeat(coldHeartBeat), IsNil) - c.Assert(cluster.HandleStoreHeartbeat(coldHeartBeat), IsNil) - c.Assert(cluster.HandleStoreHeartbeat(coldHeartBeat), IsNil) + re.NoError(cluster.HandleStoreHeartbeat(coldHeartBeat)) + re.NoError(cluster.HandleStoreHeartbeat(coldHeartBeat)) + re.NoError(cluster.HandleStoreHeartbeat(coldHeartBeat)) time.Sleep(20 * time.Millisecond) storeStats = cluster.hotStat.RegionStats(statistics.Read, 0) - c.Assert(storeStats[1], HasLen, 0) - c.Assert(cluster.HandleStoreHeartbeat(hotHeartBeat), IsNil) + re.Len(storeStats[1], 0) + re.Nil(cluster.HandleStoreHeartbeat(hotHeartBeat)) time.Sleep(20 * time.Millisecond) storeStats = cluster.hotStat.RegionStats(statistics.Read, 1) - c.Assert(storeStats[1], HasLen, 1) - c.Assert(storeStats[1][0].RegionID, Equals, uint64(1)) + re.Len(storeStats[1], 1) + re.Equal(uint64(1), storeStats[1][0].RegionID) storeStats = cluster.hotStat.RegionStats(statistics.Read, 3) - c.Assert(storeStats[1], HasLen, 0) + re.Len(storeStats[1], 0) // after 2 hot heartbeats, wo can find region 1 peer again - c.Assert(cluster.HandleStoreHeartbeat(hotHeartBeat), IsNil) - c.Assert(cluster.HandleStoreHeartbeat(hotHeartBeat), IsNil) + re.NoError(cluster.HandleStoreHeartbeat(hotHeartBeat)) + re.NoError(cluster.HandleStoreHeartbeat(hotHeartBeat)) time.Sleep(20 * time.Millisecond) storeStats = cluster.hotStat.RegionStats(statistics.Read, 3) - c.Assert(storeStats[1], HasLen, 1) - c.Assert(storeStats[1][0].RegionID, Equals, uint64(1)) + re.Len(storeStats[1], 1) + re.Equal(uint64(1), storeStats[1][0].RegionID) } -func (s *testClusterInfoSuite) TestFilterUnhealthyStore(c *C) { +func TestFilterUnhealthyStore(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, err := newTestScheduleConfig() - c.Assert(err, IsNil) - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + re.NoError(err) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) stores := newTestStores(3, "2.0.0") for _, store := range stores { @@ -196,9 +184,9 @@ func (s *testClusterInfoSuite) TestFilterUnhealthyStore(c *C) { Available: 50, RegionCount: 1, } - c.Assert(cluster.putStoreLocked(store), IsNil) - c.Assert(cluster.HandleStoreHeartbeat(storeStats), IsNil) - c.Assert(cluster.hotStat.GetRollingStoreStats(store.GetID()), NotNil) + re.NoError(cluster.putStoreLocked(store)) + re.NoError(cluster.HandleStoreHeartbeat(storeStats)) + re.NotNil(cluster.hotStat.GetRollingStoreStats(store.GetID())) } for _, store := range stores { @@ -209,17 +197,21 @@ func (s *testClusterInfoSuite) TestFilterUnhealthyStore(c *C) { RegionCount: 1, } newStore := store.Clone(core.TombstoneStore()) - c.Assert(cluster.putStoreLocked(newStore), IsNil) - c.Assert(cluster.HandleStoreHeartbeat(storeStats), IsNil) - c.Assert(cluster.hotStat.GetRollingStoreStats(store.GetID()), IsNil) + re.NoError(cluster.putStoreLocked(newStore)) + re.NoError(cluster.HandleStoreHeartbeat(storeStats)) + re.Nil(cluster.hotStat.GetRollingStoreStats(store.GetID())) } } -func (s *testClusterInfoSuite) TestSetOfflineStore(c *C) { +func TestSetOfflineStore(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, err := newTestScheduleConfig() - c.Assert(err, IsNil) - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, nil) + re.NoError(err) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, nil) cluster.ruleManager = placement.NewRuleManager(storage.NewStorageWithMemoryBackend(), cluster, cluster.GetOpts()) if opt.IsPlacementRulesEnabled() { err := cluster.ruleManager.Initialize(opt.GetMaxReplicas(), opt.GetLocationLabels()) @@ -230,65 +222,69 @@ func (s *testClusterInfoSuite) TestSetOfflineStore(c *C) { // Put 6 stores. for _, store := range newTestStores(6, "2.0.0") { - c.Assert(cluster.PutStore(store.GetMeta()), IsNil) + re.NoError(cluster.PutStore(store.GetMeta())) } // store 1: up -> offline - c.Assert(cluster.RemoveStore(1, false), IsNil) + re.NoError(cluster.RemoveStore(1, false)) store := cluster.GetStore(1) - c.Assert(store.IsRemoving(), IsTrue) - c.Assert(store.IsPhysicallyDestroyed(), IsFalse) + re.True(store.IsRemoving()) + re.False(store.IsPhysicallyDestroyed()) // store 1: set physically to true success - c.Assert(cluster.RemoveStore(1, true), IsNil) + re.NoError(cluster.RemoveStore(1, true)) store = cluster.GetStore(1) - c.Assert(store.IsRemoving(), IsTrue) - c.Assert(store.IsPhysicallyDestroyed(), IsTrue) + re.True(store.IsRemoving()) + re.True(store.IsPhysicallyDestroyed()) // store 2:up -> offline & physically destroyed - c.Assert(cluster.RemoveStore(2, true), IsNil) + re.NoError(cluster.RemoveStore(2, true)) // store 2: set physically destroyed to false failed - c.Assert(cluster.RemoveStore(2, false), NotNil) - c.Assert(cluster.RemoveStore(2, true), IsNil) + re.Error(cluster.RemoveStore(2, false)) + re.NoError(cluster.RemoveStore(2, true)) // store 3: up to offline - c.Assert(cluster.RemoveStore(3, false), IsNil) - c.Assert(cluster.RemoveStore(3, false), IsNil) + re.NoError(cluster.RemoveStore(3, false)) + re.NoError(cluster.RemoveStore(3, false)) cluster.checkStores() // store 1,2,3 should be to tombstone for storeID := uint64(1); storeID <= 3; storeID++ { - c.Assert(cluster.GetStore(storeID).IsRemoved(), IsTrue) + re.True(cluster.GetStore(storeID).IsRemoved()) } // test bury store for storeID := uint64(0); storeID <= 4; storeID++ { store := cluster.GetStore(storeID) if store == nil || store.IsUp() { - c.Assert(cluster.BuryStore(storeID, false), NotNil) + re.Error(cluster.BuryStore(storeID, false)) } else { - c.Assert(cluster.BuryStore(storeID, false), IsNil) + re.NoError(cluster.BuryStore(storeID, false)) } } } -func (s *testClusterInfoSuite) TestSetOfflineWithReplica(c *C) { +func TestSetOfflineWithReplica(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, err := newTestScheduleConfig() - c.Assert(err, IsNil) - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, nil) + re.NoError(err) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, nil) // Put 4 stores. for _, store := range newTestStores(4, "2.0.0") { - c.Assert(cluster.PutStore(store.GetMeta()), IsNil) + re.NoError(cluster.PutStore(store.GetMeta())) } - c.Assert(cluster.RemoveStore(2, false), IsNil) + re.NoError(cluster.RemoveStore(2, false)) // should be failed since no enough store to accommodate the extra replica. err = cluster.RemoveStore(3, false) - c.Assert(strings.Contains(err.Error(), string(errs.ErrStoresNotEnough.RFCCode())), IsTrue) - c.Assert(cluster.RemoveStore(3, false), NotNil) + re.Contains(err.Error(), string(errs.ErrStoresNotEnough.RFCCode())) + re.Error(cluster.RemoveStore(3, false)) // should be success since physically-destroyed is true. - c.Assert(cluster.RemoveStore(3, true), IsNil) + re.NoError(cluster.RemoveStore(3, true)) } func addEvictLeaderScheduler(cluster *RaftCluster, storeID uint64) (evictScheduler schedule.Scheduler, err error) { @@ -305,62 +301,74 @@ func addEvictLeaderScheduler(cluster *RaftCluster, storeID uint64) (evictSchedul return } -func (s *testClusterInfoSuite) TestSetOfflineStoreWithEvictLeader(c *C) { +func TestSetOfflineStoreWithEvictLeader(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, err := newTestScheduleConfig() - c.Assert(err, IsNil) + re.NoError(err) opt.SetMaxReplicas(1) - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, nil) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, nil) // Put 3 stores. for _, store := range newTestStores(3, "2.0.0") { - c.Assert(cluster.PutStore(store.GetMeta()), IsNil) + re.NoError(cluster.PutStore(store.GetMeta())) } _, err = addEvictLeaderScheduler(cluster, 1) - c.Assert(err, IsNil) - c.Assert(cluster.RemoveStore(2, false), IsNil) + re.NoError(err) + re.NoError(cluster.RemoveStore(2, false)) // should be failed since there is only 1 store left and it is the evict-leader store. err = cluster.RemoveStore(3, false) - c.Assert(err, NotNil) - c.Assert(strings.Contains(err.Error(), string(errs.ErrNoStoreForRegionLeader.RFCCode())), IsTrue) - c.Assert(cluster.RemoveScheduler(schedulers.EvictLeaderName), IsNil) - c.Assert(cluster.RemoveStore(3, false), IsNil) + re.Error(err) + re.Contains(err.Error(), string(errs.ErrNoStoreForRegionLeader.RFCCode())) + re.NoError(cluster.RemoveScheduler(schedulers.EvictLeaderName)) + re.NoError(cluster.RemoveStore(3, false)) } -func (s *testClusterInfoSuite) TestForceBuryStore(c *C) { +func TestForceBuryStore(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, err := newTestScheduleConfig() - c.Assert(err, IsNil) - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + re.NoError(err) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) // Put 2 stores. stores := newTestStores(2, "5.3.0") stores[1] = stores[1].Clone(core.SetLastHeartbeatTS(time.Now())) for _, store := range stores { - c.Assert(cluster.PutStore(store.GetMeta()), IsNil) + re.NoError(cluster.PutStore(store.GetMeta())) } - c.Assert(cluster.BuryStore(uint64(1), true), IsNil) - c.Assert(cluster.BuryStore(uint64(2), true), NotNil) - c.Assert(errors.ErrorEqual(cluster.BuryStore(uint64(3), true), errs.ErrStoreNotFound.FastGenByArgs(uint64(3))), IsTrue) + re.NoError(cluster.BuryStore(uint64(1), true)) + re.Error(cluster.BuryStore(uint64(2), true)) + re.True(errors.ErrorEqual(cluster.BuryStore(uint64(3), true), errs.ErrStoreNotFound.FastGenByArgs(uint64(3)))) } -func (s *testClusterInfoSuite) TestReuseAddress(c *C) { +func TestReuseAddress(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, err := newTestScheduleConfig() - c.Assert(err, IsNil) - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, nil) + re.NoError(err) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, nil) // Put 4 stores. for _, store := range newTestStores(4, "2.0.0") { - c.Assert(cluster.PutStore(store.GetMeta()), IsNil) + re.NoError(cluster.PutStore(store.GetMeta())) } // store 1: up // store 2: offline - c.Assert(cluster.RemoveStore(2, false), IsNil) + re.NoError(cluster.RemoveStore(2, false)) // store 3: offline and physically destroyed - c.Assert(cluster.RemoveStore(3, true), IsNil) + re.NoError(cluster.RemoveStore(3, true)) // store 4: tombstone - c.Assert(cluster.RemoveStore(4, true), IsNil) - c.Assert(cluster.BuryStore(4, false), IsNil) + re.NoError(cluster.RemoveStore(4, true)) + re.NoError(cluster.BuryStore(4, false)) for id := uint64(1); id <= 4; id++ { storeInfo := cluster.GetStore(id) @@ -375,9 +383,9 @@ func (s *testClusterInfoSuite) TestReuseAddress(c *C) { if storeInfo.IsPhysicallyDestroyed() || storeInfo.IsRemoved() { // try to start a new store with the same address with store which is physically destryed or tombstone should be success - c.Assert(cluster.PutStore(newStore), IsNil) + re.NoError(cluster.PutStore(newStore)) } else { - c.Assert(cluster.PutStore(newStore), NotNil) + re.Error(cluster.PutStore(newStore)) } } } @@ -386,11 +394,15 @@ func getTestDeployPath(storeID uint64) string { return fmt.Sprintf("test/store%d", storeID) } -func (s *testClusterInfoSuite) TestUpStore(c *C) { +func TestUpStore(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, err := newTestScheduleConfig() - c.Assert(err, IsNil) - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, nil) + re.NoError(err) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, nil) cluster.ruleManager = placement.NewRuleManager(storage.NewStorageWithMemoryBackend(), cluster, cluster.GetOpts()) if opt.IsPlacementRulesEnabled() { err := cluster.ruleManager.Initialize(opt.GetMaxReplicas(), opt.GetLocationLabels()) @@ -401,43 +413,47 @@ func (s *testClusterInfoSuite) TestUpStore(c *C) { // Put 5 stores. for _, store := range newTestStores(5, "5.0.0") { - c.Assert(cluster.PutStore(store.GetMeta()), IsNil) + re.NoError(cluster.PutStore(store.GetMeta())) } // set store 1 offline - c.Assert(cluster.RemoveStore(1, false), IsNil) + re.NoError(cluster.RemoveStore(1, false)) // up a offline store should be success. - c.Assert(cluster.UpStore(1), IsNil) + re.NoError(cluster.UpStore(1)) // set store 2 offline and physically destroyed - c.Assert(cluster.RemoveStore(2, true), IsNil) - c.Assert(cluster.UpStore(2), NotNil) + re.NoError(cluster.RemoveStore(2, true)) + re.Error(cluster.UpStore(2)) // bury store 2 cluster.checkStores() // store is tombstone err = cluster.UpStore(2) - c.Assert(errors.ErrorEqual(err, errs.ErrStoreRemoved.FastGenByArgs(2)), IsTrue) + re.True(errors.ErrorEqual(err, errs.ErrStoreRemoved.FastGenByArgs(2))) // store 3 is up - c.Assert(cluster.UpStore(3), IsNil) + re.NoError(cluster.UpStore(3)) // store 4 not exist err = cluster.UpStore(10) - c.Assert(errors.ErrorEqual(err, errs.ErrStoreNotFound.FastGenByArgs(4)), IsTrue) + re.True(errors.ErrorEqual(err, errs.ErrStoreNotFound.FastGenByArgs(4))) } -func (s *testClusterInfoSuite) TestRemovingProcess(c *C) { +func TestRemovingProcess(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, err := newTestScheduleConfig() - c.Assert(err, IsNil) - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, nil) + re.NoError(err) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, nil) cluster.SetPrepared() // Put 5 stores. stores := newTestStores(5, "5.0.0") for _, store := range stores { - c.Assert(cluster.PutStore(store.GetMeta()), IsNil) + re.NoError(cluster.PutStore(store.GetMeta())) } regions := newTestRegions(100, 5, 1) var regionInStore1 []*core.RegionInfo @@ -446,19 +462,19 @@ func (s *testClusterInfoSuite) TestRemovingProcess(c *C) { region = region.Clone(core.SetApproximateSize(100)) regionInStore1 = append(regionInStore1, region) } - c.Assert(cluster.putRegion(region), IsNil) + re.NoError(cluster.putRegion(region)) } - c.Assert(len(regionInStore1), Equals, 20) + re.Len(regionInStore1, 20) cluster.progressManager = progress.NewManager() cluster.RemoveStore(1, false) cluster.checkStores() process := "removing-1" // no region moving p, l, cs, err := cluster.progressManager.Status(process) - c.Assert(err, IsNil) - c.Assert(p, Equals, 0.0) - c.Assert(l, Equals, math.MaxFloat64) - c.Assert(cs, Equals, 0.0) + re.NoError(err) + re.Equal(0.0, p) + re.Equal(math.MaxFloat64, l) + re.Equal(0.0, cs) i := 0 // simulate region moving by deleting region from store 1 for _, region := range regionInStore1 { @@ -470,22 +486,26 @@ func (s *testClusterInfoSuite) TestRemovingProcess(c *C) { } cluster.checkStores() p, l, cs, err = cluster.progressManager.Status(process) - c.Assert(err, IsNil) + re.NoError(err) // In above we delete 5 region from store 1, the total count of region in store 1 is 20. // process = 5 / 20 = 0.25 - c.Assert(p, Equals, 0.25) + re.Equal(0.25, p) // Each region is 100MB, we use more than 1s to move 5 region. // speed = 5 * 100MB / 20s = 25MB/s - c.Assert(cs, Equals, 25.0) + re.Equal(25.0, cs) // left second = 15 * 100MB / 25s = 60s - c.Assert(l, Equals, 60.0) + re.Equal(60.0, l) } -func (s *testClusterInfoSuite) TestDeleteStoreUpdatesClusterVersion(c *C) { +func TestDeleteStoreUpdatesClusterVersion(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, err := newTestScheduleConfig() - c.Assert(err, IsNil) - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, nil) + re.NoError(err) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, nil) cluster.ruleManager = placement.NewRuleManager(storage.NewStorageWithMemoryBackend(), cluster, cluster.GetOpts()) if opt.IsPlacementRulesEnabled() { err := cluster.ruleManager.Initialize(opt.GetMaxReplicas(), opt.GetLocationLabels()) @@ -496,48 +516,56 @@ func (s *testClusterInfoSuite) TestDeleteStoreUpdatesClusterVersion(c *C) { // Put 3 new 4.0.9 stores. for _, store := range newTestStores(3, "4.0.9") { - c.Assert(cluster.PutStore(store.GetMeta()), IsNil) + re.NoError(cluster.PutStore(store.GetMeta())) } - c.Assert(cluster.GetClusterVersion(), Equals, "4.0.9") + re.Equal("4.0.9", cluster.GetClusterVersion()) // Upgrade 2 stores to 5.0.0. for _, store := range newTestStores(2, "5.0.0") { - c.Assert(cluster.PutStore(store.GetMeta()), IsNil) + re.NoError(cluster.PutStore(store.GetMeta())) } - c.Assert(cluster.GetClusterVersion(), Equals, "4.0.9") + re.Equal("4.0.9", cluster.GetClusterVersion()) // Bury the other store. - c.Assert(cluster.RemoveStore(3, true), IsNil) + re.NoError(cluster.RemoveStore(3, true)) cluster.checkStores() - c.Assert(cluster.GetClusterVersion(), Equals, "5.0.0") + re.Equal("5.0.0", cluster.GetClusterVersion()) } -func (s *testClusterInfoSuite) TestStoreClusterVersion(c *C) { +func TestStoreClusterVersion(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, err := newTestScheduleConfig() - c.Assert(err, IsNil) - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + re.NoError(err) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) stores := newTestStores(3, "5.0.0") s1, s2, s3 := stores[0].GetMeta(), stores[1].GetMeta(), stores[2].GetMeta() s1.Version = "5.0.1" s2.Version = "5.0.3" s3.Version = "5.0.5" - c.Assert(cluster.PutStore(s2), IsNil) - c.Assert(cluster.GetClusterVersion(), Equals, s2.Version) + re.NoError(cluster.PutStore(s2)) + re.Equal(s2.Version, cluster.GetClusterVersion()) - c.Assert(cluster.PutStore(s1), IsNil) + re.NoError(cluster.PutStore(s1)) // the cluster version should be 5.0.1(the min one) - c.Assert(cluster.GetClusterVersion(), Equals, s1.Version) + re.Equal(s1.Version, cluster.GetClusterVersion()) - c.Assert(cluster.PutStore(s3), IsNil) + re.NoError(cluster.PutStore(s3)) // the cluster version should be 5.0.1(the min one) - c.Assert(cluster.GetClusterVersion(), Equals, s1.Version) + re.Equal(s1.Version, cluster.GetClusterVersion()) } -func (s *testClusterInfoSuite) TestRegionHeartbeatHotStat(c *C) { +func TestRegionHeartbeatHotStat(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, err := newTestScheduleConfig() - c.Assert(err, IsNil) - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, nil) + re.NoError(err) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, nil) newTestStores(4, "2.0.0") peers := []*metapb.Peer{ { @@ -568,34 +596,38 @@ func (s *testClusterInfoSuite) TestRegionHeartbeatHotStat(c *C) { core.SetWrittenBytes(30000*10), core.SetWrittenKeys(300000*10)) err = cluster.processRegionHeartbeat(region) - c.Assert(err, IsNil) + re.NoError(err) // wait HotStat to update items time.Sleep(1 * time.Second) stats := cluster.hotStat.RegionStats(statistics.Write, 0) - c.Assert(stats[1], HasLen, 1) - c.Assert(stats[2], HasLen, 1) - c.Assert(stats[3], HasLen, 1) + re.Len(stats[1], 1) + re.Len(stats[2], 1) + re.Len(stats[3], 1) newPeer := &metapb.Peer{ Id: 4, StoreId: 4, } region = region.Clone(core.WithRemoveStorePeer(2), core.WithAddPeer(newPeer)) err = cluster.processRegionHeartbeat(region) - c.Assert(err, IsNil) + re.NoError(err) // wait HotStat to update items time.Sleep(1 * time.Second) stats = cluster.hotStat.RegionStats(statistics.Write, 0) - c.Assert(stats[1], HasLen, 1) - c.Assert(stats[2], HasLen, 0) - c.Assert(stats[3], HasLen, 1) - c.Assert(stats[4], HasLen, 1) + re.Len(stats[1], 1) + re.Len(stats[2], 0) + re.Len(stats[3], 1) + re.Len(stats[4], 1) } -func (s *testClusterInfoSuite) TestBucketHeartbeat(c *C) { +func TestBucketHeartbeat(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, err := newTestScheduleConfig() - c.Assert(err, IsNil) - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, nil) + re.NoError(err) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, nil) // case1: region is not exist buckets := &metapb.Buckets{ @@ -603,54 +635,58 @@ func (s *testClusterInfoSuite) TestBucketHeartbeat(c *C) { Version: 1, Keys: [][]byte{{'1'}, {'2'}}, } - c.Assert(cluster.processReportBuckets(buckets), NotNil) + re.Error(cluster.processReportBuckets(buckets)) // case2: bucket can be processed after the region update. stores := newTestStores(3, "2.0.0") n, np := uint64(2), uint64(2) regions := newTestRegions(n, n, np) for _, store := range stores { - c.Assert(cluster.putStoreLocked(store), IsNil) + re.NoError(cluster.putStoreLocked(store)) } - c.Assert(cluster.processRegionHeartbeat(regions[0]), IsNil) - c.Assert(cluster.processRegionHeartbeat(regions[1]), IsNil) - c.Assert(cluster.GetRegion(uint64(1)).GetBuckets(), IsNil) - c.Assert(cluster.processReportBuckets(buckets), IsNil) - c.Assert(cluster.GetRegion(uint64(1)).GetBuckets(), DeepEquals, buckets) + re.NoError(cluster.processRegionHeartbeat(regions[0])) + re.NoError(cluster.processRegionHeartbeat(regions[1])) + re.Nil(cluster.GetRegion(uint64(1)).GetBuckets()) + re.NoError(cluster.processReportBuckets(buckets)) + re.Equal(buckets, cluster.GetRegion(uint64(1)).GetBuckets()) // case3: the bucket version is same. - c.Assert(cluster.processReportBuckets(buckets), IsNil) + re.NoError(cluster.processReportBuckets(buckets)) // case4: the bucket version is changed. newBuckets := &metapb.Buckets{ RegionId: 1, Version: 3, Keys: [][]byte{{'1'}, {'2'}}, } - c.Assert(cluster.processReportBuckets(newBuckets), IsNil) - c.Assert(cluster.GetRegion(uint64(1)).GetBuckets(), DeepEquals, newBuckets) + re.NoError(cluster.processReportBuckets(newBuckets)) + re.Equal(newBuckets, cluster.GetRegion(uint64(1)).GetBuckets()) // case5: region update should inherit buckets. newRegion := regions[1].Clone(core.WithIncConfVer(), core.SetBuckets(nil)) cluster.storeConfigManager = config.NewTestStoreConfigManager(nil) config := cluster.storeConfigManager.GetStoreConfig() config.Coprocessor.EnableRegionBucket = true - c.Assert(cluster.processRegionHeartbeat(newRegion), IsNil) - c.Assert(cluster.GetRegion(uint64(1)).GetBuckets().GetKeys(), HasLen, 2) + re.NoError(cluster.processRegionHeartbeat(newRegion)) + re.Len(cluster.GetRegion(uint64(1)).GetBuckets().GetKeys(), 2) // case6: disable region bucket in config.Coprocessor.EnableRegionBucket = false newRegion2 := regions[1].Clone(core.WithIncConfVer(), core.SetBuckets(nil)) - c.Assert(cluster.processRegionHeartbeat(newRegion2), IsNil) - c.Assert(cluster.GetRegion(uint64(1)).GetBuckets(), IsNil) - c.Assert(cluster.GetRegion(uint64(1)).GetBuckets().GetKeys(), HasLen, 0) + re.NoError(cluster.processRegionHeartbeat(newRegion2)) + re.Nil(cluster.GetRegion(uint64(1)).GetBuckets()) + re.Len(cluster.GetRegion(uint64(1)).GetBuckets().GetKeys(), 0) } -func (s *testClusterInfoSuite) TestRegionHeartbeat(c *C) { +func TestRegionHeartbeat(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, err := newTestScheduleConfig() - c.Assert(err, IsNil) - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, nil) + re.NoError(err) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, nil) n, np := uint64(3), uint64(3) @@ -658,32 +694,32 @@ func (s *testClusterInfoSuite) TestRegionHeartbeat(c *C) { regions := newTestRegions(n, n, np) for _, store := range stores { - c.Assert(cluster.putStoreLocked(store), IsNil) + re.NoError(cluster.putStoreLocked(store)) } for i, region := range regions { // region does not exist. - c.Assert(cluster.processRegionHeartbeat(region), IsNil) - checkRegions(c, cluster.core.Regions, regions[:i+1]) - checkRegionsKV(c, cluster.storage, regions[:i+1]) + re.NoError(cluster.processRegionHeartbeat(region)) + checkRegions(re, cluster.core.Regions, regions[:i+1]) + checkRegionsKV(re, cluster.storage, regions[:i+1]) // region is the same, not updated. - c.Assert(cluster.processRegionHeartbeat(region), IsNil) - checkRegions(c, cluster.core.Regions, regions[:i+1]) - checkRegionsKV(c, cluster.storage, regions[:i+1]) + re.NoError(cluster.processRegionHeartbeat(region)) + checkRegions(re, cluster.core.Regions, regions[:i+1]) + checkRegionsKV(re, cluster.storage, regions[:i+1]) origin := region // region is updated. region = origin.Clone(core.WithIncVersion()) regions[i] = region - c.Assert(cluster.processRegionHeartbeat(region), IsNil) - checkRegions(c, cluster.core.Regions, regions[:i+1]) - checkRegionsKV(c, cluster.storage, regions[:i+1]) + re.NoError(cluster.processRegionHeartbeat(region)) + checkRegions(re, cluster.core.Regions, regions[:i+1]) + checkRegionsKV(re, cluster.storage, regions[:i+1]) // region is stale (Version). stale := origin.Clone(core.WithIncConfVer()) - c.Assert(cluster.processRegionHeartbeat(stale), NotNil) - checkRegions(c, cluster.core.Regions, regions[:i+1]) - checkRegionsKV(c, cluster.storage, regions[:i+1]) + re.Error(cluster.processRegionHeartbeat(stale)) + checkRegions(re, cluster.core.Regions, regions[:i+1]) + checkRegionsKV(re, cluster.storage, regions[:i+1]) // region is updated. region = origin.Clone( @@ -691,15 +727,15 @@ func (s *testClusterInfoSuite) TestRegionHeartbeat(c *C) { core.WithIncConfVer(), ) regions[i] = region - c.Assert(cluster.processRegionHeartbeat(region), IsNil) - checkRegions(c, cluster.core.Regions, regions[:i+1]) - checkRegionsKV(c, cluster.storage, regions[:i+1]) + re.NoError(cluster.processRegionHeartbeat(region)) + checkRegions(re, cluster.core.Regions, regions[:i+1]) + checkRegionsKV(re, cluster.storage, regions[:i+1]) // region is stale (ConfVer). stale = origin.Clone(core.WithIncConfVer()) - c.Assert(cluster.processRegionHeartbeat(stale), NotNil) - checkRegions(c, cluster.core.Regions, regions[:i+1]) - checkRegionsKV(c, cluster.storage, regions[:i+1]) + re.Error(cluster.processRegionHeartbeat(stale)) + checkRegions(re, cluster.core.Regions, regions[:i+1]) + checkRegionsKV(re, cluster.storage, regions[:i+1]) // Add a down peer. region = region.Clone(core.WithDownPeers([]*pdpb.PeerStats{ @@ -709,70 +745,70 @@ func (s *testClusterInfoSuite) TestRegionHeartbeat(c *C) { }, })) regions[i] = region - c.Assert(cluster.processRegionHeartbeat(region), IsNil) - checkRegions(c, cluster.core.Regions, regions[:i+1]) + re.NoError(cluster.processRegionHeartbeat(region)) + checkRegions(re, cluster.core.Regions, regions[:i+1]) // Add a pending peer. region = region.Clone(core.WithPendingPeers([]*metapb.Peer{region.GetPeers()[rand.Intn(len(region.GetPeers()))]})) regions[i] = region - c.Assert(cluster.processRegionHeartbeat(region), IsNil) - checkRegions(c, cluster.core.Regions, regions[:i+1]) + re.NoError(cluster.processRegionHeartbeat(region)) + checkRegions(re, cluster.core.Regions, regions[:i+1]) // Clear down peers. region = region.Clone(core.WithDownPeers(nil)) regions[i] = region - c.Assert(cluster.processRegionHeartbeat(region), IsNil) - checkRegions(c, cluster.core.Regions, regions[:i+1]) + re.NoError(cluster.processRegionHeartbeat(region)) + checkRegions(re, cluster.core.Regions, regions[:i+1]) // Clear pending peers. region = region.Clone(core.WithPendingPeers(nil)) regions[i] = region - c.Assert(cluster.processRegionHeartbeat(region), IsNil) - checkRegions(c, cluster.core.Regions, regions[:i+1]) + re.NoError(cluster.processRegionHeartbeat(region)) + checkRegions(re, cluster.core.Regions, regions[:i+1]) // Remove peers. origin = region region = origin.Clone(core.SetPeers(region.GetPeers()[:1])) regions[i] = region - c.Assert(cluster.processRegionHeartbeat(region), IsNil) - checkRegions(c, cluster.core.Regions, regions[:i+1]) - checkRegionsKV(c, cluster.storage, regions[:i+1]) + re.NoError(cluster.processRegionHeartbeat(region)) + checkRegions(re, cluster.core.Regions, regions[:i+1]) + checkRegionsKV(re, cluster.storage, regions[:i+1]) // Add peers. region = origin regions[i] = region - c.Assert(cluster.processRegionHeartbeat(region), IsNil) - checkRegions(c, cluster.core.Regions, regions[:i+1]) - checkRegionsKV(c, cluster.storage, regions[:i+1]) + re.NoError(cluster.processRegionHeartbeat(region)) + checkRegions(re, cluster.core.Regions, regions[:i+1]) + checkRegionsKV(re, cluster.storage, regions[:i+1]) // Change leader. region = region.Clone(core.WithLeader(region.GetPeers()[1])) regions[i] = region - c.Assert(cluster.processRegionHeartbeat(region), IsNil) - checkRegions(c, cluster.core.Regions, regions[:i+1]) + re.NoError(cluster.processRegionHeartbeat(region)) + checkRegions(re, cluster.core.Regions, regions[:i+1]) // Change ApproximateSize. region = region.Clone(core.SetApproximateSize(144)) regions[i] = region - c.Assert(cluster.processRegionHeartbeat(region), IsNil) - checkRegions(c, cluster.core.Regions, regions[:i+1]) + re.NoError(cluster.processRegionHeartbeat(region)) + checkRegions(re, cluster.core.Regions, regions[:i+1]) // Change ApproximateKeys. region = region.Clone(core.SetApproximateKeys(144000)) regions[i] = region - c.Assert(cluster.processRegionHeartbeat(region), IsNil) - checkRegions(c, cluster.core.Regions, regions[:i+1]) + re.NoError(cluster.processRegionHeartbeat(region)) + checkRegions(re, cluster.core.Regions, regions[:i+1]) // Change bytes written. region = region.Clone(core.SetWrittenBytes(24000)) regions[i] = region - c.Assert(cluster.processRegionHeartbeat(region), IsNil) - checkRegions(c, cluster.core.Regions, regions[:i+1]) + re.NoError(cluster.processRegionHeartbeat(region)) + checkRegions(re, cluster.core.Regions, regions[:i+1]) // Change bytes read. region = region.Clone(core.SetReadBytes(1080000)) regions[i] = region - c.Assert(cluster.processRegionHeartbeat(region), IsNil) - checkRegions(c, cluster.core.Regions, regions[:i+1]) + re.NoError(cluster.processRegionHeartbeat(region)) + checkRegions(re, cluster.core.Regions, regions[:i+1]) } regionCounts := make(map[uint64]int) @@ -782,31 +818,31 @@ func (s *testClusterInfoSuite) TestRegionHeartbeat(c *C) { } } for id, count := range regionCounts { - c.Assert(cluster.GetStoreRegionCount(id), Equals, count) + re.Equal(count, cluster.GetStoreRegionCount(id)) } for _, region := range cluster.GetRegions() { - checkRegion(c, region, regions[region.GetID()]) + checkRegion(re, region, regions[region.GetID()]) } for _, region := range cluster.GetMetaRegions() { - c.Assert(region, DeepEquals, regions[region.GetId()].GetMeta()) + re.Equal(regions[region.GetId()].GetMeta(), region) } for _, region := range regions { for _, store := range cluster.GetRegionStores(region) { - c.Assert(region.GetStorePeer(store.GetID()), NotNil) + re.NotNil(region.GetStorePeer(store.GetID())) } for _, store := range cluster.GetFollowerStores(region) { peer := region.GetStorePeer(store.GetID()) - c.Assert(peer.GetId(), Not(Equals), region.GetLeader().GetId()) + re.NotEqual(region.GetLeader().GetId(), peer.GetId()) } } for _, store := range cluster.core.Stores.GetStores() { - c.Assert(store.GetLeaderCount(), Equals, cluster.core.Regions.GetStoreLeaderCount(store.GetID())) - c.Assert(store.GetRegionCount(), Equals, cluster.core.Regions.GetStoreRegionCount(store.GetID())) - c.Assert(store.GetLeaderSize(), Equals, cluster.core.Regions.GetStoreLeaderRegionSize(store.GetID())) - c.Assert(store.GetRegionSize(), Equals, cluster.core.Regions.GetStoreRegionSize(store.GetID())) + re.Equal(cluster.core.Regions.GetStoreLeaderCount(store.GetID()), store.GetLeaderCount()) + re.Equal(cluster.core.Regions.GetStoreRegionCount(store.GetID()), store.GetRegionCount()) + re.Equal(cluster.core.Regions.GetStoreLeaderRegionSize(store.GetID()), store.GetLeaderSize()) + re.Equal(cluster.core.Regions.GetStoreRegionSize(store.GetID()), store.GetRegionSize()) } // Test with storage. @@ -814,9 +850,9 @@ func (s *testClusterInfoSuite) TestRegionHeartbeat(c *C) { for _, region := range regions { tmp := &metapb.Region{} ok, err := storage.LoadRegion(region.GetID(), tmp) - c.Assert(ok, IsTrue) - c.Assert(err, IsNil) - c.Assert(tmp, DeepEquals, region.GetMeta()) + re.True(ok) + re.NoError(err) + re.Equal(region.GetMeta(), tmp) } // Check overlap with stale version @@ -826,45 +862,49 @@ func (s *testClusterInfoSuite) TestRegionHeartbeat(c *C) { core.WithNewRegionID(10000), core.WithDecVersion(), ) - c.Assert(cluster.processRegionHeartbeat(overlapRegion), NotNil) + re.Error(cluster.processRegionHeartbeat(overlapRegion)) region := &metapb.Region{} ok, err := storage.LoadRegion(regions[n-1].GetID(), region) - c.Assert(ok, IsTrue) - c.Assert(err, IsNil) - c.Assert(region, DeepEquals, regions[n-1].GetMeta()) + re.True(ok) + re.NoError(err) + re.Equal(regions[n-1].GetMeta(), region) ok, err = storage.LoadRegion(regions[n-2].GetID(), region) - c.Assert(ok, IsTrue) - c.Assert(err, IsNil) - c.Assert(region, DeepEquals, regions[n-2].GetMeta()) + re.True(ok) + re.NoError(err) + re.Equal(regions[n-2].GetMeta(), region) ok, err = storage.LoadRegion(overlapRegion.GetID(), region) - c.Assert(ok, IsFalse) - c.Assert(err, IsNil) + re.False(ok) + re.NoError(err) // Check overlap overlapRegion = regions[n-1].Clone( core.WithStartKey(regions[n-2].GetStartKey()), core.WithNewRegionID(regions[n-1].GetID()+1), ) - c.Assert(cluster.processRegionHeartbeat(overlapRegion), IsNil) + re.NoError(cluster.processRegionHeartbeat(overlapRegion)) region = &metapb.Region{} ok, err = storage.LoadRegion(regions[n-1].GetID(), region) - c.Assert(ok, IsFalse) - c.Assert(err, IsNil) + re.False(ok) + re.NoError(err) ok, err = storage.LoadRegion(regions[n-2].GetID(), region) - c.Assert(ok, IsFalse) - c.Assert(err, IsNil) + re.False(ok) + re.NoError(err) ok, err = storage.LoadRegion(overlapRegion.GetID(), region) - c.Assert(ok, IsTrue) - c.Assert(err, IsNil) - c.Assert(region, DeepEquals, overlapRegion.GetMeta()) + re.True(ok) + re.NoError(err) + re.Equal(overlapRegion.GetMeta(), region) } } -func (s *testClusterInfoSuite) TestRegionFlowChanged(c *C) { +func TestRegionFlowChanged(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, err := newTestScheduleConfig() - c.Assert(err, IsNil) - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, nil) + re.NoError(err) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, nil) regions := []*core.RegionInfo{core.NewTestRegionInfo([]byte{}, []byte{})} processRegions := func(regions []*core.RegionInfo) { for _, r := range regions { @@ -878,14 +918,18 @@ func (s *testClusterInfoSuite) TestRegionFlowChanged(c *C) { regions[0] = region.Clone(core.SetReadBytes(1000)) processRegions(regions) newRegion := cluster.GetRegion(region.GetID()) - c.Assert(newRegion.GetBytesRead(), Equals, uint64(1000)) + re.Equal(uint64(1000), newRegion.GetBytesRead()) } -func (s *testClusterInfoSuite) TestRegionSizeChanged(c *C) { +func TestRegionSizeChanged(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, err := newTestScheduleConfig() - c.Assert(err, IsNil) - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, nil) + re.NoError(err) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, nil) cluster.regionStats = statistics.NewRegionStatistics(cluster.GetOpts(), cluster.ruleManager, cluster.storeConfigManager) region := newTestRegions(1, 3, 3)[0] cluster.opt.GetMaxMergeRegionKeys() @@ -899,7 +943,7 @@ func (s *testClusterInfoSuite) TestRegionSizeChanged(c *C) { ) cluster.processRegionHeartbeat(region) regionID := region.GetID() - c.Assert(cluster.regionStats.IsRegionStatsType(regionID, statistics.UndersizedRegion), IsTrue) + re.True(cluster.regionStats.IsRegionStatsType(regionID, statistics.UndersizedRegion)) // Test ApproximateSize and ApproximateKeys change. region = region.Clone( core.WithLeader(region.GetPeers()[2]), @@ -908,53 +952,61 @@ func (s *testClusterInfoSuite) TestRegionSizeChanged(c *C) { core.SetFromHeartbeat(true), ) cluster.processRegionHeartbeat(region) - c.Assert(cluster.regionStats.IsRegionStatsType(regionID, statistics.UndersizedRegion), IsFalse) + re.False(cluster.regionStats.IsRegionStatsType(regionID, statistics.UndersizedRegion)) // Test MaxMergeRegionSize and MaxMergeRegionKeys change. cluster.opt.SetMaxMergeRegionSize((uint64(curMaxMergeSize + 2))) cluster.opt.SetMaxMergeRegionKeys((uint64(curMaxMergeKeys + 2))) cluster.processRegionHeartbeat(region) - c.Assert(cluster.regionStats.IsRegionStatsType(regionID, statistics.UndersizedRegion), IsTrue) + re.True(cluster.regionStats.IsRegionStatsType(regionID, statistics.UndersizedRegion)) cluster.opt.SetMaxMergeRegionSize((uint64(curMaxMergeSize))) cluster.opt.SetMaxMergeRegionKeys((uint64(curMaxMergeKeys))) cluster.processRegionHeartbeat(region) - c.Assert(cluster.regionStats.IsRegionStatsType(regionID, statistics.UndersizedRegion), IsFalse) + re.False(cluster.regionStats.IsRegionStatsType(regionID, statistics.UndersizedRegion)) } -func (s *testClusterInfoSuite) TestConcurrentReportBucket(c *C) { +func TestConcurrentReportBucket(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, err := newTestScheduleConfig() - c.Assert(err, IsNil) - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, nil) + re.NoError(err) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, nil) regions := []*core.RegionInfo{core.NewTestRegionInfo([]byte{}, []byte{})} - heartbeatRegions(c, cluster, regions) - c.Assert(cluster.GetRegion(0), NotNil) + heartbeatRegions(re, cluster, regions) + re.NotNil(cluster.GetRegion(0)) bucket1 := &metapb.Buckets{RegionId: 0, Version: 3} bucket2 := &metapb.Buckets{RegionId: 0, Version: 2} var wg sync.WaitGroup wg.Add(1) - c.Assert(failpoint.Enable("github.com/tikv/pd/server/cluster/concurrentBucketHeartbeat", "return(true)"), IsNil) + re.NoError(failpoint.Enable("github.com/tikv/pd/server/cluster/concurrentBucketHeartbeat", "return(true)")) go func() { defer wg.Done() cluster.processReportBuckets(bucket1) }() time.Sleep(100 * time.Millisecond) - c.Assert(failpoint.Disable("github.com/tikv/pd/server/cluster/concurrentBucketHeartbeat"), IsNil) - c.Assert(cluster.processReportBuckets(bucket2), IsNil) + re.NoError(failpoint.Disable("github.com/tikv/pd/server/cluster/concurrentBucketHeartbeat")) + re.NoError(cluster.processReportBuckets(bucket2)) wg.Wait() - c.Assert(cluster.GetRegion(0).GetBuckets(), DeepEquals, bucket1) + re.Equal(bucket1, cluster.GetRegion(0).GetBuckets()) } -func (s *testClusterInfoSuite) TestConcurrentRegionHeartbeat(c *C) { +func TestConcurrentRegionHeartbeat(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, err := newTestScheduleConfig() - c.Assert(err, IsNil) - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, nil) + re.NoError(err) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, nil) regions := []*core.RegionInfo{core.NewTestRegionInfo([]byte{}, []byte{})} regions = core.SplitRegions(regions) - heartbeatRegions(c, cluster, regions) + heartbeatRegions(re, cluster, regions) // Merge regions manually source, target := regions[0], regions[1] @@ -968,25 +1020,29 @@ func (s *testClusterInfoSuite) TestConcurrentRegionHeartbeat(c *C) { var wg sync.WaitGroup wg.Add(1) - c.Assert(failpoint.Enable("github.com/tikv/pd/server/cluster/concurrentRegionHeartbeat", "return(true)"), IsNil) + re.NoError(failpoint.Enable("github.com/tikv/pd/server/cluster/concurrentRegionHeartbeat", "return(true)")) go func() { defer wg.Done() cluster.processRegionHeartbeat(source) }() time.Sleep(100 * time.Millisecond) - c.Assert(failpoint.Disable("github.com/tikv/pd/server/cluster/concurrentRegionHeartbeat"), IsNil) - c.Assert(cluster.processRegionHeartbeat(target), IsNil) + re.NoError(failpoint.Disable("github.com/tikv/pd/server/cluster/concurrentRegionHeartbeat")) + re.NoError(cluster.processRegionHeartbeat(target)) wg.Wait() - checkRegion(c, cluster.GetRegionByKey([]byte{}), target) + checkRegion(re, cluster.GetRegionByKey([]byte{}), target) } -func (s *testClusterInfoSuite) TestRegionLabelIsolationLevel(c *C) { +func TestRegionLabelIsolationLevel(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, err := newTestScheduleConfig() cfg := opt.GetReplicationConfig() cfg.LocationLabels = []string{"zone"} opt.SetReplicationConfig(cfg) - c.Assert(err, IsNil) - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + re.NoError(err) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) for i := uint64(1); i <= 4; i++ { var labels []*metapb.StoreLabel @@ -1001,7 +1057,7 @@ func (s *testClusterInfoSuite) TestRegionLabelIsolationLevel(c *C) { State: metapb.StoreState_Up, Labels: labels, } - c.Assert(cluster.putStoreLocked(core.NewStoreInfo(store)), IsNil) + re.NoError(cluster.putStoreLocked(core.NewStoreInfo(store))) } peers := make([]*metapb.Peer, 0, 4) @@ -1022,52 +1078,56 @@ func (s *testClusterInfoSuite) TestRegionLabelIsolationLevel(c *C) { EndKey: []byte{byte(2)}, } r := core.NewRegionInfo(region, peers[0]) - c.Assert(cluster.putRegion(r), IsNil) + re.NoError(cluster.putRegion(r)) cluster.updateRegionsLabelLevelStats([]*core.RegionInfo{r}) counter := cluster.labelLevelStats.GetLabelCounter() - c.Assert(counter["none"], Equals, 0) - c.Assert(counter["zone"], Equals, 1) + re.Equal(0, counter["none"]) + re.Equal(1, counter["zone"]) } -func heartbeatRegions(c *C, cluster *RaftCluster, regions []*core.RegionInfo) { +func heartbeatRegions(re *require.Assertions, cluster *RaftCluster, regions []*core.RegionInfo) { // Heartbeat and check region one by one. for _, r := range regions { - c.Assert(cluster.processRegionHeartbeat(r), IsNil) + re.NoError(cluster.processRegionHeartbeat(r)) - checkRegion(c, cluster.GetRegion(r.GetID()), r) - checkRegion(c, cluster.GetRegionByKey(r.GetStartKey()), r) + checkRegion(re, cluster.GetRegion(r.GetID()), r) + checkRegion(re, cluster.GetRegionByKey(r.GetStartKey()), r) if len(r.GetEndKey()) > 0 { end := r.GetEndKey()[0] - checkRegion(c, cluster.GetRegionByKey([]byte{end - 1}), r) + checkRegion(re, cluster.GetRegionByKey([]byte{end - 1}), r) } } // Check all regions after handling all heartbeats. for _, r := range regions { - checkRegion(c, cluster.GetRegion(r.GetID()), r) - checkRegion(c, cluster.GetRegionByKey(r.GetStartKey()), r) + checkRegion(re, cluster.GetRegion(r.GetID()), r) + checkRegion(re, cluster.GetRegionByKey(r.GetStartKey()), r) if len(r.GetEndKey()) > 0 { end := r.GetEndKey()[0] - checkRegion(c, cluster.GetRegionByKey([]byte{end - 1}), r) + checkRegion(re, cluster.GetRegionByKey([]byte{end - 1}), r) result := cluster.GetRegionByKey([]byte{end + 1}) - c.Assert(result.GetID(), Not(Equals), r.GetID()) + re.NotEqual(r.GetID(), result.GetID()) } } } -func (s *testClusterInfoSuite) TestHeartbeatSplit(c *C) { +func TestHeartbeatSplit(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, err := newTestScheduleConfig() - c.Assert(err, IsNil) - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, nil) + re.NoError(err) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, nil) // 1: [nil, nil) region1 := core.NewRegionInfo(&metapb.Region{Id: 1, RegionEpoch: &metapb.RegionEpoch{Version: 1, ConfVer: 1}}, nil) - c.Assert(cluster.processRegionHeartbeat(region1), IsNil) - checkRegion(c, cluster.GetRegionByKey([]byte("foo")), region1) + re.NoError(cluster.processRegionHeartbeat(region1)) + checkRegion(re, cluster.GetRegionByKey([]byte("foo")), region1) // split 1 to 2: [nil, m) 1: [m, nil), sync 2 first. region1 = region1.Clone( @@ -1075,13 +1135,13 @@ func (s *testClusterInfoSuite) TestHeartbeatSplit(c *C) { core.WithIncVersion(), ) region2 := core.NewRegionInfo(&metapb.Region{Id: 2, EndKey: []byte("m"), RegionEpoch: &metapb.RegionEpoch{Version: 1, ConfVer: 1}}, nil) - c.Assert(cluster.processRegionHeartbeat(region2), IsNil) - checkRegion(c, cluster.GetRegionByKey([]byte("a")), region2) + re.NoError(cluster.processRegionHeartbeat(region2)) + checkRegion(re, cluster.GetRegionByKey([]byte("a")), region2) // [m, nil) is missing before r1's heartbeat. - c.Assert(cluster.GetRegionByKey([]byte("z")), IsNil) + re.Nil(cluster.GetRegionByKey([]byte("z"))) - c.Assert(cluster.processRegionHeartbeat(region1), IsNil) - checkRegion(c, cluster.GetRegionByKey([]byte("z")), region1) + re.NoError(cluster.processRegionHeartbeat(region1)) + checkRegion(re, cluster.GetRegionByKey([]byte("z")), region1) // split 1 to 3: [m, q) 1: [q, nil), sync 1 first. region1 = region1.Clone( @@ -1089,20 +1149,24 @@ func (s *testClusterInfoSuite) TestHeartbeatSplit(c *C) { core.WithIncVersion(), ) region3 := core.NewRegionInfo(&metapb.Region{Id: 3, StartKey: []byte("m"), EndKey: []byte("q"), RegionEpoch: &metapb.RegionEpoch{Version: 1, ConfVer: 1}}, nil) - c.Assert(cluster.processRegionHeartbeat(region1), IsNil) - checkRegion(c, cluster.GetRegionByKey([]byte("z")), region1) - checkRegion(c, cluster.GetRegionByKey([]byte("a")), region2) + re.NoError(cluster.processRegionHeartbeat(region1)) + checkRegion(re, cluster.GetRegionByKey([]byte("z")), region1) + checkRegion(re, cluster.GetRegionByKey([]byte("a")), region2) // [m, q) is missing before r3's heartbeat. - c.Assert(cluster.GetRegionByKey([]byte("n")), IsNil) - c.Assert(cluster.processRegionHeartbeat(region3), IsNil) - checkRegion(c, cluster.GetRegionByKey([]byte("n")), region3) + re.Nil(cluster.GetRegionByKey([]byte("n"))) + re.NoError(cluster.processRegionHeartbeat(region3)) + checkRegion(re, cluster.GetRegionByKey([]byte("n")), region3) } -func (s *testClusterInfoSuite) TestRegionSplitAndMerge(c *C) { +func TestRegionSplitAndMerge(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, err := newTestScheduleConfig() - c.Assert(err, IsNil) - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, nil) + re.NoError(err) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, nil) regions := []*core.RegionInfo{core.NewTestRegionInfo([]byte{}, []byte{})} @@ -1112,13 +1176,13 @@ func (s *testClusterInfoSuite) TestRegionSplitAndMerge(c *C) { // Split. for i := 0; i < n; i++ { regions = core.SplitRegions(regions) - heartbeatRegions(c, cluster, regions) + heartbeatRegions(re, cluster, regions) } // Merge. for i := 0; i < n; i++ { regions = core.MergeRegions(regions) - heartbeatRegions(c, cluster, regions) + heartbeatRegions(re, cluster, regions) } // Split twice and merge once. @@ -1128,15 +1192,19 @@ func (s *testClusterInfoSuite) TestRegionSplitAndMerge(c *C) { } else { regions = core.SplitRegions(regions) } - heartbeatRegions(c, cluster, regions) + heartbeatRegions(re, cluster, regions) } } -func (s *testClusterInfoSuite) TestOfflineAndMerge(c *C) { +func TestOfflineAndMerge(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, err := newTestScheduleConfig() - c.Assert(err, IsNil) - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, nil) + re.NoError(err) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, nil) cluster.ruleManager = placement.NewRuleManager(storage.NewStorageWithMemoryBackend(), cluster, cluster.GetOpts()) if opt.IsPlacementRulesEnabled() { err := cluster.ruleManager.Initialize(opt.GetMaxReplicas(), opt.GetLocationLabels()) @@ -1145,11 +1213,11 @@ func (s *testClusterInfoSuite) TestOfflineAndMerge(c *C) { } } cluster.regionStats = statistics.NewRegionStatistics(cluster.GetOpts(), cluster.ruleManager, cluster.storeConfigManager) - cluster.coordinator = newCoordinator(s.ctx, cluster, nil) + cluster.coordinator = newCoordinator(ctx, cluster, nil) // Put 4 stores. for _, store := range newTestStores(4, "5.0.0") { - c.Assert(cluster.PutStore(store.GetMeta()), IsNil) + re.NoError(cluster.PutStore(store.GetMeta())) } peers := []*metapb.Peer{ @@ -1174,35 +1242,39 @@ func (s *testClusterInfoSuite) TestOfflineAndMerge(c *C) { regions := []*core.RegionInfo{origin} // store 1: up -> offline - c.Assert(cluster.RemoveStore(1, false), IsNil) + re.NoError(cluster.RemoveStore(1, false)) store := cluster.GetStore(1) - c.Assert(store.IsRemoving(), IsTrue) + re.True(store.IsRemoving()) // Split. n := 7 for i := 0; i < n; i++ { regions = core.SplitRegions(regions) } - heartbeatRegions(c, cluster, regions) - c.Assert(cluster.GetOfflineRegionStatsByType(statistics.OfflinePeer), HasLen, len(regions)) + heartbeatRegions(re, cluster, regions) + re.Len(cluster.GetOfflineRegionStatsByType(statistics.OfflinePeer), len(regions)) // Merge. for i := 0; i < n; i++ { regions = core.MergeRegions(regions) - heartbeatRegions(c, cluster, regions) - c.Assert(cluster.GetOfflineRegionStatsByType(statistics.OfflinePeer), HasLen, len(regions)) + heartbeatRegions(re, cluster, regions) + re.Len(cluster.GetOfflineRegionStatsByType(statistics.OfflinePeer), len(regions)) } } -func (s *testClusterInfoSuite) TestSyncConfig(c *C) { +func TestSyncConfig(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, err := newTestScheduleConfig() - c.Assert(err, IsNil) - tc := newTestCluster(s.ctx, opt) + re.NoError(err) + tc := newTestCluster(ctx, opt) stores := newTestStores(5, "2.0.0") for _, s := range stores { - c.Assert(tc.putStoreLocked(s), IsNil) + re.NoError(tc.putStoreLocked(s)) } - c.Assert(tc.getUpStores(), HasLen, 5) + re.Len(tc.getUpStores(), 5) testdata := []struct { whiteList []string @@ -1220,20 +1292,24 @@ func (s *testClusterInfoSuite) TestSyncConfig(c *C) { for _, v := range testdata { tc.storeConfigManager = config.NewTestStoreConfigManager(v.whiteList) - c.Assert(tc.GetStoreConfig().GetRegionMaxSize(), Equals, uint64(144)) - c.Assert(syncConfig(tc.storeConfigManager, tc.GetStores()), Equals, v.updated) - c.Assert(tc.GetStoreConfig().GetRegionMaxSize(), Equals, v.maxRegionSize) + re.Equal(uint64(144), tc.GetStoreConfig().GetRegionMaxSize()) + re.Equal(v.updated, syncConfig(tc.storeConfigManager, tc.GetStores())) + re.Equal(v.maxRegionSize, tc.GetStoreConfig().GetRegionMaxSize()) } } -func (s *testClusterInfoSuite) TestUpdateStorePendingPeerCount(c *C) { +func TestUpdateStorePendingPeerCount(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, err := newTestScheduleConfig() - c.Assert(err, IsNil) - tc := newTestCluster(s.ctx, opt) - tc.RaftCluster.coordinator = newCoordinator(s.ctx, tc.RaftCluster, nil) + re.NoError(err) + tc := newTestCluster(ctx, opt) + tc.RaftCluster.coordinator = newCoordinator(ctx, tc.RaftCluster, nil) stores := newTestStores(5, "2.0.0") for _, s := range stores { - c.Assert(tc.putStoreLocked(s), IsNil) + re.NoError(tc.putStoreLocked(s)) } peers := []*metapb.Peer{ { @@ -1254,14 +1330,16 @@ func (s *testClusterInfoSuite) TestUpdateStorePendingPeerCount(c *C) { }, } origin := core.NewRegionInfo(&metapb.Region{Id: 1, Peers: peers[:3]}, peers[0], core.WithPendingPeers(peers[1:3])) - c.Assert(tc.processRegionHeartbeat(origin), IsNil) - checkPendingPeerCount([]int{0, 1, 1, 0}, tc.RaftCluster, c) + re.NoError(tc.processRegionHeartbeat(origin)) + checkPendingPeerCount([]int{0, 1, 1, 0}, tc.RaftCluster, re) newRegion := core.NewRegionInfo(&metapb.Region{Id: 1, Peers: peers[1:]}, peers[1], core.WithPendingPeers(peers[3:4])) - c.Assert(tc.processRegionHeartbeat(newRegion), IsNil) - checkPendingPeerCount([]int{0, 0, 0, 1}, tc.RaftCluster, c) + re.NoError(tc.processRegionHeartbeat(newRegion)) + checkPendingPeerCount([]int{0, 0, 0, 1}, tc.RaftCluster, re) } -func (s *testClusterInfoSuite) TestTopologyWeight(c *C) { +func TestTopologyWeight(t *testing.T) { + re := require.New(t) + labels := []string{"zone", "rack", "host"} zones := []string{"z1", "z2", "z3"} racks := []string{"r1", "r2", "r3"} @@ -1287,17 +1365,21 @@ func (s *testClusterInfoSuite) TestTopologyWeight(c *C) { } } - c.Assert(getStoreTopoWeight(testStore, stores, labels), Equals, 1.0/3/3/4) + re.Equal(1.0/3/3/4, getStoreTopoWeight(testStore, stores, labels)) } -func (s *testClusterInfoSuite) TestCalculateStoreSize1(c *C) { +func TestCalculateStoreSize1(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, err := newTestScheduleConfig() - c.Assert(err, IsNil) + re.NoError(err) cfg := opt.GetReplicationConfig() cfg.EnablePlacementRules = true opt.SetReplicationConfig(cfg) - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, nil) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, nil) cluster.regionStats = statistics.NewRegionStatistics(cluster.GetOpts(), cluster.ruleManager, cluster.storeConfigManager) // Put 10 stores. @@ -1324,7 +1406,7 @@ func (s *testClusterInfoSuite) TestCalculateStoreSize1(c *C) { }, }...) s := store.Clone(core.SetStoreLabels(labels)) - c.Assert(cluster.PutStore(s.GetMeta()), IsNil) + re.NoError(cluster.PutStore(s.GetMeta())) } cluster.ruleManager.SetRule( @@ -1354,29 +1436,33 @@ func (s *testClusterInfoSuite) TestCalculateStoreSize1(c *C) { regions := newTestRegions(100, 10, 5) for _, region := range regions { - c.Assert(cluster.putRegion(region), IsNil) + re.NoError(cluster.putRegion(region)) } stores := cluster.GetStores() store := cluster.GetStore(1) // 100 * 100 * 2 (placement rule) / 4 (host) * 0.9 = 4500 - c.Assert(cluster.getThreshold(stores, store), Equals, 4500.0) + re.Equal(4500.0, cluster.getThreshold(stores, store)) cluster.opt.SetPlacementRuleEnabled(false) cluster.opt.SetLocationLabels([]string{"zone", "rack", "host"}) // 30000 (total region size) / 3 (zone) / 4 (host) * 0.9 = 2250 - c.Assert(cluster.getThreshold(stores, store), Equals, 2250.0) + re.Equal(2250.0, cluster.getThreshold(stores, store)) } -func (s *testClusterInfoSuite) TestCalculateStoreSize2(c *C) { +func TestCalculateStoreSize2(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, err := newTestScheduleConfig() - c.Assert(err, IsNil) + re.NoError(err) cfg := opt.GetReplicationConfig() cfg.EnablePlacementRules = true opt.SetReplicationConfig(cfg) opt.SetMaxReplicas(3) - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, nil) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, nil) cluster.regionStats = statistics.NewRegionStatistics(cluster.GetOpts(), cluster.ruleManager, cluster.storeConfigManager) // Put 10 stores. @@ -1401,7 +1487,7 @@ func (s *testClusterInfoSuite) TestCalculateStoreSize2(c *C) { } labels = append(labels, []*metapb.StoreLabel{{Key: "rack", Value: "r1"}, {Key: "host", Value: "h1"}}...) s := store.Clone(core.SetStoreLabels(labels)) - c.Assert(cluster.PutStore(s.GetMeta()), IsNil) + re.NoError(cluster.PutStore(s.GetMeta())) } cluster.ruleManager.SetRule( @@ -1431,129 +1517,115 @@ func (s *testClusterInfoSuite) TestCalculateStoreSize2(c *C) { regions := newTestRegions(100, 10, 5) for _, region := range regions { - c.Assert(cluster.putRegion(region), IsNil) + re.NoError(cluster.putRegion(region)) } stores := cluster.GetStores() store := cluster.GetStore(1) // 100 * 100 * 4 (total region size) / 2 (dc) / 2 (logic) / 3 (host) * 0.9 = 3000 - c.Assert(cluster.getThreshold(stores, store), Equals, 3000.0) + re.Equal(3000.0, cluster.getThreshold(stores, store)) } -var _ = Suite(&testStoresInfoSuite{}) - -type testStoresInfoSuite struct{} - -func (s *testStoresInfoSuite) TestStores(c *C) { +func TestStores(t *testing.T) { + re := require.New(t) n := uint64(10) cache := core.NewStoresInfo() stores := newTestStores(n, "2.0.0") for i, store := range stores { id := store.GetID() - c.Assert(cache.GetStore(id), IsNil) - c.Assert(cache.PauseLeaderTransfer(id), NotNil) + re.Nil(cache.GetStore(id)) + re.Error(cache.PauseLeaderTransfer(id)) cache.SetStore(store) - c.Assert(cache.GetStore(id), DeepEquals, store) - c.Assert(cache.GetStoreCount(), Equals, i+1) - c.Assert(cache.PauseLeaderTransfer(id), IsNil) - c.Assert(cache.GetStore(id).AllowLeaderTransfer(), IsFalse) - c.Assert(cache.PauseLeaderTransfer(id), NotNil) + re.Equal(store, cache.GetStore(id)) + re.Equal(i+1, cache.GetStoreCount()) + re.NoError(cache.PauseLeaderTransfer(id)) + re.False(cache.GetStore(id).AllowLeaderTransfer()) + re.Error(cache.PauseLeaderTransfer(id)) cache.ResumeLeaderTransfer(id) - c.Assert(cache.GetStore(id).AllowLeaderTransfer(), IsTrue) + re.True(cache.GetStore(id).AllowLeaderTransfer()) } - c.Assert(cache.GetStoreCount(), Equals, int(n)) + re.Equal(int(n), cache.GetStoreCount()) for _, store := range cache.GetStores() { - c.Assert(store, DeepEquals, stores[store.GetID()-1]) + re.Equal(stores[store.GetID()-1], store) } for _, store := range cache.GetMetaStores() { - c.Assert(store, DeepEquals, stores[store.GetId()-1].GetMeta()) + re.Equal(stores[store.GetId()-1].GetMeta(), store) } - c.Assert(cache.GetStoreCount(), Equals, int(n)) + re.Equal(int(n), cache.GetStoreCount()) } -var _ = Suite(&testRegionsInfoSuite{}) - -type testRegionsInfoSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *testRegionsInfoSuite) TearDownTest(c *C) { - s.cancel() -} - -func (s *testRegionsInfoSuite) SetUpTest(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) -} +func Test(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() -func (s *testRegionsInfoSuite) Test(c *C) { n, np := uint64(10), uint64(3) regions := newTestRegions(n, n, np) _, opts, err := newTestScheduleConfig() - c.Assert(err, IsNil) - tc := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opts, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + re.NoError(err) + tc := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opts, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) cache := tc.core.Regions for i := uint64(0); i < n; i++ { region := regions[i] regionKey := []byte{byte(i)} - c.Assert(cache.GetRegion(i), IsNil) - c.Assert(cache.GetRegionByKey(regionKey), IsNil) - checkRegions(c, cache, regions[0:i]) + re.Nil(cache.GetRegion(i)) + re.Nil(cache.GetRegionByKey(regionKey)) + checkRegions(re, cache, regions[0:i]) cache.SetRegion(region) - checkRegion(c, cache.GetRegion(i), region) - checkRegion(c, cache.GetRegionByKey(regionKey), region) - checkRegions(c, cache, regions[0:(i+1)]) + checkRegion(re, cache.GetRegion(i), region) + checkRegion(re, cache.GetRegionByKey(regionKey), region) + checkRegions(re, cache, regions[0:(i+1)]) // previous region if i == 0 { - c.Assert(cache.GetPrevRegionByKey(regionKey), IsNil) + re.Nil(cache.GetPrevRegionByKey(regionKey)) } else { - checkRegion(c, cache.GetPrevRegionByKey(regionKey), regions[i-1]) + checkRegion(re, cache.GetPrevRegionByKey(regionKey), regions[i-1]) } // Update leader to peer np-1. newRegion := region.Clone(core.WithLeader(region.GetPeers()[np-1])) regions[i] = newRegion cache.SetRegion(newRegion) - checkRegion(c, cache.GetRegion(i), newRegion) - checkRegion(c, cache.GetRegionByKey(regionKey), newRegion) - checkRegions(c, cache, regions[0:(i+1)]) + checkRegion(re, cache.GetRegion(i), newRegion) + checkRegion(re, cache.GetRegionByKey(regionKey), newRegion) + checkRegions(re, cache, regions[0:(i+1)]) cache.RemoveRegion(region) - c.Assert(cache.GetRegion(i), IsNil) - c.Assert(cache.GetRegionByKey(regionKey), IsNil) - checkRegions(c, cache, regions[0:i]) + re.Nil(cache.GetRegion(i)) + re.Nil(cache.GetRegionByKey(regionKey)) + checkRegions(re, cache, regions[0:i]) // Reset leader to peer 0. newRegion = region.Clone(core.WithLeader(region.GetPeers()[0])) regions[i] = newRegion cache.SetRegion(newRegion) - checkRegion(c, cache.GetRegion(i), newRegion) - checkRegions(c, cache, regions[0:(i+1)]) - checkRegion(c, cache.GetRegionByKey(regionKey), newRegion) + checkRegion(re, cache.GetRegion(i), newRegion) + checkRegions(re, cache, regions[0:(i+1)]) + checkRegion(re, cache.GetRegionByKey(regionKey), newRegion) } for i := uint64(0); i < n; i++ { region := tc.RandLeaderRegion(i, []core.KeyRange{core.NewKeyRange("", "")}, schedule.IsRegionHealthy) - c.Assert(region.GetLeader().GetStoreId(), Equals, i) + re.Equal(i, region.GetLeader().GetStoreId()) region = tc.RandFollowerRegion(i, []core.KeyRange{core.NewKeyRange("", "")}, schedule.IsRegionHealthy) - c.Assert(region.GetLeader().GetStoreId(), Not(Equals), i) + re.NotEqual(i, region.GetLeader().GetStoreId()) - c.Assert(region.GetStorePeer(i), NotNil) + re.NotNil(region.GetStorePeer(i)) } // check overlaps // clone it otherwise there are two items with the same key in the tree overlapRegion := regions[n-1].Clone(core.WithStartKey(regions[n-2].GetStartKey())) cache.SetRegion(overlapRegion) - c.Assert(cache.GetRegion(n-2), IsNil) - c.Assert(cache.GetRegion(n-1), NotNil) + re.Nil(cache.GetRegion(n - 2)) + re.NotNil(cache.GetRegion(n - 1)) // All regions will be filtered out if they have pending peers. for i := uint64(0); i < n; i++ { @@ -1562,71 +1634,36 @@ func (s *testRegionsInfoSuite) Test(c *C) { newRegion := region.Clone(core.WithPendingPeers(region.GetPeers())) cache.SetRegion(newRegion) } - c.Assert(tc.RandLeaderRegion(i, []core.KeyRange{core.NewKeyRange("", "")}, schedule.IsRegionHealthy), IsNil) + re.Nil(tc.RandLeaderRegion(i, []core.KeyRange{core.NewKeyRange("", "")}, schedule.IsRegionHealthy)) } for i := uint64(0); i < n; i++ { - c.Assert(tc.RandFollowerRegion(i, []core.KeyRange{core.NewKeyRange("", "")}, schedule.IsRegionHealthy), IsNil) + re.Nil(tc.RandFollowerRegion(i, []core.KeyRange{core.NewKeyRange("", "")}, schedule.IsRegionHealthy)) } } -var _ = Suite(&testClusterUtilSuite{}) - -type testClusterUtilSuite struct{} +func TestCheckStaleRegion(t *testing.T) { + re := require.New(t) -func (s *testClusterUtilSuite) TestCheckStaleRegion(c *C) { // (0, 0) v.s. (0, 0) region := core.NewTestRegionInfo([]byte{}, []byte{}) origin := core.NewTestRegionInfo([]byte{}, []byte{}) - c.Assert(checkStaleRegion(region.GetMeta(), origin.GetMeta()), IsNil) - c.Assert(checkStaleRegion(origin.GetMeta(), region.GetMeta()), IsNil) + re.NoError(checkStaleRegion(region.GetMeta(), origin.GetMeta())) + re.NoError(checkStaleRegion(origin.GetMeta(), region.GetMeta())) // (1, 0) v.s. (0, 0) region.GetRegionEpoch().Version++ - c.Assert(checkStaleRegion(origin.GetMeta(), region.GetMeta()), IsNil) - c.Assert(checkStaleRegion(region.GetMeta(), origin.GetMeta()), NotNil) + re.NoError(checkStaleRegion(origin.GetMeta(), region.GetMeta())) + re.Error(checkStaleRegion(region.GetMeta(), origin.GetMeta())) // (1, 1) v.s. (0, 0) region.GetRegionEpoch().ConfVer++ - c.Assert(checkStaleRegion(origin.GetMeta(), region.GetMeta()), IsNil) - c.Assert(checkStaleRegion(region.GetMeta(), origin.GetMeta()), NotNil) + re.NoError(checkStaleRegion(origin.GetMeta(), region.GetMeta())) + re.Error(checkStaleRegion(region.GetMeta(), origin.GetMeta())) // (0, 1) v.s. (0, 0) region.GetRegionEpoch().Version-- - c.Assert(checkStaleRegion(origin.GetMeta(), region.GetMeta()), IsNil) - c.Assert(checkStaleRegion(region.GetMeta(), origin.GetMeta()), NotNil) -} - -var _ = Suite(&testGetStoresSuite{}) - -type testGetStoresSuite struct { - ctx context.Context - cancel context.CancelFunc - cluster *RaftCluster -} - -func (s *testGetStoresSuite) TearDownTest(c *C) { - s.cancel() -} - -func (s *testGetStoresSuite) SetUpSuite(c *C) { - _, opt, err := newTestScheduleConfig() - c.Assert(err, IsNil) - s.ctx, s.cancel = context.WithCancel(context.Background()) - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - s.cluster = cluster - - stores := newTestStores(200, "2.0.0") - - for _, store := range stores { - c.Assert(s.cluster.putStoreLocked(store), IsNil) - } -} - -func (s *testGetStoresSuite) BenchmarkGetStores(c *C) { - for i := 0; i < c.N; i++ { - // Logic to benchmark - s.cluster.core.Stores.GetStores() - } + re.NoError(checkStaleRegion(origin.GetMeta(), region.GetMeta())) + re.Error(checkStaleRegion(region.GetMeta(), origin.GetMeta())) } type testCluster struct { @@ -1723,32 +1760,32 @@ func newTestRegionMeta(regionID uint64) *metapb.Region { } } -func checkRegion(c *C, a *core.RegionInfo, b *core.RegionInfo) { - c.Assert(a, DeepEquals, b) - c.Assert(a.GetMeta(), DeepEquals, b.GetMeta()) - c.Assert(a.GetLeader(), DeepEquals, b.GetLeader()) - c.Assert(a.GetPeers(), DeepEquals, b.GetPeers()) +func checkRegion(re *require.Assertions, a *core.RegionInfo, b *core.RegionInfo) { + re.Equal(b, a) + re.Equal(b.GetMeta(), a.GetMeta()) + re.Equal(b.GetLeader(), a.GetLeader()) + re.Equal(b.GetPeers(), a.GetPeers()) if len(a.GetDownPeers()) > 0 || len(b.GetDownPeers()) > 0 { - c.Assert(a.GetDownPeers(), DeepEquals, b.GetDownPeers()) + re.Equal(b.GetDownPeers(), a.GetDownPeers()) } if len(a.GetPendingPeers()) > 0 || len(b.GetPendingPeers()) > 0 { - c.Assert(a.GetPendingPeers(), DeepEquals, b.GetPendingPeers()) + re.Equal(b.GetPendingPeers(), a.GetPendingPeers()) } } -func checkRegionsKV(c *C, s storage.Storage, regions []*core.RegionInfo) { +func checkRegionsKV(re *require.Assertions, s storage.Storage, regions []*core.RegionInfo) { if s != nil { for _, region := range regions { var meta metapb.Region ok, err := s.LoadRegion(region.GetID(), &meta) - c.Assert(ok, IsTrue) - c.Assert(err, IsNil) - c.Assert(&meta, DeepEquals, region.GetMeta()) + re.True(ok) + re.NoError(err) + re.Equal(region.GetMeta(), &meta) } } } -func checkRegions(c *C, cache *core.RegionsInfo, regions []*core.RegionInfo) { +func checkRegions(re *require.Assertions, cache *core.RegionsInfo, regions []*core.RegionInfo) { regionCount := make(map[uint64]int) leaderCount := make(map[uint64]int) followerCount := make(map[uint64]int) @@ -1757,37 +1794,37 @@ func checkRegions(c *C, cache *core.RegionsInfo, regions []*core.RegionInfo) { regionCount[peer.StoreId]++ if peer.Id == region.GetLeader().Id { leaderCount[peer.StoreId]++ - checkRegion(c, cache.GetLeader(peer.StoreId, region), region) + checkRegion(re, cache.GetLeader(peer.StoreId, region), region) } else { followerCount[peer.StoreId]++ - checkRegion(c, cache.GetFollower(peer.StoreId, region), region) + checkRegion(re, cache.GetFollower(peer.StoreId, region), region) } } } - c.Assert(cache.GetRegionCount(), Equals, len(regions)) + re.Equal(len(regions), cache.GetRegionCount()) for id, count := range regionCount { - c.Assert(cache.GetStoreRegionCount(id), Equals, count) + re.Equal(count, cache.GetStoreRegionCount(id)) } for id, count := range leaderCount { - c.Assert(cache.GetStoreLeaderCount(id), Equals, count) + re.Equal(count, cache.GetStoreLeaderCount(id)) } for id, count := range followerCount { - c.Assert(cache.GetStoreFollowerCount(id), Equals, count) + re.Equal(count, cache.GetStoreFollowerCount(id)) } for _, region := range cache.GetRegions() { - checkRegion(c, region, regions[region.GetID()]) + checkRegion(re, region, regions[region.GetID()]) } for _, region := range cache.GetMetaRegions() { - c.Assert(region, DeepEquals, regions[region.GetId()].GetMeta()) + re.Equal(regions[region.GetId()].GetMeta(), region) } } -func checkPendingPeerCount(expect []int, cluster *RaftCluster, c *C) { +func checkPendingPeerCount(expect []int, cluster *RaftCluster, re *require.Assertions) { for i, e := range expect { s := cluster.core.Stores.GetStore(uint64(i + 1)) - c.Assert(s.GetPendingPeerCount(), Equals, e) + re.Equal(e, s.GetPendingPeerCount()) } } diff --git a/server/cluster/cluster_worker_test.go b/server/cluster/cluster_worker_test.go index b545ccd2d2c..8c747c6c847 100644 --- a/server/cluster/cluster_worker_test.go +++ b/server/cluster/cluster_worker_test.go @@ -16,47 +16,41 @@ package cluster import ( "context" + "testing" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/mock/mockid" "github.com/tikv/pd/server/core" _ "github.com/tikv/pd/server/schedulers" "github.com/tikv/pd/server/storage" ) -var _ = Suite(&testClusterWorkerSuite{}) +func TestReportSplit(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() -type testClusterWorkerSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *testClusterWorkerSuite) TearDownTest(c *C) { - s.cancel() -} - -func (s *testClusterWorkerSuite) SetUpTest(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) -} - -func (s *testClusterWorkerSuite) TestReportSplit(c *C) { _, opt, err := newTestScheduleConfig() - c.Assert(err, IsNil) - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + re.NoError(err) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) left := &metapb.Region{Id: 1, StartKey: []byte("a"), EndKey: []byte("b")} right := &metapb.Region{Id: 2, StartKey: []byte("b"), EndKey: []byte("c")} _, err = cluster.HandleReportSplit(&pdpb.ReportSplitRequest{Left: left, Right: right}) - c.Assert(err, IsNil) + re.NoError(err) _, err = cluster.HandleReportSplit(&pdpb.ReportSplitRequest{Left: right, Right: left}) - c.Assert(err, NotNil) + re.Error(err) } -func (s *testClusterWorkerSuite) TestReportBatchSplit(c *C) { +func TestReportBatchSplit(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, err := newTestScheduleConfig() - c.Assert(err, IsNil) - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + re.NoError(err) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) regions := []*metapb.Region{ {Id: 1, StartKey: []byte(""), EndKey: []byte("a")}, {Id: 2, StartKey: []byte("a"), EndKey: []byte("b")}, @@ -64,5 +58,5 @@ func (s *testClusterWorkerSuite) TestReportBatchSplit(c *C) { {Id: 3, StartKey: []byte("c"), EndKey: []byte("")}, } _, err = cluster.HandleBatchReportSplit(&pdpb.ReportBatchSplitRequest{Regions: regions}) - c.Assert(err, IsNil) + re.NoError(err) } diff --git a/server/cluster/coordinator_test.go b/server/cluster/coordinator_test.go index b234374a765..a7d34ccb558 100644 --- a/server/cluster/coordinator_test.go +++ b/server/cluster/coordinator_test.go @@ -17,16 +17,16 @@ package cluster import ( "context" "encoding/json" + "github.com/pingcap/failpoint" "math/rand" "sync" "testing" "time" - . "github.com/pingcap/check" - "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/eraftpb" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/mock/mockhbstream" "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/pkg/typeutil" @@ -149,95 +149,83 @@ func (c *testCluster) LoadRegion(regionID uint64, followerStoreIDs ...uint64) er return c.putRegion(core.NewRegionInfo(region, nil)) } -var _ = Suite(&testCoordinatorSuite{}) - -type testCoordinatorSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *testCoordinatorSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) - c.Assert(failpoint.Enable("github.com/tikv/pd/server/schedule/unexpectedOperator", "return(true)"), IsNil) -} - -func (s *testCoordinatorSuite) TearDownSuite(c *C) { - s.cancel() -} +func TestBasic(t *testing.T) { + re := require.New(t) -func (s *testCoordinatorSuite) TestBasic(c *C) { - tc, co, cleanup := prepare(nil, nil, nil, c) + tc, co, cleanup := prepare(nil, nil, nil, re) defer cleanup() oc := co.opController - c.Assert(tc.addLeaderRegion(1, 1), IsNil) + re.NoError(tc.addLeaderRegion(1, 1)) op1 := newTestOperator(1, tc.GetRegion(1).GetRegionEpoch(), operator.OpLeader) oc.AddWaitingOperator(op1) - c.Assert(oc.OperatorCount(operator.OpLeader), Equals, uint64(1)) - c.Assert(oc.GetOperator(1).RegionID(), Equals, op1.RegionID()) + re.Equal(uint64(1), oc.OperatorCount(operator.OpLeader)) + re.Equal(op1.RegionID(), oc.GetOperator(1).RegionID()) // Region 1 already has an operator, cannot add another one. op2 := newTestOperator(1, tc.GetRegion(1).GetRegionEpoch(), operator.OpRegion) oc.AddWaitingOperator(op2) - c.Assert(oc.OperatorCount(operator.OpRegion), Equals, uint64(0)) + re.Equal(uint64(0), oc.OperatorCount(operator.OpRegion)) // Remove the operator manually, then we can add a new operator. - c.Assert(oc.RemoveOperator(op1), IsTrue) + re.True(oc.RemoveOperator(op1)) op3 := newTestOperator(1, tc.GetRegion(1).GetRegionEpoch(), operator.OpRegion) oc.AddWaitingOperator(op3) - c.Assert(oc.OperatorCount(operator.OpRegion), Equals, uint64(1)) - c.Assert(oc.GetOperator(1).RegionID(), Equals, op3.RegionID()) + re.Equal(uint64(1), oc.OperatorCount(operator.OpRegion)) + re.Equal(op3.RegionID(), oc.GetOperator(1).RegionID()) } -func (s *testCoordinatorSuite) TestDispatch(c *C) { - tc, co, cleanup := prepare(nil, nil, nil, c) +func TestDispatch(t *testing.T) { + re := require.New(t) + + tc, co, cleanup := prepare(nil, nil, nil, re) defer cleanup() co.prepareChecker.prepared = true // Transfer peer from store 4 to store 1. - c.Assert(tc.addRegionStore(4, 40), IsNil) - c.Assert(tc.addRegionStore(3, 30), IsNil) - c.Assert(tc.addRegionStore(2, 20), IsNil) - c.Assert(tc.addRegionStore(1, 10), IsNil) - c.Assert(tc.addLeaderRegion(1, 2, 3, 4), IsNil) + re.NoError(tc.addRegionStore(4, 40)) + re.NoError(tc.addRegionStore(3, 30)) + re.NoError(tc.addRegionStore(2, 20)) + re.NoError(tc.addRegionStore(1, 10)) + re.NoError(tc.addLeaderRegion(1, 2, 3, 4)) // Transfer leader from store 4 to store 2. - c.Assert(tc.updateLeaderCount(4, 50), IsNil) - c.Assert(tc.updateLeaderCount(3, 50), IsNil) - c.Assert(tc.updateLeaderCount(2, 20), IsNil) - c.Assert(tc.updateLeaderCount(1, 10), IsNil) - c.Assert(tc.addLeaderRegion(2, 4, 3, 2), IsNil) + re.NoError(tc.updateLeaderCount(4, 50)) + re.NoError(tc.updateLeaderCount(3, 50)) + re.NoError(tc.updateLeaderCount(2, 20)) + re.NoError(tc.updateLeaderCount(1, 10)) + re.NoError(tc.addLeaderRegion(2, 4, 3, 2)) go co.runUntilStop() // Wait for schedule and turn off balance. - waitOperator(c, co, 1) - testutil.CheckTransferPeer(c, co.opController.GetOperator(1), operator.OpKind(0), 4, 1) - c.Assert(co.removeScheduler(schedulers.BalanceRegionName), IsNil) - waitOperator(c, co, 2) - testutil.CheckTransferLeader(c, co.opController.GetOperator(2), operator.OpKind(0), 4, 2) - c.Assert(co.removeScheduler(schedulers.BalanceLeaderName), IsNil) + waitOperator(re, co, 1) + testutil.CheckTransferPeerWithTestify(re, co.opController.GetOperator(1), operator.OpKind(0), 4, 1) + re.NoError(co.removeScheduler(schedulers.BalanceRegionName)) + waitOperator(re, co, 2) + testutil.CheckTransferLeaderWithTestify(re, co.opController.GetOperator(2), operator.OpKind(0), 4, 2) + re.NoError(co.removeScheduler(schedulers.BalanceLeaderName)) stream := mockhbstream.NewHeartbeatStream() // Transfer peer. region := tc.GetRegion(1).Clone() - c.Assert(dispatchHeartbeat(co, region, stream), IsNil) - region = waitAddLearner(c, stream, region, 1) - c.Assert(dispatchHeartbeat(co, region, stream), IsNil) - region = waitPromoteLearner(c, stream, region, 1) - c.Assert(dispatchHeartbeat(co, region, stream), IsNil) - region = waitRemovePeer(c, stream, region, 4) - c.Assert(dispatchHeartbeat(co, region, stream), IsNil) - c.Assert(dispatchHeartbeat(co, region, stream), IsNil) - waitNoResponse(c, stream) + re.NoError(dispatchHeartbeat(co, region, stream)) + region = waitAddLearner(re, stream, region, 1) + re.NoError(dispatchHeartbeat(co, region, stream)) + region = waitPromoteLearner(re, stream, region, 1) + re.NoError(dispatchHeartbeat(co, region, stream)) + region = waitRemovePeer(re, stream, region, 4) + re.NoError(dispatchHeartbeat(co, region, stream)) + re.NoError(dispatchHeartbeat(co, region, stream)) + waitNoResponse(re, stream) // Transfer leader. region = tc.GetRegion(2).Clone() - c.Assert(dispatchHeartbeat(co, region, stream), IsNil) - waitTransferLeader(c, stream, region, 2) - c.Assert(dispatchHeartbeat(co, region, stream), IsNil) - waitNoResponse(c, stream) + re.NoError(dispatchHeartbeat(co, region, stream)) + waitTransferLeader(re, stream, region, 2) + re.NoError(dispatchHeartbeat(co, region, stream)) + waitNoResponse(re, stream) } func dispatchHeartbeat(co *coordinator, region *core.RegionInfo, stream hbstream.HeartbeatStream) error { @@ -249,10 +237,12 @@ func dispatchHeartbeat(co *coordinator, region *core.RegionInfo, stream hbstream return nil } -func (s *testCoordinatorSuite) TestCollectMetrics(c *C) { +func TestCollectMetrics(t *testing.T) { + re := require.New(t) + tc, co, cleanup := prepare(nil, func(tc *testCluster) { tc.regionStats = statistics.NewRegionStatistics(tc.GetOpts(), nil, tc.storeConfigManager) - }, func(co *coordinator) { co.run() }, c) + }, func(co *coordinator) { co.run() }, re) defer cleanup() // Make sure there are no problem when concurrent write and read @@ -263,7 +253,7 @@ func (s *testCoordinatorSuite) TestCollectMetrics(c *C) { go func(i int) { defer wg.Done() for j := 0; j < 1000; j++ { - c.Assert(tc.addRegionStore(uint64(i%5), rand.Intn(200)), IsNil) + re.NoError(tc.addRegionStore(uint64(i%5), rand.Intn(200))) } }(i) } @@ -278,10 +268,10 @@ func (s *testCoordinatorSuite) TestCollectMetrics(c *C) { wg.Wait() } -func prepare(setCfg func(*config.ScheduleConfig), setTc func(*testCluster), run func(*coordinator), c *C) (*testCluster, *coordinator, func()) { +func prepare(setCfg func(*config.ScheduleConfig), setTc func(*testCluster), run func(*coordinator), re *require.Assertions) (*testCluster, *coordinator, func()) { ctx, cancel := context.WithCancel(context.Background()) cfg, opt, err := newTestScheduleConfig() - c.Assert(err, IsNil) + re.NoError(err) if setCfg != nil { setCfg(cfg) } @@ -302,28 +292,32 @@ func prepare(setCfg func(*config.ScheduleConfig), setTc func(*testCluster), run } } -func (s *testCoordinatorSuite) checkRegion(c *C, tc *testCluster, co *coordinator, regionID uint64, expectAddOperator int) { +func checkRegionAndOperator(re *require.Assertions, tc *testCluster, co *coordinator, regionID uint64, expectAddOperator int) { ops := co.checkers.CheckRegion(tc.GetRegion(regionID)) if ops == nil { - c.Assert(expectAddOperator, Equals, 0) + re.Equal(0, expectAddOperator) } else { - c.Assert(co.opController.AddWaitingOperator(ops...), Equals, expectAddOperator) + re.Equal(expectAddOperator, co.opController.AddWaitingOperator(ops...)) } } -func (s *testCoordinatorSuite) TestCheckRegion(c *C) { - tc, co, cleanup := prepare(nil, nil, nil, c) +func TestCheckRegion(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tc, co, cleanup := prepare(nil, nil, nil, re) hbStreams, opt := co.hbStreams, tc.opt defer cleanup() - c.Assert(tc.addRegionStore(4, 4), IsNil) - c.Assert(tc.addRegionStore(3, 3), IsNil) - c.Assert(tc.addRegionStore(2, 2), IsNil) - c.Assert(tc.addRegionStore(1, 1), IsNil) - c.Assert(tc.addLeaderRegion(1, 2, 3), IsNil) - s.checkRegion(c, tc, co, 1, 1) - testutil.CheckAddPeer(c, co.opController.GetOperator(1), operator.OpReplica, 1) - s.checkRegion(c, tc, co, 1, 0) + re.NoError(tc.addRegionStore(4, 4)) + re.NoError(tc.addRegionStore(3, 3)) + re.NoError(tc.addRegionStore(2, 2)) + re.NoError(tc.addRegionStore(1, 1)) + re.NoError(tc.addLeaderRegion(1, 2, 3)) + checkRegionAndOperator(re, tc, co, 1, 1) + testutil.CheckAddPeerWithTestify(re, co.opController.GetOperator(1), operator.OpReplica, 1) + checkRegionAndOperator(re, tc, co, 1, 0) r := tc.GetRegion(1) p := &metapb.Peer{Id: 1, StoreId: 1, Role: metapb.PeerRole_Learner} @@ -331,39 +325,41 @@ func (s *testCoordinatorSuite) TestCheckRegion(c *C) { core.WithAddPeer(p), core.WithPendingPeers(append(r.GetPendingPeers(), p)), ) - c.Assert(tc.putRegion(r), IsNil) - s.checkRegion(c, tc, co, 1, 0) - - tc = newTestCluster(s.ctx, opt) - co = newCoordinator(s.ctx, tc.RaftCluster, hbStreams) - - c.Assert(tc.addRegionStore(4, 4), IsNil) - c.Assert(tc.addRegionStore(3, 3), IsNil) - c.Assert(tc.addRegionStore(2, 2), IsNil) - c.Assert(tc.addRegionStore(1, 1), IsNil) - c.Assert(tc.putRegion(r), IsNil) - s.checkRegion(c, tc, co, 1, 0) + re.NoError(tc.putRegion(r)) + checkRegionAndOperator(re, tc, co, 1, 0) + + tc = newTestCluster(ctx, opt) + co = newCoordinator(ctx, tc.RaftCluster, hbStreams) + + re.NoError(tc.addRegionStore(4, 4)) + re.NoError(tc.addRegionStore(3, 3)) + re.NoError(tc.addRegionStore(2, 2)) + re.NoError(tc.addRegionStore(1, 1)) + re.NoError(tc.putRegion(r)) + checkRegionAndOperator(re, tc, co, 1, 0) r = r.Clone(core.WithPendingPeers(nil)) - c.Assert(tc.putRegion(r), IsNil) - s.checkRegion(c, tc, co, 1, 1) + re.NoError(tc.putRegion(r)) + checkRegionAndOperator(re, tc, co, 1, 1) op := co.opController.GetOperator(1) - c.Assert(op.Len(), Equals, 1) - c.Assert(op.Step(0).(operator.PromoteLearner).ToStore, Equals, uint64(1)) - s.checkRegion(c, tc, co, 1, 0) + re.Equal(1, op.Len()) + re.Equal(uint64(1), op.Step(0).(operator.PromoteLearner).ToStore) + checkRegionAndOperator(re, tc, co, 1, 0) } -func (s *testCoordinatorSuite) TestCheckRegionWithScheduleDeny(c *C) { - tc, co, cleanup := prepare(nil, nil, nil, c) +func TestCheckRegionWithScheduleDeny(t *testing.T) { + re := require.New(t) + + tc, co, cleanup := prepare(nil, nil, nil, re) defer cleanup() - c.Assert(tc.addRegionStore(4, 4), IsNil) - c.Assert(tc.addRegionStore(3, 3), IsNil) - c.Assert(tc.addRegionStore(2, 2), IsNil) - c.Assert(tc.addRegionStore(1, 1), IsNil) - c.Assert(tc.addLeaderRegion(1, 2, 3), IsNil) + re.NoError(tc.addRegionStore(4, 4)) + re.NoError(tc.addRegionStore(3, 3)) + re.NoError(tc.addRegionStore(2, 2)) + re.NoError(tc.addRegionStore(1, 1)) + re.NoError(tc.addLeaderRegion(1, 2, 3)) region := tc.GetRegion(1) - c.Assert(region, NotNil) + re.NotNil(region) // test with label schedule=deny labelerManager := tc.GetRegionLabeler() labelerManager.SetLabelRule(&labeler.LabelRule{ @@ -373,21 +369,23 @@ func (s *testCoordinatorSuite) TestCheckRegionWithScheduleDeny(c *C) { Data: []interface{}{map[string]interface{}{"start_key": "", "end_key": ""}}, }) - c.Assert(labelerManager.ScheduleDisabled(region), IsTrue) - s.checkRegion(c, tc, co, 1, 0) + re.True(labelerManager.ScheduleDisabled(region)) + checkRegionAndOperator(re, tc, co, 1, 0) labelerManager.DeleteLabelRule("schedulelabel") - c.Assert(labelerManager.ScheduleDisabled(region), IsFalse) - s.checkRegion(c, tc, co, 1, 1) + re.False(labelerManager.ScheduleDisabled(region)) + checkRegionAndOperator(re, tc, co, 1, 1) } -func (s *testCoordinatorSuite) TestCheckerIsBusy(c *C) { +func TestCheckerIsBusy(t *testing.T) { + re := require.New(t) + tc, co, cleanup := prepare(func(cfg *config.ScheduleConfig) { cfg.ReplicaScheduleLimit = 0 // ensure replica checker is busy cfg.MergeScheduleLimit = 10 - }, nil, nil, c) + }, nil, nil, re) defer cleanup() - c.Assert(tc.addRegionStore(1, 0), IsNil) + re.NoError(tc.addRegionStore(1, 0)) num := 1 + typeutil.MaxUint64(tc.opt.GetReplicaScheduleLimit(), tc.opt.GetMergeScheduleLimit()) var operatorKinds = []operator.OpKind{ operator.OpReplica, operator.OpRegion | operator.OpMerge, @@ -395,50 +393,52 @@ func (s *testCoordinatorSuite) TestCheckerIsBusy(c *C) { for i, operatorKind := range operatorKinds { for j := uint64(0); j < num; j++ { regionID := j + uint64(i+1)*num - c.Assert(tc.addLeaderRegion(regionID, 1), IsNil) + re.NoError(tc.addLeaderRegion(regionID, 1)) switch operatorKind { case operator.OpReplica: op := newTestOperator(regionID, tc.GetRegion(regionID).GetRegionEpoch(), operatorKind) - c.Assert(co.opController.AddWaitingOperator(op), Equals, 1) + re.Equal(1, co.opController.AddWaitingOperator(op)) case operator.OpRegion | operator.OpMerge: if regionID%2 == 1 { ops, err := operator.CreateMergeRegionOperator("merge-region", co.cluster, tc.GetRegion(regionID), tc.GetRegion(regionID-1), operator.OpMerge) - c.Assert(err, IsNil) - c.Assert(co.opController.AddWaitingOperator(ops...), Equals, len(ops)) + re.NoError(err) + re.Len(ops, co.opController.AddWaitingOperator(ops...)) } } } } - s.checkRegion(c, tc, co, num, 0) + checkRegionAndOperator(re, tc, co, num, 0) } -func (s *testCoordinatorSuite) TestReplica(c *C) { +func TestReplica(t *testing.T) { + re := require.New(t) + tc, co, cleanup := prepare(func(cfg *config.ScheduleConfig) { // Turn off balance. cfg.LeaderScheduleLimit = 0 cfg.RegionScheduleLimit = 0 - }, nil, func(co *coordinator) { co.run() }, c) + }, nil, func(co *coordinator) { co.run() }, re) defer cleanup() - c.Assert(tc.addRegionStore(1, 1), IsNil) - c.Assert(tc.addRegionStore(2, 2), IsNil) - c.Assert(tc.addRegionStore(3, 3), IsNil) - c.Assert(tc.addRegionStore(4, 4), IsNil) + re.NoError(tc.addRegionStore(1, 1)) + re.NoError(tc.addRegionStore(2, 2)) + re.NoError(tc.addRegionStore(3, 3)) + re.NoError(tc.addRegionStore(4, 4)) stream := mockhbstream.NewHeartbeatStream() // Add peer to store 1. - c.Assert(tc.addLeaderRegion(1, 2, 3), IsNil) + re.NoError(tc.addLeaderRegion(1, 2, 3)) region := tc.GetRegion(1) - c.Assert(dispatchHeartbeat(co, region, stream), IsNil) - region = waitAddLearner(c, stream, region, 1) - c.Assert(dispatchHeartbeat(co, region, stream), IsNil) - region = waitPromoteLearner(c, stream, region, 1) - c.Assert(dispatchHeartbeat(co, region, stream), IsNil) - waitNoResponse(c, stream) + re.NoError(dispatchHeartbeat(co, region, stream)) + region = waitAddLearner(re, stream, region, 1) + re.NoError(dispatchHeartbeat(co, region, stream)) + region = waitPromoteLearner(re, stream, region, 1) + re.NoError(dispatchHeartbeat(co, region, stream)) + waitNoResponse(re, stream) // Peer in store 3 is down, remove peer in store 3 and add peer to store 4. - c.Assert(tc.setStoreDown(3), IsNil) + re.NoError(tc.setStoreDown(3)) downPeer := &pdpb.PeerStats{ Peer: region.GetStorePeer(3), DownSeconds: 24 * 60 * 60, @@ -446,50 +446,52 @@ func (s *testCoordinatorSuite) TestReplica(c *C) { region = region.Clone( core.WithDownPeers(append(region.GetDownPeers(), downPeer)), ) - c.Assert(dispatchHeartbeat(co, region, stream), IsNil) - region = waitAddLearner(c, stream, region, 4) - c.Assert(dispatchHeartbeat(co, region, stream), IsNil) - region = waitPromoteLearner(c, stream, region, 4) + re.NoError(dispatchHeartbeat(co, region, stream)) + region = waitAddLearner(re, stream, region, 4) + re.NoError(dispatchHeartbeat(co, region, stream)) + region = waitPromoteLearner(re, stream, region, 4) region = region.Clone(core.WithDownPeers(nil)) - c.Assert(dispatchHeartbeat(co, region, stream), IsNil) - waitNoResponse(c, stream) + re.NoError(dispatchHeartbeat(co, region, stream)) + waitNoResponse(re, stream) // Remove peer from store 4. - c.Assert(tc.addLeaderRegion(2, 1, 2, 3, 4), IsNil) + re.NoError(tc.addLeaderRegion(2, 1, 2, 3, 4)) region = tc.GetRegion(2) - c.Assert(dispatchHeartbeat(co, region, stream), IsNil) - region = waitRemovePeer(c, stream, region, 4) - c.Assert(dispatchHeartbeat(co, region, stream), IsNil) - waitNoResponse(c, stream) + re.NoError(dispatchHeartbeat(co, region, stream)) + region = waitRemovePeer(re, stream, region, 4) + re.NoError(dispatchHeartbeat(co, region, stream)) + waitNoResponse(re, stream) // Remove offline peer directly when it's pending. - c.Assert(tc.addLeaderRegion(3, 1, 2, 3), IsNil) - c.Assert(tc.setStoreOffline(3), IsNil) + re.NoError(tc.addLeaderRegion(3, 1, 2, 3)) + re.NoError(tc.setStoreOffline(3)) region = tc.GetRegion(3) region = region.Clone(core.WithPendingPeers([]*metapb.Peer{region.GetStorePeer(3)})) - c.Assert(dispatchHeartbeat(co, region, stream), IsNil) - waitNoResponse(c, stream) + re.NoError(dispatchHeartbeat(co, region, stream)) + waitNoResponse(re, stream) } -func (s *testCoordinatorSuite) TestCheckCache(c *C) { +func TestCheckCache(t *testing.T) { + re := require.New(t) + tc, co, cleanup := prepare(func(cfg *config.ScheduleConfig) { // Turn off replica scheduling. cfg.ReplicaScheduleLimit = 0 - }, nil, nil, c) + }, nil, nil, re) defer cleanup() - c.Assert(tc.addRegionStore(1, 0), IsNil) - c.Assert(tc.addRegionStore(2, 0), IsNil) - c.Assert(tc.addRegionStore(3, 0), IsNil) + re.NoError(tc.addRegionStore(1, 0)) + re.NoError(tc.addRegionStore(2, 0)) + re.NoError(tc.addRegionStore(3, 0)) // Add a peer with two replicas. - c.Assert(tc.addLeaderRegion(1, 2, 3), IsNil) - c.Assert(failpoint.Enable("github.com/tikv/pd/server/cluster/break-patrol", `return`), IsNil) + re.NoError(tc.addLeaderRegion(1, 2, 3)) + re.NoError(failpoint.Enable("github.com/tikv/pd/server/cluster/break-patrol", `return`)) // case 1: operator cannot be created due to replica-schedule-limit restriction co.wg.Add(1) co.patrolRegions() - c.Assert(co.checkers.GetWaitingRegions(), HasLen, 1) + re.Len(co.checkers.GetWaitingRegions(), 1) // cancel the replica-schedule-limit restriction opt := tc.GetOpts() @@ -499,88 +501,92 @@ func (s *testCoordinatorSuite) TestCheckCache(c *C) { co.wg.Add(1) co.patrolRegions() oc := co.opController - c.Assert(oc.GetOperators(), HasLen, 1) - c.Assert(co.checkers.GetWaitingRegions(), HasLen, 0) + re.Len(oc.GetOperators(), 1) + re.Len(co.checkers.GetWaitingRegions(), 0) // case 2: operator cannot be created due to store limit restriction oc.RemoveOperator(oc.GetOperator(1)) tc.SetStoreLimit(1, storelimit.AddPeer, 0) co.wg.Add(1) co.patrolRegions() - c.Assert(co.checkers.GetWaitingRegions(), HasLen, 1) + re.Len(co.checkers.GetWaitingRegions(), 1) // cancel the store limit restriction tc.SetStoreLimit(1, storelimit.AddPeer, 10) time.Sleep(1 * time.Second) co.wg.Add(1) co.patrolRegions() - c.Assert(oc.GetOperators(), HasLen, 1) - c.Assert(co.checkers.GetWaitingRegions(), HasLen, 0) + re.Len(oc.GetOperators(), 1) + re.Len(co.checkers.GetWaitingRegions(), 0) co.wg.Wait() - c.Assert(failpoint.Disable("github.com/tikv/pd/server/cluster/break-patrol"), IsNil) + re.NoError(failpoint.Disable("github.com/tikv/pd/server/cluster/break-patrol")) } -func (s *testCoordinatorSuite) TestPeerState(c *C) { - tc, co, cleanup := prepare(nil, nil, func(co *coordinator) { co.run() }, c) +func TestPeerState(t *testing.T) { + re := require.New(t) + + tc, co, cleanup := prepare(nil, nil, func(co *coordinator) { co.run() }, re) defer cleanup() // Transfer peer from store 4 to store 1. - c.Assert(tc.addRegionStore(1, 10), IsNil) - c.Assert(tc.addRegionStore(2, 10), IsNil) - c.Assert(tc.addRegionStore(3, 10), IsNil) - c.Assert(tc.addRegionStore(4, 40), IsNil) - c.Assert(tc.addLeaderRegion(1, 2, 3, 4), IsNil) + re.NoError(tc.addRegionStore(1, 10)) + re.NoError(tc.addRegionStore(2, 10)) + re.NoError(tc.addRegionStore(3, 10)) + re.NoError(tc.addRegionStore(4, 40)) + re.NoError(tc.addLeaderRegion(1, 2, 3, 4)) stream := mockhbstream.NewHeartbeatStream() // Wait for schedule. - waitOperator(c, co, 1) - testutil.CheckTransferPeer(c, co.opController.GetOperator(1), operator.OpKind(0), 4, 1) + waitOperator(re, co, 1) + testutil.CheckTransferPeerWithTestify(re, co.opController.GetOperator(1), operator.OpKind(0), 4, 1) region := tc.GetRegion(1).Clone() // Add new peer. - c.Assert(dispatchHeartbeat(co, region, stream), IsNil) - region = waitAddLearner(c, stream, region, 1) - c.Assert(dispatchHeartbeat(co, region, stream), IsNil) - region = waitPromoteLearner(c, stream, region, 1) + re.NoError(dispatchHeartbeat(co, region, stream)) + region = waitAddLearner(re, stream, region, 1) + re.NoError(dispatchHeartbeat(co, region, stream)) + region = waitPromoteLearner(re, stream, region, 1) // If the new peer is pending, the operator will not finish. region = region.Clone(core.WithPendingPeers(append(region.GetPendingPeers(), region.GetStorePeer(1)))) - c.Assert(dispatchHeartbeat(co, region, stream), IsNil) - waitNoResponse(c, stream) - c.Assert(co.opController.GetOperator(region.GetID()), NotNil) + re.NoError(dispatchHeartbeat(co, region, stream)) + waitNoResponse(re, stream) + re.NotNil(co.opController.GetOperator(region.GetID())) // The new peer is not pending now, the operator will finish. // And we will proceed to remove peer in store 4. region = region.Clone(core.WithPendingPeers(nil)) - c.Assert(dispatchHeartbeat(co, region, stream), IsNil) - waitRemovePeer(c, stream, region, 4) - c.Assert(tc.addLeaderRegion(1, 1, 2, 3), IsNil) + re.NoError(dispatchHeartbeat(co, region, stream)) + waitRemovePeer(re, stream, region, 4) + re.NoError(tc.addLeaderRegion(1, 1, 2, 3)) region = tc.GetRegion(1).Clone() - c.Assert(dispatchHeartbeat(co, region, stream), IsNil) - waitNoResponse(c, stream) + re.NoError(dispatchHeartbeat(co, region, stream)) + waitNoResponse(re, stream) } -func (s *testCoordinatorSuite) TestShouldRun(c *C) { - tc, co, cleanup := prepare(nil, nil, nil, c) +func TestShouldRun(t *testing.T) { + re := require.New(t) + + tc, co, cleanup := prepare(nil, nil, nil, re) tc.RaftCluster.coordinator = co defer cleanup() - c.Assert(tc.addLeaderStore(1, 5), IsNil) - c.Assert(tc.addLeaderStore(2, 2), IsNil) - c.Assert(tc.addLeaderStore(3, 0), IsNil) - c.Assert(tc.addLeaderStore(4, 0), IsNil) - c.Assert(tc.LoadRegion(1, 1, 2, 3), IsNil) - c.Assert(tc.LoadRegion(2, 1, 2, 3), IsNil) - c.Assert(tc.LoadRegion(3, 1, 2, 3), IsNil) - c.Assert(tc.LoadRegion(4, 1, 2, 3), IsNil) - c.Assert(tc.LoadRegion(5, 1, 2, 3), IsNil) - c.Assert(tc.LoadRegion(6, 2, 1, 4), IsNil) - c.Assert(tc.LoadRegion(7, 2, 1, 4), IsNil) - c.Assert(co.shouldRun(), IsFalse) - c.Assert(tc.core.Regions.GetStoreRegionCount(4), Equals, 2) + re.NoError(tc.addLeaderStore(1, 5)) + re.NoError(tc.addLeaderStore(2, 2)) + re.NoError(tc.addLeaderStore(3, 0)) + re.NoError(tc.addLeaderStore(4, 0)) + re.NoError(tc.LoadRegion(1, 1, 2, 3)) + re.NoError(tc.LoadRegion(2, 1, 2, 3)) + re.NoError(tc.LoadRegion(3, 1, 2, 3)) + re.NoError(tc.LoadRegion(4, 1, 2, 3)) + re.NoError(tc.LoadRegion(5, 1, 2, 3)) + re.NoError(tc.LoadRegion(6, 2, 1, 4)) + re.NoError(tc.LoadRegion(7, 2, 1, 4)) + re.False(co.shouldRun()) + re.Equal(2, tc.core.Regions.GetStoreRegionCount(4)) tbl := []struct { regionID uint64 @@ -599,28 +605,30 @@ func (s *testCoordinatorSuite) TestShouldRun(c *C) { for _, t := range tbl { r := tc.GetRegion(t.regionID) nr := r.Clone(core.WithLeader(r.GetPeers()[0])) - c.Assert(tc.processRegionHeartbeat(nr), IsNil) - c.Assert(co.shouldRun(), Equals, t.shouldRun) + re.NoError(tc.processRegionHeartbeat(nr)) + re.Equal(t.shouldRun, co.shouldRun()) } nr := &metapb.Region{Id: 6, Peers: []*metapb.Peer{}} newRegion := core.NewRegionInfo(nr, nil) - c.Assert(tc.processRegionHeartbeat(newRegion), NotNil) - c.Assert(co.prepareChecker.sum, Equals, 7) + re.Error(tc.processRegionHeartbeat(newRegion)) + re.Equal(7, co.prepareChecker.sum) } -func (s *testCoordinatorSuite) TestShouldRunWithNonLeaderRegions(c *C) { - tc, co, cleanup := prepare(nil, nil, nil, c) +func TestShouldRunWithNonLeaderRegions(t *testing.T) { + re := require.New(t) + + tc, co, cleanup := prepare(nil, nil, nil, re) tc.RaftCluster.coordinator = co defer cleanup() - c.Assert(tc.addLeaderStore(1, 10), IsNil) - c.Assert(tc.addLeaderStore(2, 0), IsNil) - c.Assert(tc.addLeaderStore(3, 0), IsNil) + re.NoError(tc.addLeaderStore(1, 10)) + re.NoError(tc.addLeaderStore(2, 0)) + re.NoError(tc.addLeaderStore(3, 0)) for i := 0; i < 10; i++ { - c.Assert(tc.LoadRegion(uint64(i+1), 1, 2, 3), IsNil) + re.NoError(tc.LoadRegion(uint64(i+1), 1, 2, 3)) } - c.Assert(co.shouldRun(), IsFalse) - c.Assert(tc.core.Regions.GetStoreRegionCount(1), Equals, 10) + re.False(co.shouldRun()) + re.Equal(10, tc.core.Regions.GetStoreRegionCount(1)) tbl := []struct { regionID uint64 @@ -640,289 +648,307 @@ func (s *testCoordinatorSuite) TestShouldRunWithNonLeaderRegions(c *C) { for _, t := range tbl { r := tc.GetRegion(t.regionID) nr := r.Clone(core.WithLeader(r.GetPeers()[0])) - c.Assert(tc.processRegionHeartbeat(nr), IsNil) - c.Assert(co.shouldRun(), Equals, t.shouldRun) + re.NoError(tc.processRegionHeartbeat(nr)) + re.Equal(t.shouldRun, co.shouldRun()) } nr := &metapb.Region{Id: 9, Peers: []*metapb.Peer{}} newRegion := core.NewRegionInfo(nr, nil) - c.Assert(tc.processRegionHeartbeat(newRegion), NotNil) - c.Assert(co.prepareChecker.sum, Equals, 9) + re.Error(tc.processRegionHeartbeat(newRegion)) + re.Equal(9, co.prepareChecker.sum) // Now, after server is prepared, there exist some regions with no leader. - c.Assert(tc.GetRegion(10).GetLeader().GetStoreId(), Equals, uint64(0)) + re.Equal(uint64(0), tc.GetRegion(10).GetLeader().GetStoreId()) } -func (s *testCoordinatorSuite) TestAddScheduler(c *C) { - tc, co, cleanup := prepare(nil, nil, func(co *coordinator) { co.run() }, c) +func TestAddScheduler(t *testing.T) { + re := require.New(t) + + tc, co, cleanup := prepare(nil, nil, func(co *coordinator) { co.run() }, re) defer cleanup() - c.Assert(co.schedulers, HasLen, len(config.DefaultSchedulers)) - c.Assert(co.removeScheduler(schedulers.BalanceLeaderName), IsNil) - c.Assert(co.removeScheduler(schedulers.BalanceRegionName), IsNil) - c.Assert(co.removeScheduler(schedulers.HotRegionName), IsNil) - c.Assert(co.removeScheduler(schedulers.SplitBucketName), IsNil) - c.Assert(co.schedulers, HasLen, 0) + re.Len(co.schedulers, len(config.DefaultSchedulers)) + re.NoError(co.removeScheduler(schedulers.BalanceLeaderName)) + re.NoError(co.removeScheduler(schedulers.BalanceRegionName)) + re.NoError(co.removeScheduler(schedulers.HotRegionName)) + re.NoError(co.removeScheduler(schedulers.SplitBucketName)) + re.Len(co.schedulers, 0) stream := mockhbstream.NewHeartbeatStream() // Add stores 1,2,3 - c.Assert(tc.addLeaderStore(1, 1), IsNil) - c.Assert(tc.addLeaderStore(2, 1), IsNil) - c.Assert(tc.addLeaderStore(3, 1), IsNil) + re.NoError(tc.addLeaderStore(1, 1)) + re.NoError(tc.addLeaderStore(2, 1)) + re.NoError(tc.addLeaderStore(3, 1)) // Add regions 1 with leader in store 1 and followers in stores 2,3 - c.Assert(tc.addLeaderRegion(1, 1, 2, 3), IsNil) + re.NoError(tc.addLeaderRegion(1, 1, 2, 3)) // Add regions 2 with leader in store 2 and followers in stores 1,3 - c.Assert(tc.addLeaderRegion(2, 2, 1, 3), IsNil) + re.NoError(tc.addLeaderRegion(2, 2, 1, 3)) // Add regions 3 with leader in store 3 and followers in stores 1,2 - c.Assert(tc.addLeaderRegion(3, 3, 1, 2), IsNil) + re.NoError(tc.addLeaderRegion(3, 3, 1, 2)) oc := co.opController // test ConfigJSONDecoder create bl, err := schedule.CreateScheduler(schedulers.BalanceLeaderType, oc, storage.NewStorageWithMemoryBackend(), schedule.ConfigJSONDecoder([]byte("{}"))) - c.Assert(err, IsNil) + re.NoError(err) conf, err := bl.EncodeConfig() - c.Assert(err, IsNil) + re.NoError(err) data := make(map[string]interface{}) err = json.Unmarshal(conf, &data) - c.Assert(err, IsNil) + re.NoError(err) batch := data["batch"].(float64) - c.Assert(int(batch), Equals, 4) + re.Equal(4, int(batch)) gls, err := schedule.CreateScheduler(schedulers.GrantLeaderType, oc, storage.NewStorageWithMemoryBackend(), schedule.ConfigSliceDecoder(schedulers.GrantLeaderType, []string{"0"})) - c.Assert(err, IsNil) - c.Assert(co.addScheduler(gls), NotNil) - c.Assert(co.removeScheduler(gls.GetName()), NotNil) + re.NoError(err) + re.NotNil(co.addScheduler(gls)) + re.NotNil(co.removeScheduler(gls.GetName())) gls, err = schedule.CreateScheduler(schedulers.GrantLeaderType, oc, storage.NewStorageWithMemoryBackend(), schedule.ConfigSliceDecoder(schedulers.GrantLeaderType, []string{"1"})) - c.Assert(err, IsNil) - c.Assert(co.addScheduler(gls), IsNil) + re.NoError(err) + re.NoError(co.addScheduler(gls)) // Transfer all leaders to store 1. - waitOperator(c, co, 2) + waitOperator(re, co, 2) region2 := tc.GetRegion(2) - c.Assert(dispatchHeartbeat(co, region2, stream), IsNil) - region2 = waitTransferLeader(c, stream, region2, 1) - c.Assert(dispatchHeartbeat(co, region2, stream), IsNil) - waitNoResponse(c, stream) + re.NoError(dispatchHeartbeat(co, region2, stream)) + region2 = waitTransferLeader(re, stream, region2, 1) + re.NoError(dispatchHeartbeat(co, region2, stream)) + waitNoResponse(re, stream) - waitOperator(c, co, 3) + waitOperator(re, co, 3) region3 := tc.GetRegion(3) - c.Assert(dispatchHeartbeat(co, region3, stream), IsNil) - region3 = waitTransferLeader(c, stream, region3, 1) - c.Assert(dispatchHeartbeat(co, region3, stream), IsNil) - waitNoResponse(c, stream) + re.NoError(dispatchHeartbeat(co, region3, stream)) + region3 = waitTransferLeader(re, stream, region3, 1) + re.NoError(dispatchHeartbeat(co, region3, stream)) + waitNoResponse(re, stream) } -func (s *testCoordinatorSuite) TestPersistScheduler(c *C) { - tc, co, cleanup := prepare(nil, nil, func(co *coordinator) { co.run() }, c) +func TestPersistScheduler(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tc, co, cleanup := prepare(nil, nil, func(co *coordinator) { co.run() }, re) hbStreams := co.hbStreams defer cleanup() // Add stores 1,2 - c.Assert(tc.addLeaderStore(1, 1), IsNil) - c.Assert(tc.addLeaderStore(2, 1), IsNil) + re.NoError(tc.addLeaderStore(1, 1)) + re.NoError(tc.addLeaderStore(2, 1)) - c.Assert(co.schedulers, HasLen, 4) + re.Len(co.schedulers, 4) oc := co.opController storage := tc.RaftCluster.storage gls1, err := schedule.CreateScheduler(schedulers.GrantLeaderType, oc, storage, schedule.ConfigSliceDecoder(schedulers.GrantLeaderType, []string{"1"})) - c.Assert(err, IsNil) - c.Assert(co.addScheduler(gls1, "1"), IsNil) + re.NoError(err) + re.NoError(co.addScheduler(gls1, "1")) evict, err := schedule.CreateScheduler(schedulers.EvictLeaderType, oc, storage, schedule.ConfigSliceDecoder(schedulers.EvictLeaderType, []string{"2"})) - c.Assert(err, IsNil) - c.Assert(co.addScheduler(evict, "2"), IsNil) - c.Assert(co.schedulers, HasLen, 6) + re.NoError(err) + re.NoError(co.addScheduler(evict, "2")) + re.Len(co.schedulers, 6) sches, _, err := storage.LoadAllScheduleConfig() - c.Assert(err, IsNil) - c.Assert(sches, HasLen, 6) - c.Assert(co.removeScheduler(schedulers.BalanceLeaderName), IsNil) - c.Assert(co.removeScheduler(schedulers.BalanceRegionName), IsNil) - c.Assert(co.removeScheduler(schedulers.HotRegionName), IsNil) - c.Assert(co.removeScheduler(schedulers.SplitBucketName), IsNil) - c.Assert(co.schedulers, HasLen, 2) - c.Assert(co.cluster.opt.Persist(storage), IsNil) + re.NoError(err) + re.Len(sches, 6) + re.NoError(co.removeScheduler(schedulers.BalanceLeaderName)) + re.NoError(co.removeScheduler(schedulers.BalanceRegionName)) + re.NoError(co.removeScheduler(schedulers.HotRegionName)) + re.NoError(co.removeScheduler(schedulers.SplitBucketName)) + re.Len(co.schedulers, 2) + re.NoError(co.cluster.opt.Persist(storage)) co.stop() co.wg.Wait() // make a new coordinator for testing // whether the schedulers added or removed in dynamic way are recorded in opt _, newOpt, err := newTestScheduleConfig() - c.Assert(err, IsNil) + re.NoError(err) _, err = schedule.CreateScheduler(schedulers.ShuffleRegionType, oc, storage, schedule.ConfigJSONDecoder([]byte("null"))) - c.Assert(err, IsNil) + re.NoError(err) // suppose we add a new default enable scheduler config.DefaultSchedulers = append(config.DefaultSchedulers, config.SchedulerConfig{Type: "shuffle-region"}) defer func() { config.DefaultSchedulers = config.DefaultSchedulers[:len(config.DefaultSchedulers)-1] }() - c.Assert(newOpt.GetSchedulers(), HasLen, 4) - c.Assert(newOpt.Reload(storage), IsNil) + re.Len(newOpt.GetSchedulers(), 4) + re.NoError(newOpt.Reload(storage)) // only remains 3 items with independent config. sches, _, err = storage.LoadAllScheduleConfig() - c.Assert(err, IsNil) - c.Assert(sches, HasLen, 3) + re.NoError(err) + re.Len(sches, 3) // option have 6 items because the default scheduler do not remove. - c.Assert(newOpt.GetSchedulers(), HasLen, 7) - c.Assert(newOpt.Persist(storage), IsNil) + re.Len(newOpt.GetSchedulers(), 7) + re.NoError(newOpt.Persist(storage)) tc.RaftCluster.opt = newOpt - co = newCoordinator(s.ctx, tc.RaftCluster, hbStreams) + co = newCoordinator(ctx, tc.RaftCluster, hbStreams) co.run() - c.Assert(co.schedulers, HasLen, 3) + re.Len(co.schedulers, 3) co.stop() co.wg.Wait() // suppose restart PD again _, newOpt, err = newTestScheduleConfig() - c.Assert(err, IsNil) - c.Assert(newOpt.Reload(storage), IsNil) + re.NoError(err) + re.NoError(newOpt.Reload(storage)) tc.RaftCluster.opt = newOpt - co = newCoordinator(s.ctx, tc.RaftCluster, hbStreams) + co = newCoordinator(ctx, tc.RaftCluster, hbStreams) co.run() - c.Assert(co.schedulers, HasLen, 3) + re.Len(co.schedulers, 3) bls, err := schedule.CreateScheduler(schedulers.BalanceLeaderType, oc, storage, schedule.ConfigSliceDecoder(schedulers.BalanceLeaderType, []string{"", ""})) - c.Assert(err, IsNil) - c.Assert(co.addScheduler(bls), IsNil) + re.NoError(err) + re.NoError(co.addScheduler(bls)) brs, err := schedule.CreateScheduler(schedulers.BalanceRegionType, oc, storage, schedule.ConfigSliceDecoder(schedulers.BalanceRegionType, []string{"", ""})) - c.Assert(err, IsNil) - c.Assert(co.addScheduler(brs), IsNil) - c.Assert(co.schedulers, HasLen, 5) + re.NoError(err) + re.NoError(co.addScheduler(brs)) + re.Len(co.schedulers, 5) // the scheduler option should contain 6 items // the `hot scheduler` are disabled - c.Assert(co.cluster.opt.GetSchedulers(), HasLen, 7) - c.Assert(co.removeScheduler(schedulers.GrantLeaderName), IsNil) + re.Len(co.cluster.opt.GetSchedulers(), 7) + re.NoError(co.removeScheduler(schedulers.GrantLeaderName)) // the scheduler that is not enable by default will be completely deleted - c.Assert(co.cluster.opt.GetSchedulers(), HasLen, 6) - c.Assert(co.schedulers, HasLen, 4) - c.Assert(co.cluster.opt.Persist(co.cluster.storage), IsNil) + re.Len(co.cluster.opt.GetSchedulers(), 6) + re.Len(co.schedulers, 4) + re.NoError(co.cluster.opt.Persist(co.cluster.storage)) co.stop() co.wg.Wait() _, newOpt, err = newTestScheduleConfig() - c.Assert(err, IsNil) - c.Assert(newOpt.Reload(co.cluster.storage), IsNil) + re.NoError(err) + re.NoError(newOpt.Reload(co.cluster.storage)) tc.RaftCluster.opt = newOpt - co = newCoordinator(s.ctx, tc.RaftCluster, hbStreams) + co = newCoordinator(ctx, tc.RaftCluster, hbStreams) co.run() - c.Assert(co.schedulers, HasLen, 4) - c.Assert(co.removeScheduler(schedulers.EvictLeaderName), IsNil) - c.Assert(co.schedulers, HasLen, 3) + re.Len(co.schedulers, 4) + re.NoError(co.removeScheduler(schedulers.EvictLeaderName)) + re.Len(co.schedulers, 3) } -func (s *testCoordinatorSuite) TestRemoveScheduler(c *C) { +func TestRemoveScheduler(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tc, co, cleanup := prepare(func(cfg *config.ScheduleConfig) { cfg.ReplicaScheduleLimit = 0 - }, nil, func(co *coordinator) { co.run() }, c) + }, nil, func(co *coordinator) { co.run() }, re) hbStreams := co.hbStreams defer cleanup() // Add stores 1,2 - c.Assert(tc.addLeaderStore(1, 1), IsNil) - c.Assert(tc.addLeaderStore(2, 1), IsNil) + re.NoError(tc.addLeaderStore(1, 1)) + re.NoError(tc.addLeaderStore(2, 1)) - c.Assert(co.schedulers, HasLen, 4) + re.Len(co.schedulers, 4) oc := co.opController storage := tc.RaftCluster.storage gls1, err := schedule.CreateScheduler(schedulers.GrantLeaderType, oc, storage, schedule.ConfigSliceDecoder(schedulers.GrantLeaderType, []string{"1"})) - c.Assert(err, IsNil) - c.Assert(co.addScheduler(gls1, "1"), IsNil) - c.Assert(co.schedulers, HasLen, 5) + re.NoError(err) + re.NoError(co.addScheduler(gls1, "1")) + re.Len(co.schedulers, 5) sches, _, err := storage.LoadAllScheduleConfig() - c.Assert(err, IsNil) - c.Assert(sches, HasLen, 5) + re.NoError(err) + re.Len(sches, 5) // remove all schedulers - c.Assert(co.removeScheduler(schedulers.BalanceLeaderName), IsNil) - c.Assert(co.removeScheduler(schedulers.BalanceRegionName), IsNil) - c.Assert(co.removeScheduler(schedulers.HotRegionName), IsNil) - c.Assert(co.removeScheduler(schedulers.GrantLeaderName), IsNil) - c.Assert(co.removeScheduler(schedulers.SplitBucketName), IsNil) + re.NoError(co.removeScheduler(schedulers.BalanceLeaderName)) + re.NoError(co.removeScheduler(schedulers.BalanceRegionName)) + re.NoError(co.removeScheduler(schedulers.HotRegionName)) + re.NoError(co.removeScheduler(schedulers.GrantLeaderName)) + re.NoError(co.removeScheduler(schedulers.SplitBucketName)) // all removed sches, _, err = storage.LoadAllScheduleConfig() - c.Assert(err, IsNil) - c.Assert(sches, HasLen, 0) - c.Assert(co.schedulers, HasLen, 0) - c.Assert(co.cluster.opt.Persist(co.cluster.storage), IsNil) + re.NoError(err) + re.Len(sches, 0) + re.Len(co.schedulers, 0) + re.NoError(co.cluster.opt.Persist(co.cluster.storage)) co.stop() co.wg.Wait() // suppose restart PD again _, newOpt, err := newTestScheduleConfig() - c.Assert(err, IsNil) - c.Assert(newOpt.Reload(tc.storage), IsNil) + re.NoError(err) + re.NoError(newOpt.Reload(tc.storage)) tc.RaftCluster.opt = newOpt - co = newCoordinator(s.ctx, tc.RaftCluster, hbStreams) + co = newCoordinator(ctx, tc.RaftCluster, hbStreams) co.run() - c.Assert(co.schedulers, HasLen, 0) + re.Len(co.schedulers, 0) // the option remains default scheduler - c.Assert(co.cluster.opt.GetSchedulers(), HasLen, 4) + re.Len(co.cluster.opt.GetSchedulers(), 4) co.stop() co.wg.Wait() } -func (s *testCoordinatorSuite) TestRestart(c *C) { +func TestRestart(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tc, co, cleanup := prepare(func(cfg *config.ScheduleConfig) { // Turn off balance, we test add replica only. cfg.LeaderScheduleLimit = 0 cfg.RegionScheduleLimit = 0 - }, nil, func(co *coordinator) { co.run() }, c) + }, nil, func(co *coordinator) { co.run() }, re) hbStreams := co.hbStreams defer cleanup() // Add 3 stores (1, 2, 3) and a region with 1 replica on store 1. - c.Assert(tc.addRegionStore(1, 1), IsNil) - c.Assert(tc.addRegionStore(2, 2), IsNil) - c.Assert(tc.addRegionStore(3, 3), IsNil) - c.Assert(tc.addLeaderRegion(1, 1), IsNil) + re.NoError(tc.addRegionStore(1, 1)) + re.NoError(tc.addRegionStore(2, 2)) + re.NoError(tc.addRegionStore(3, 3)) + re.NoError(tc.addLeaderRegion(1, 1)) region := tc.GetRegion(1) co.prepareChecker.collect(region) // Add 1 replica on store 2. stream := mockhbstream.NewHeartbeatStream() - c.Assert(dispatchHeartbeat(co, region, stream), IsNil) - region = waitAddLearner(c, stream, region, 2) - c.Assert(dispatchHeartbeat(co, region, stream), IsNil) - region = waitPromoteLearner(c, stream, region, 2) + re.NoError(dispatchHeartbeat(co, region, stream)) + region = waitAddLearner(re, stream, region, 2) + re.NoError(dispatchHeartbeat(co, region, stream)) + region = waitPromoteLearner(re, stream, region, 2) co.stop() co.wg.Wait() // Recreate coordinator then add another replica on store 3. - co = newCoordinator(s.ctx, tc.RaftCluster, hbStreams) + co = newCoordinator(ctx, tc.RaftCluster, hbStreams) co.prepareChecker.collect(region) co.run() - c.Assert(dispatchHeartbeat(co, region, stream), IsNil) - region = waitAddLearner(c, stream, region, 3) - c.Assert(dispatchHeartbeat(co, region, stream), IsNil) - waitPromoteLearner(c, stream, region, 3) + re.NoError(dispatchHeartbeat(co, region, stream)) + region = waitAddLearner(re, stream, region, 3) + re.NoError(dispatchHeartbeat(co, region, stream)) + waitPromoteLearner(re, stream, region, 3) } -func (s *testCoordinatorSuite) TestPauseScheduler(c *C) { - _, co, cleanup := prepare(nil, nil, func(co *coordinator) { co.run() }, c) +func TestPauseScheduler(t *testing.T) { + re := require.New(t) + + _, co, cleanup := prepare(nil, nil, func(co *coordinator) { co.run() }, re) defer cleanup() _, err := co.isSchedulerAllowed("test") - c.Assert(err, NotNil) + re.Error(err) co.pauseOrResumeScheduler(schedulers.BalanceLeaderName, 60) paused, _ := co.isSchedulerPaused(schedulers.BalanceLeaderName) - c.Assert(paused, Equals, true) + re.True(paused) pausedAt, err := co.getPausedSchedulerDelayAt(schedulers.BalanceLeaderName) - c.Assert(err, IsNil) + re.NoError(err) resumeAt, err := co.getPausedSchedulerDelayUntil(schedulers.BalanceLeaderName) - c.Assert(err, IsNil) - c.Assert(resumeAt-pausedAt, Equals, int64(60)) + re.NoError(err) + re.Equal(int64(60), resumeAt-pausedAt) allowed, _ := co.isSchedulerAllowed(schedulers.BalanceLeaderName) - c.Assert(allowed, Equals, false) + re.False(allowed) } func BenchmarkPatrolRegion(b *testing.B) { + re := require.New(b) + mergeLimit := uint64(4100) regionNum := 10000 tc, co, cleanup := prepare(func(cfg *config.ScheduleConfig) { cfg.MergeScheduleLimit = mergeLimit - }, nil, nil, &C{}) + }, nil, nil, re) defer cleanup() tc.opt.SetSplitMergeInterval(time.Duration(0)) @@ -955,83 +981,71 @@ func BenchmarkPatrolRegion(b *testing.B) { co.patrolRegions() } -func waitOperator(c *C, co *coordinator, regionID uint64) { - testutil.WaitUntil(c, func() bool { +func waitOperator(re *require.Assertions, co *coordinator, regionID uint64) { + testutil.Eventually(re, func() bool { return co.opController.GetOperator(regionID) != nil }) } -var _ = Suite(&testOperatorControllerSuite{}) - -type testOperatorControllerSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *testOperatorControllerSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) - c.Assert(failpoint.Enable("github.com/tikv/pd/server/schedule/unexpectedOperator", "return(true)"), IsNil) -} - -func (s *testOperatorControllerSuite) TearDownSuite(c *C) { - s.cancel() -} +func TestOperatorCount(t *testing.T) { + re := require.New(t) -func (s *testOperatorControllerSuite) TestOperatorCount(c *C) { - tc, co, cleanup := prepare(nil, nil, nil, c) + tc, co, cleanup := prepare(nil, nil, nil, re) defer cleanup() oc := co.opController - c.Assert(oc.OperatorCount(operator.OpLeader), Equals, uint64(0)) - c.Assert(oc.OperatorCount(operator.OpRegion), Equals, uint64(0)) + re.Equal(uint64(0), oc.OperatorCount(operator.OpLeader)) + re.Equal(uint64(0), oc.OperatorCount(operator.OpRegion)) - c.Assert(tc.addLeaderRegion(1, 1), IsNil) - c.Assert(tc.addLeaderRegion(2, 2), IsNil) + re.NoError(tc.addLeaderRegion(1, 1)) + re.NoError(tc.addLeaderRegion(2, 2)) { op1 := newTestOperator(1, tc.GetRegion(1).GetRegionEpoch(), operator.OpLeader) oc.AddWaitingOperator(op1) - c.Assert(oc.OperatorCount(operator.OpLeader), Equals, uint64(1)) // 1:leader + re.Equal(uint64(1), oc.OperatorCount(operator.OpLeader)) // 1:leader op2 := newTestOperator(2, tc.GetRegion(2).GetRegionEpoch(), operator.OpLeader) oc.AddWaitingOperator(op2) - c.Assert(oc.OperatorCount(operator.OpLeader), Equals, uint64(2)) // 1:leader, 2:leader - c.Assert(oc.RemoveOperator(op1), IsTrue) - c.Assert(oc.OperatorCount(operator.OpLeader), Equals, uint64(1)) // 2:leader + re.Equal(uint64(2), oc.OperatorCount(operator.OpLeader)) // 1:leader, 2:leader + re.True(oc.RemoveOperator(op1)) + re.Equal(uint64(1), oc.OperatorCount(operator.OpLeader)) // 2:leader } { op1 := newTestOperator(1, tc.GetRegion(1).GetRegionEpoch(), operator.OpRegion) oc.AddWaitingOperator(op1) - c.Assert(oc.OperatorCount(operator.OpRegion), Equals, uint64(1)) // 1:region 2:leader - c.Assert(oc.OperatorCount(operator.OpLeader), Equals, uint64(1)) + re.Equal(uint64(1), oc.OperatorCount(operator.OpRegion)) // 1:region 2:leader + re.Equal(uint64(1), oc.OperatorCount(operator.OpLeader)) op2 := newTestOperator(2, tc.GetRegion(2).GetRegionEpoch(), operator.OpRegion) op2.SetPriorityLevel(core.HighPriority) oc.AddWaitingOperator(op2) - c.Assert(oc.OperatorCount(operator.OpRegion), Equals, uint64(2)) // 1:region 2:region - c.Assert(oc.OperatorCount(operator.OpLeader), Equals, uint64(0)) + re.Equal(uint64(2), oc.OperatorCount(operator.OpRegion)) // 1:region 2:region + re.Equal(uint64(0), oc.OperatorCount(operator.OpLeader)) } } -func (s *testOperatorControllerSuite) TestStoreOverloaded(c *C) { - tc, co, cleanup := prepare(nil, nil, nil, c) +func TestStoreOverloaded(t *testing.T) { + re := require.New(t) + + tc, co, cleanup := prepare(nil, nil, nil, re) defer cleanup() oc := co.opController lb, err := schedule.CreateScheduler(schedulers.BalanceRegionType, oc, tc.storage, schedule.ConfigSliceDecoder(schedulers.BalanceRegionType, []string{"", ""})) - c.Assert(err, IsNil) + re.NoError(err) opt := tc.GetOpts() - c.Assert(tc.addRegionStore(4, 100), IsNil) - c.Assert(tc.addRegionStore(3, 100), IsNil) - c.Assert(tc.addRegionStore(2, 100), IsNil) - c.Assert(tc.addRegionStore(1, 10), IsNil) - c.Assert(tc.addLeaderRegion(1, 2, 3, 4), IsNil) + re.NoError(tc.addRegionStore(4, 100)) + re.NoError(tc.addRegionStore(3, 100)) + re.NoError(tc.addRegionStore(2, 100)) + re.NoError(tc.addRegionStore(1, 10)) + re.NoError(tc.addLeaderRegion(1, 2, 3, 4)) region := tc.GetRegion(1).Clone(core.SetApproximateSize(60)) tc.putRegion(region) start := time.Now() { ops := lb.Schedule(tc) - c.Assert(ops, HasLen, 1) + re.Len(ops, 1) op1 := ops[0] - c.Assert(op1, NotNil) - c.Assert(oc.AddOperator(op1), IsTrue) - c.Assert(oc.RemoveOperator(op1), IsTrue) + re.NotNil(op1) + re.True(oc.AddOperator(op1)) + re.True(oc.RemoveOperator(op1)) } for { time.Sleep(time.Millisecond * 10) @@ -1039,7 +1053,7 @@ func (s *testOperatorControllerSuite) TestStoreOverloaded(c *C) { if time.Since(start) > time.Second { break } - c.Assert(ops, HasLen, 0) + re.Len(ops, 0) } // reset all stores' limit @@ -1049,50 +1063,54 @@ func (s *testOperatorControllerSuite) TestStoreOverloaded(c *C) { time.Sleep(time.Second) for i := 0; i < 10; i++ { ops := lb.Schedule(tc) - c.Assert(ops, HasLen, 1) + re.Len(ops, 1) op := ops[0] - c.Assert(oc.AddOperator(op), IsTrue) - c.Assert(oc.RemoveOperator(op), IsTrue) + re.True(oc.AddOperator(op)) + re.True(oc.RemoveOperator(op)) } // sleep 1 seconds to make sure that the token is filled up time.Sleep(time.Second) for i := 0; i < 100; i++ { - c.Assert(len(lb.Schedule(tc)), Greater, 0) + re.Greater(len(lb.Schedule(tc)), 0) } } -func (s *testOperatorControllerSuite) TestStoreOverloadedWithReplace(c *C) { - tc, co, cleanup := prepare(nil, nil, nil, c) +func TestStoreOverloadedWithReplace(t *testing.T) { + re := require.New(t) + + tc, co, cleanup := prepare(nil, nil, nil, re) defer cleanup() oc := co.opController lb, err := schedule.CreateScheduler(schedulers.BalanceRegionType, oc, tc.storage, schedule.ConfigSliceDecoder(schedulers.BalanceRegionType, []string{"", ""})) - c.Assert(err, IsNil) - - c.Assert(tc.addRegionStore(4, 100), IsNil) - c.Assert(tc.addRegionStore(3, 100), IsNil) - c.Assert(tc.addRegionStore(2, 100), IsNil) - c.Assert(tc.addRegionStore(1, 10), IsNil) - c.Assert(tc.addLeaderRegion(1, 2, 3, 4), IsNil) - c.Assert(tc.addLeaderRegion(2, 1, 3, 4), IsNil) + re.NoError(err) + + re.NoError(tc.addRegionStore(4, 100)) + re.NoError(tc.addRegionStore(3, 100)) + re.NoError(tc.addRegionStore(2, 100)) + re.NoError(tc.addRegionStore(1, 10)) + re.NoError(tc.addLeaderRegion(1, 2, 3, 4)) + re.NoError(tc.addLeaderRegion(2, 1, 3, 4)) region := tc.GetRegion(1).Clone(core.SetApproximateSize(60)) tc.putRegion(region) region = tc.GetRegion(2).Clone(core.SetApproximateSize(60)) tc.putRegion(region) op1 := newTestOperator(1, tc.GetRegion(1).GetRegionEpoch(), operator.OpRegion, operator.AddPeer{ToStore: 1, PeerID: 1}) - c.Assert(oc.AddOperator(op1), IsTrue) + re.True(oc.AddOperator(op1)) op2 := newTestOperator(1, tc.GetRegion(1).GetRegionEpoch(), operator.OpRegion, operator.AddPeer{ToStore: 2, PeerID: 2}) op2.SetPriorityLevel(core.HighPriority) - c.Assert(oc.AddOperator(op2), IsTrue) + re.True(oc.AddOperator(op2)) op3 := newTestOperator(1, tc.GetRegion(2).GetRegionEpoch(), operator.OpRegion, operator.AddPeer{ToStore: 1, PeerID: 3}) - c.Assert(oc.AddOperator(op3), IsFalse) - c.Assert(lb.Schedule(tc), HasLen, 0) + re.False(oc.AddOperator(op3)) + re.Len(lb.Schedule(tc), 0) // sleep 2 seconds to make sure that token is filled up time.Sleep(2 * time.Second) - c.Assert(len(lb.Schedule(tc)), Greater, 0) + re.Greater(len(lb.Schedule(tc)), 0) } -func (s *testOperatorControllerSuite) TestDownStoreLimit(c *C) { - tc, co, cleanup := prepare(nil, nil, nil, c) +func TestDownStoreLimit(t *testing.T) { + re := require.New(t) + + tc, co, cleanup := prepare(nil, nil, nil, re) defer cleanup() oc := co.opController rc := co.checkers.GetRuleChecker() @@ -1116,8 +1134,8 @@ func (s *testOperatorControllerSuite) TestDownStoreLimit(c *C) { for i := uint64(1); i < 20; i++ { tc.addRegionStore(i+3, 100) op := rc.Check(region) - c.Assert(op, NotNil) - c.Assert(oc.AddOperator(op), IsTrue) + re.NotNil(op) + re.True(oc.AddOperator(op)) oc.RemoveOperator(op) } @@ -1126,28 +1144,12 @@ func (s *testOperatorControllerSuite) TestDownStoreLimit(c *C) { for i := uint64(20); i < 25; i++ { tc.addRegionStore(i+3, 100) op := rc.Check(region) - c.Assert(op, NotNil) - c.Assert(oc.AddOperator(op), IsTrue) + re.NotNil(op) + re.True(oc.AddOperator(op)) oc.RemoveOperator(op) } } -var _ = Suite(&testScheduleControllerSuite{}) - -type testScheduleControllerSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *testScheduleControllerSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) - c.Assert(failpoint.Enable("github.com/tikv/pd/server/schedule/unexpectedOperator", "return(true)"), IsNil) -} - -func (s *testScheduleControllerSuite) TearDownSuite(c *C) { - s.cancel() -} - // FIXME: remove after move into schedulers package type mockLimitScheduler struct { schedule.Scheduler @@ -1160,15 +1162,17 @@ func (s *mockLimitScheduler) IsScheduleAllowed(cluster schedule.Cluster) bool { return s.counter.OperatorCount(s.kind) < s.limit } -func (s *testScheduleControllerSuite) TestController(c *C) { - tc, co, cleanup := prepare(nil, nil, nil, c) +func TestController(t *testing.T) { + re := require.New(t) + + tc, co, cleanup := prepare(nil, nil, nil, re) defer cleanup() oc := co.opController - c.Assert(tc.addLeaderRegion(1, 1), IsNil) - c.Assert(tc.addLeaderRegion(2, 2), IsNil) + re.NoError(tc.addLeaderRegion(1, 1)) + re.NoError(tc.addLeaderRegion(2, 2)) scheduler, err := schedule.CreateScheduler(schedulers.BalanceLeaderType, oc, storage.NewStorageWithMemoryBackend(), schedule.ConfigSliceDecoder(schedulers.BalanceLeaderType, []string{"", ""})) - c.Assert(err, IsNil) + re.NoError(err) lb := &mockLimitScheduler{ Scheduler: scheduler, counter: oc, @@ -1178,25 +1182,25 @@ func (s *testScheduleControllerSuite) TestController(c *C) { sc := newScheduleController(co, lb) for i := schedulers.MinScheduleInterval; sc.GetInterval() != schedulers.MaxScheduleInterval; i = sc.GetNextInterval(i) { - c.Assert(sc.GetInterval(), Equals, i) - c.Assert(sc.Schedule(), HasLen, 0) + re.Equal(i, sc.GetInterval()) + re.Len(sc.Schedule(), 0) } // limit = 2 lb.limit = 2 // count = 0 { - c.Assert(sc.AllowSchedule(), IsTrue) + re.True(sc.AllowSchedule()) op1 := newTestOperator(1, tc.GetRegion(1).GetRegionEpoch(), operator.OpLeader) - c.Assert(oc.AddWaitingOperator(op1), Equals, 1) + re.Equal(1, oc.AddWaitingOperator(op1)) // count = 1 - c.Assert(sc.AllowSchedule(), IsTrue) + re.True(sc.AllowSchedule()) op2 := newTestOperator(2, tc.GetRegion(2).GetRegionEpoch(), operator.OpLeader) - c.Assert(oc.AddWaitingOperator(op2), Equals, 1) + re.Equal(1, oc.AddWaitingOperator(op2)) // count = 2 - c.Assert(sc.AllowSchedule(), IsFalse) - c.Assert(oc.RemoveOperator(op1), IsTrue) + re.False(sc.AllowSchedule()) + re.True(oc.RemoveOperator(op1)) // count = 1 - c.Assert(sc.AllowSchedule(), IsTrue) + re.True(sc.AllowSchedule()) } op11 := newTestOperator(1, tc.GetRegion(1).GetRegionEpoch(), operator.OpLeader) @@ -1204,55 +1208,57 @@ func (s *testScheduleControllerSuite) TestController(c *C) { { op3 := newTestOperator(2, tc.GetRegion(2).GetRegionEpoch(), operator.OpHotRegion) op3.SetPriorityLevel(core.HighPriority) - c.Assert(oc.AddWaitingOperator(op11), Equals, 1) - c.Assert(sc.AllowSchedule(), IsFalse) - c.Assert(oc.AddWaitingOperator(op3), Equals, 1) - c.Assert(sc.AllowSchedule(), IsTrue) - c.Assert(oc.RemoveOperator(op3), IsTrue) + re.Equal(1, oc.AddWaitingOperator(op11)) + re.False(sc.AllowSchedule()) + re.Equal(1, oc.AddWaitingOperator(op3)) + re.True(sc.AllowSchedule()) + re.True(oc.RemoveOperator(op3)) } // add a admin operator will remove old operator { op2 := newTestOperator(2, tc.GetRegion(2).GetRegionEpoch(), operator.OpLeader) - c.Assert(oc.AddWaitingOperator(op2), Equals, 1) - c.Assert(sc.AllowSchedule(), IsFalse) + re.Equal(1, oc.AddWaitingOperator(op2)) + re.False(sc.AllowSchedule()) op4 := newTestOperator(2, tc.GetRegion(2).GetRegionEpoch(), operator.OpAdmin) op4.SetPriorityLevel(core.HighPriority) - c.Assert(oc.AddWaitingOperator(op4), Equals, 1) - c.Assert(sc.AllowSchedule(), IsTrue) - c.Assert(oc.RemoveOperator(op4), IsTrue) + re.Equal(1, oc.AddWaitingOperator(op4)) + re.True(sc.AllowSchedule()) + re.True(oc.RemoveOperator(op4)) } // test wrong region id. { op5 := newTestOperator(3, &metapb.RegionEpoch{}, operator.OpHotRegion) - c.Assert(oc.AddWaitingOperator(op5), Equals, 0) + re.Equal(0, oc.AddWaitingOperator(op5)) } // test wrong region epoch. - c.Assert(oc.RemoveOperator(op11), IsTrue) + re.True(oc.RemoveOperator(op11)) epoch := &metapb.RegionEpoch{ Version: tc.GetRegion(1).GetRegionEpoch().GetVersion() + 1, ConfVer: tc.GetRegion(1).GetRegionEpoch().GetConfVer(), } { op6 := newTestOperator(1, epoch, operator.OpLeader) - c.Assert(oc.AddWaitingOperator(op6), Equals, 0) + re.Equal(0, oc.AddWaitingOperator(op6)) } epoch.Version-- { op6 := newTestOperator(1, epoch, operator.OpLeader) - c.Assert(oc.AddWaitingOperator(op6), Equals, 1) - c.Assert(oc.RemoveOperator(op6), IsTrue) + re.Equal(1, oc.AddWaitingOperator(op6)) + re.True(oc.RemoveOperator(op6)) } } -func (s *testScheduleControllerSuite) TestInterval(c *C) { - _, co, cleanup := prepare(nil, nil, nil, c) +func TestInterval(t *testing.T) { + re := require.New(t) + + _, co, cleanup := prepare(nil, nil, nil, re) defer cleanup() lb, err := schedule.CreateScheduler(schedulers.BalanceLeaderType, co.opController, storage.NewStorageWithMemoryBackend(), schedule.ConfigSliceDecoder(schedulers.BalanceLeaderType, []string{"", ""})) - c.Assert(err, IsNil) + re.NoError(err) sc := newScheduleController(co, lb) // If no operator for x seconds, the next check should be in x/2 seconds. @@ -1260,15 +1266,15 @@ func (s *testScheduleControllerSuite) TestInterval(c *C) { for _, n := range idleSeconds { sc.nextInterval = schedulers.MinScheduleInterval for totalSleep := time.Duration(0); totalSleep <= time.Second*time.Duration(n); totalSleep += sc.GetInterval() { - c.Assert(sc.Schedule(), HasLen, 0) + re.Len(sc.Schedule(), 0) } - c.Assert(sc.GetInterval(), Less, time.Second*time.Duration(n/2)) + re.Less(sc.GetInterval(), time.Second*time.Duration(n/2)) } } -func waitAddLearner(c *C, stream mockhbstream.HeartbeatStream, region *core.RegionInfo, storeID uint64) *core.RegionInfo { +func waitAddLearner(re *require.Assertions, stream mockhbstream.HeartbeatStream, region *core.RegionInfo, storeID uint64) *core.RegionInfo { var res *pdpb.RegionHeartbeatResponse - testutil.WaitUntil(c, func() bool { + testutil.Eventually(re, func() bool { if res = stream.Recv(); res != nil { return res.GetRegionId() == region.GetID() && res.GetChangePeer().GetChangeType() == eraftpb.ConfChangeType_AddLearnerNode && @@ -1282,9 +1288,9 @@ func waitAddLearner(c *C, stream mockhbstream.HeartbeatStream, region *core.Regi ) } -func waitPromoteLearner(c *C, stream mockhbstream.HeartbeatStream, region *core.RegionInfo, storeID uint64) *core.RegionInfo { +func waitPromoteLearner(re *require.Assertions, stream mockhbstream.HeartbeatStream, region *core.RegionInfo, storeID uint64) *core.RegionInfo { var res *pdpb.RegionHeartbeatResponse - testutil.WaitUntil(c, func() bool { + testutil.Eventually(re, func() bool { if res = stream.Recv(); res != nil { return res.GetRegionId() == region.GetID() && res.GetChangePeer().GetChangeType() == eraftpb.ConfChangeType_AddNode && @@ -1299,9 +1305,9 @@ func waitPromoteLearner(c *C, stream mockhbstream.HeartbeatStream, region *core. ) } -func waitRemovePeer(c *C, stream mockhbstream.HeartbeatStream, region *core.RegionInfo, storeID uint64) *core.RegionInfo { +func waitRemovePeer(re *require.Assertions, stream mockhbstream.HeartbeatStream, region *core.RegionInfo, storeID uint64) *core.RegionInfo { var res *pdpb.RegionHeartbeatResponse - testutil.WaitUntil(c, func() bool { + testutil.Eventually(re, func() bool { if res = stream.Recv(); res != nil { return res.GetRegionId() == region.GetID() && res.GetChangePeer().GetChangeType() == eraftpb.ConfChangeType_RemoveNode && @@ -1315,9 +1321,9 @@ func waitRemovePeer(c *C, stream mockhbstream.HeartbeatStream, region *core.Regi ) } -func waitTransferLeader(c *C, stream mockhbstream.HeartbeatStream, region *core.RegionInfo, storeID uint64) *core.RegionInfo { +func waitTransferLeader(re *require.Assertions, stream mockhbstream.HeartbeatStream, region *core.RegionInfo, storeID uint64) *core.RegionInfo { var res *pdpb.RegionHeartbeatResponse - testutil.WaitUntil(c, func() bool { + testutil.Eventually(re, func() bool { if res = stream.Recv(); res != nil { if res.GetRegionId() == region.GetID() { for _, peer := range append(res.GetTransferLeader().GetPeers(), res.GetTransferLeader().GetPeer()) { @@ -1334,8 +1340,8 @@ func waitTransferLeader(c *C, stream mockhbstream.HeartbeatStream, region *core. ) } -func waitNoResponse(c *C, stream mockhbstream.HeartbeatStream) { - testutil.WaitUntil(c, func() bool { +func waitNoResponse(re *require.Assertions, stream mockhbstream.HeartbeatStream) { + testutil.Eventually(re, func() bool { res := stream.Recv() return res == nil }) diff --git a/server/cluster/store_limiter_test.go b/server/cluster/store_limiter_test.go index d23bdb06536..7a1dcab9fad 100644 --- a/server/cluster/store_limiter_test.go +++ b/server/cluster/store_limiter_test.go @@ -15,43 +15,40 @@ package cluster import ( - . "github.com/pingcap/check" + "testing" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/core/storelimit" ) -var _ = Suite(&testStoreLimiterSuite{}) +func TestCollect(t *testing.T) { + re := require.New(t) -type testStoreLimiterSuite struct { - opt *config.PersistOptions -} + limiter := NewStoreLimiter(config.NewTestOptions()) -func (s *testStoreLimiterSuite) SetUpSuite(c *C) { - // Create a server for testing - s.opt = config.NewTestOptions() + limiter.Collect(&pdpb.StoreStats{}) + re.Equal(int64(1), limiter.state.cst.total) } -func (s *testStoreLimiterSuite) TestCollect(c *C) { - limiter := NewStoreLimiter(s.opt) +func TestStoreLimitScene(t *testing.T) { + re := require.New(t) - limiter.Collect(&pdpb.StoreStats{}) - c.Assert(limiter.state.cst.total, Equals, int64(1)) + limiter := NewStoreLimiter(config.NewTestOptions()) + re.Equal(storelimit.DefaultScene(storelimit.AddPeer), limiter.scene[storelimit.AddPeer]) + re.Equal(storelimit.DefaultScene(storelimit.RemovePeer), limiter.scene[storelimit.RemovePeer]) } -func (s *testStoreLimiterSuite) TestStoreLimitScene(c *C) { - limiter := NewStoreLimiter(s.opt) - c.Assert(limiter.scene[storelimit.AddPeer], DeepEquals, storelimit.DefaultScene(storelimit.AddPeer)) - c.Assert(limiter.scene[storelimit.RemovePeer], DeepEquals, storelimit.DefaultScene(storelimit.RemovePeer)) -} +func TestReplaceStoreLimitScene(t *testing.T) { + re := require.New(t) -func (s *testStoreLimiterSuite) TestReplaceStoreLimitScene(c *C) { - limiter := NewStoreLimiter(s.opt) + limiter := NewStoreLimiter(config.NewTestOptions()) sceneAddPeer := &storelimit.Scene{Idle: 4, Low: 3, Normal: 2, High: 1} limiter.ReplaceStoreLimitScene(sceneAddPeer, storelimit.AddPeer) - c.Assert(limiter.scene[storelimit.AddPeer], DeepEquals, sceneAddPeer) + re.Equal(sceneAddPeer, limiter.scene[storelimit.AddPeer]) sceneRemovePeer := &storelimit.Scene{Idle: 5, Low: 4, Normal: 3, High: 2} limiter.ReplaceStoreLimitScene(sceneRemovePeer, storelimit.RemovePeer) diff --git a/server/cluster/unsafe_recovery_controller_test.go b/server/cluster/unsafe_recovery_controller_test.go index edd6bf9c187..2b3717dabd6 100644 --- a/server/cluster/unsafe_recovery_controller_test.go +++ b/server/cluster/unsafe_recovery_controller_test.go @@ -16,13 +16,14 @@ package cluster import ( "context" + "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/eraftpb" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/kvproto/pkg/raft_serverpb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/codec" "github.com/tikv/pd/pkg/mock/mockid" "github.com/tikv/pd/server/core" @@ -30,21 +31,6 @@ import ( "github.com/tikv/pd/server/storage" ) -var _ = Suite(&testUnsafeRecoverySuite{}) - -type testUnsafeRecoverySuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *testUnsafeRecoverySuite) TearDownTest(c *C) { - s.cancel() -} - -func (s *testUnsafeRecoverySuite) SetUpTest(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) -} - func newStoreHeartbeat(storeID uint64, report *pdpb.StoreReport) *pdpb.StoreHeartbeatRequest { return &pdpb.StoreHeartbeatRequest{ Stats: &pdpb.StoreStats{ @@ -54,7 +40,7 @@ func newStoreHeartbeat(storeID uint64, report *pdpb.StoreReport) *pdpb.StoreHear } } -func applyRecoveryPlan(c *C, storeID uint64, storeReports map[uint64]*pdpb.StoreReport, resp *pdpb.StoreHeartbeatResponse) { +func applyRecoveryPlan(re *require.Assertions, storeID uint64, storeReports map[uint64]*pdpb.StoreReport, resp *pdpb.StoreHeartbeatResponse) { plan := resp.GetRecoveryPlan() if plan == nil { return @@ -122,7 +108,7 @@ func applyRecoveryPlan(c *C, storeID uint64, storeReports map[uint64]*pdpb.Store } region.RegionEpoch.ConfVer += 1 if store == storeID { - c.Assert(report.IsForceLeader, IsTrue) + re.True(report.IsForceLeader) } break } @@ -135,7 +121,7 @@ func applyRecoveryPlan(c *C, storeID uint64, storeReports map[uint64]*pdpb.Store } } -func advanceUntilFinished(c *C, recoveryController *unsafeRecoveryController, reports map[uint64]*pdpb.StoreReport) { +func advanceUntilFinished(re *require.Assertions, recoveryController *unsafeRecoveryController, reports map[uint64]*pdpb.StoreReport) { retry := 0 for { @@ -144,7 +130,7 @@ func advanceUntilFinished(c *C, recoveryController *unsafeRecoveryController, re req.StoreReport = report resp := &pdpb.StoreHeartbeatResponse{} recoveryController.HandleStoreHeartbeat(req, resp) - applyRecoveryPlan(c, storeID, reports, resp) + applyRecoveryPlan(re, storeID, reports, resp) } if recoveryController.GetStage() == finished { break @@ -157,19 +143,23 @@ func advanceUntilFinished(c *C, recoveryController *unsafeRecoveryController, re } } -func (s *testUnsafeRecoverySuite) TestFinished(c *C) { +func TestFinished(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, _ := newTestScheduleConfig() - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, hbstream.NewTestHeartbeatStreams(s.ctx, cluster.meta.GetId(), cluster, true)) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, hbstream.NewTestHeartbeatStreams(ctx, cluster.meta.GetId(), cluster, true)) cluster.coordinator.run() for _, store := range newTestStores(3, "6.0.0") { - c.Assert(cluster.PutStore(store.GetMeta()), IsNil) + re.NoError(cluster.PutStore(store.GetMeta())) } recoveryController := newUnsafeRecoveryController(cluster) - c.Assert(recoveryController.RemoveFailedStores(map[uint64]struct{}{ + re.NoError(recoveryController.RemoveFailedStores(map[uint64]struct{}{ 2: {}, 3: {}, - }, 60), IsNil) + }, 60)) reports := map[uint64]*pdpb.StoreReport{ 1: {PeerReports: []*pdpb.PeerReport{ @@ -183,18 +173,18 @@ func (s *testUnsafeRecoverySuite) TestFinished(c *C) { {Id: 11, StoreId: 1}, {Id: 21, StoreId: 2}, {Id: 31, StoreId: 3}}}}}, }}, } - c.Assert(recoveryController.GetStage(), Equals, collectReport) + re.Equal(collectReport, recoveryController.GetStage()) for storeID := range reports { req := newStoreHeartbeat(storeID, nil) resp := &pdpb.StoreHeartbeatResponse{} recoveryController.HandleStoreHeartbeat(req, resp) // require peer report by empty plan - c.Assert(resp.RecoveryPlan, NotNil) - c.Assert(len(resp.RecoveryPlan.Creates), Equals, 0) - c.Assert(len(resp.RecoveryPlan.Demotes), Equals, 0) - c.Assert(resp.RecoveryPlan.ForceLeader, IsNil) - c.Assert(resp.RecoveryPlan.Step, Equals, uint64(1)) - applyRecoveryPlan(c, storeID, reports, resp) + re.NotNil(resp.RecoveryPlan) + re.Empty(len(resp.RecoveryPlan.Creates)) + re.Empty(len(resp.RecoveryPlan.Demotes)) + re.Nil(resp.RecoveryPlan.ForceLeader) + re.Equal(uint64(1), resp.RecoveryPlan.Step) + applyRecoveryPlan(re, storeID, reports, resp) } // receive all reports and dispatch plan @@ -203,49 +193,53 @@ func (s *testUnsafeRecoverySuite) TestFinished(c *C) { req.StoreReport = report resp := &pdpb.StoreHeartbeatResponse{} recoveryController.HandleStoreHeartbeat(req, resp) - c.Assert(resp.RecoveryPlan, NotNil) - c.Assert(resp.RecoveryPlan.ForceLeader, NotNil) - c.Assert(len(resp.RecoveryPlan.ForceLeader.EnterForceLeaders), Equals, 1) - c.Assert(resp.RecoveryPlan.ForceLeader.FailedStores, NotNil) - applyRecoveryPlan(c, storeID, reports, resp) + re.NotNil(resp.RecoveryPlan) + re.NotNil(resp.RecoveryPlan.ForceLeader) + re.Equal(1, len(resp.RecoveryPlan.ForceLeader.EnterForceLeaders)) + re.NotNil(resp.RecoveryPlan.ForceLeader.FailedStores) + applyRecoveryPlan(re, storeID, reports, resp) } - c.Assert(recoveryController.GetStage(), Equals, forceLeader) + re.Equal(forceLeader, recoveryController.GetStage()) for storeID, report := range reports { req := newStoreHeartbeat(storeID, report) req.StoreReport = report resp := &pdpb.StoreHeartbeatResponse{} recoveryController.HandleStoreHeartbeat(req, resp) - c.Assert(resp.RecoveryPlan, NotNil) - c.Assert(len(resp.RecoveryPlan.Demotes), Equals, 1) - applyRecoveryPlan(c, storeID, reports, resp) + re.NotNil(resp.RecoveryPlan) + re.Equal(1, len(resp.RecoveryPlan.Demotes)) + applyRecoveryPlan(re, storeID, reports, resp) } - c.Assert(recoveryController.GetStage(), Equals, demoteFailedVoter) + re.Equal(demoteFailedVoter, recoveryController.GetStage()) for storeID, report := range reports { req := newStoreHeartbeat(storeID, report) req.StoreReport = report resp := &pdpb.StoreHeartbeatResponse{} recoveryController.HandleStoreHeartbeat(req, resp) - c.Assert(resp.RecoveryPlan, IsNil) + re.Nil(resp.RecoveryPlan) // remove the two failed peers - applyRecoveryPlan(c, storeID, reports, resp) + applyRecoveryPlan(re, storeID, reports, resp) } - c.Assert(recoveryController.GetStage(), Equals, finished) + re.Equal(finished, recoveryController.GetStage()) } -func (s *testUnsafeRecoverySuite) TestFailed(c *C) { +func TestFailed(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, _ := newTestScheduleConfig() - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, hbstream.NewTestHeartbeatStreams(s.ctx, cluster.meta.GetId(), cluster, true)) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, hbstream.NewTestHeartbeatStreams(ctx, cluster.meta.GetId(), cluster, true)) cluster.coordinator.run() for _, store := range newTestStores(3, "6.0.0") { - c.Assert(cluster.PutStore(store.GetMeta()), IsNil) + re.NoError(cluster.PutStore(store.GetMeta())) } recoveryController := newUnsafeRecoveryController(cluster) - c.Assert(recoveryController.RemoveFailedStores(map[uint64]struct{}{ + re.NoError(recoveryController.RemoveFailedStores(map[uint64]struct{}{ 2: {}, 3: {}, - }, 60), IsNil) + }, 60)) reports := map[uint64]*pdpb.StoreReport{ 1: {PeerReports: []*pdpb.PeerReport{ @@ -259,17 +253,17 @@ func (s *testUnsafeRecoverySuite) TestFailed(c *C) { {Id: 11, StoreId: 1}, {Id: 21, StoreId: 2}, {Id: 31, StoreId: 3}}}}}, }}, } - c.Assert(recoveryController.GetStage(), Equals, collectReport) + re.Equal(collectReport, recoveryController.GetStage()) // require peer report for storeID := range reports { req := newStoreHeartbeat(storeID, nil) resp := &pdpb.StoreHeartbeatResponse{} recoveryController.HandleStoreHeartbeat(req, resp) - c.Assert(resp.RecoveryPlan, NotNil) - c.Assert(len(resp.RecoveryPlan.Creates), Equals, 0) - c.Assert(len(resp.RecoveryPlan.Demotes), Equals, 0) - c.Assert(resp.RecoveryPlan.ForceLeader, IsNil) - applyRecoveryPlan(c, storeID, reports, resp) + re.NotNil(resp.RecoveryPlan) + re.Empty(len(resp.RecoveryPlan.Creates)) + re.Empty(len(resp.RecoveryPlan.Demotes)) + re.Nil(resp.RecoveryPlan.ForceLeader) + applyRecoveryPlan(re, storeID, reports, resp) } // receive all reports and dispatch plan @@ -278,39 +272,39 @@ func (s *testUnsafeRecoverySuite) TestFailed(c *C) { req.StoreReport = report resp := &pdpb.StoreHeartbeatResponse{} recoveryController.HandleStoreHeartbeat(req, resp) - c.Assert(resp.RecoveryPlan, NotNil) - c.Assert(resp.RecoveryPlan.ForceLeader, NotNil) - c.Assert(len(resp.RecoveryPlan.ForceLeader.EnterForceLeaders), Equals, 1) - c.Assert(resp.RecoveryPlan.ForceLeader.FailedStores, NotNil) - applyRecoveryPlan(c, storeID, reports, resp) + re.NotNil(resp.RecoveryPlan) + re.NotNil(resp.RecoveryPlan.ForceLeader) + re.Equal(1, len(resp.RecoveryPlan.ForceLeader.EnterForceLeaders)) + re.NotNil(resp.RecoveryPlan.ForceLeader.FailedStores) + applyRecoveryPlan(re, storeID, reports, resp) } - c.Assert(recoveryController.GetStage(), Equals, forceLeader) + re.Equal(forceLeader, recoveryController.GetStage()) for storeID, report := range reports { req := newStoreHeartbeat(storeID, report) req.StoreReport = report resp := &pdpb.StoreHeartbeatResponse{} recoveryController.HandleStoreHeartbeat(req, resp) - c.Assert(resp.RecoveryPlan, NotNil) - c.Assert(len(resp.RecoveryPlan.Demotes), Equals, 1) - applyRecoveryPlan(c, storeID, reports, resp) + re.NotNil(resp.RecoveryPlan) + re.Equal(1, len(resp.RecoveryPlan.Demotes)) + applyRecoveryPlan(re, storeID, reports, resp) } - c.Assert(recoveryController.GetStage(), Equals, demoteFailedVoter) + re.Equal(demoteFailedVoter, recoveryController.GetStage()) // received heartbeat from failed store, abort req := newStoreHeartbeat(2, nil) resp := &pdpb.StoreHeartbeatResponse{} recoveryController.HandleStoreHeartbeat(req, resp) - c.Assert(resp.RecoveryPlan, IsNil) - c.Assert(recoveryController.GetStage(), Equals, exitForceLeader) + re.Nil(resp.RecoveryPlan) + re.Equal(exitForceLeader, recoveryController.GetStage()) for storeID, report := range reports { req := newStoreHeartbeat(storeID, report) req.StoreReport = report resp := &pdpb.StoreHeartbeatResponse{} recoveryController.HandleStoreHeartbeat(req, resp) - c.Assert(resp.RecoveryPlan, NotNil) - applyRecoveryPlan(c, storeID, reports, resp) + re.NotNil(resp.RecoveryPlan) + applyRecoveryPlan(re, storeID, reports, resp) } for storeID, report := range reports { @@ -318,24 +312,28 @@ func (s *testUnsafeRecoverySuite) TestFailed(c *C) { req.StoreReport = report resp := &pdpb.StoreHeartbeatResponse{} recoveryController.HandleStoreHeartbeat(req, resp) - applyRecoveryPlan(c, storeID, reports, resp) + applyRecoveryPlan(re, storeID, reports, resp) } - c.Assert(recoveryController.GetStage(), Equals, failed) + re.Equal(failed, recoveryController.GetStage()) } -func (s *testUnsafeRecoverySuite) TestForceLeaderFail(c *C) { +func TestForceLeaderFail(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, _ := newTestScheduleConfig() - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, hbstream.NewTestHeartbeatStreams(s.ctx, cluster.meta.GetId(), cluster, true)) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, hbstream.NewTestHeartbeatStreams(ctx, cluster.meta.GetId(), cluster, true)) cluster.coordinator.run() for _, store := range newTestStores(4, "6.0.0") { - c.Assert(cluster.PutStore(store.GetMeta()), IsNil) + re.NoError(cluster.PutStore(store.GetMeta())) } recoveryController := newUnsafeRecoveryController(cluster) - c.Assert(recoveryController.RemoveFailedStores(map[uint64]struct{}{ + re.NoError(recoveryController.RemoveFailedStores(map[uint64]struct{}{ 3: {}, 4: {}, - }, 60), IsNil) + }, 60)) reports := map[uint64]*pdpb.StoreReport{ 1: { @@ -376,42 +374,46 @@ func (s *testUnsafeRecoverySuite) TestForceLeaderFail(c *C) { resp2 := &pdpb.StoreHeartbeatResponse{} req2.StoreReport.Step = 1 recoveryController.HandleStoreHeartbeat(req2, resp2) - c.Assert(recoveryController.GetStage(), Equals, forceLeader) + re.Equal(forceLeader, recoveryController.GetStage()) recoveryController.HandleStoreHeartbeat(req1, resp1) // force leader on store 1 succeed - applyRecoveryPlan(c, 1, reports, resp1) - applyRecoveryPlan(c, 2, reports, resp2) + applyRecoveryPlan(re, 1, reports, resp1) + applyRecoveryPlan(re, 2, reports, resp2) // force leader on store 2 doesn't succeed reports[2].PeerReports[0].IsForceLeader = false // force leader should retry on store 2 recoveryController.HandleStoreHeartbeat(req1, resp1) recoveryController.HandleStoreHeartbeat(req2, resp2) - c.Assert(recoveryController.GetStage(), Equals, forceLeader) + re.Equal(forceLeader, recoveryController.GetStage()) recoveryController.HandleStoreHeartbeat(req1, resp1) // force leader succeed this time - applyRecoveryPlan(c, 1, reports, resp1) - applyRecoveryPlan(c, 2, reports, resp2) + applyRecoveryPlan(re, 1, reports, resp1) + applyRecoveryPlan(re, 2, reports, resp2) recoveryController.HandleStoreHeartbeat(req1, resp1) recoveryController.HandleStoreHeartbeat(req2, resp2) - c.Assert(recoveryController.GetStage(), Equals, demoteFailedVoter) + re.Equal(demoteFailedVoter, recoveryController.GetStage()) } -func (s *testUnsafeRecoverySuite) TestAffectedTableID(c *C) { +func TestAffectedTableID(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, _ := newTestScheduleConfig() - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, hbstream.NewTestHeartbeatStreams(s.ctx, cluster.meta.GetId(), cluster, true)) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, hbstream.NewTestHeartbeatStreams(ctx, cluster.meta.GetId(), cluster, true)) cluster.coordinator.run() for _, store := range newTestStores(3, "6.0.0") { - c.Assert(cluster.PutStore(store.GetMeta()), IsNil) + re.NoError(cluster.PutStore(store.GetMeta())) } recoveryController := newUnsafeRecoveryController(cluster) - c.Assert(recoveryController.RemoveFailedStores(map[uint64]struct{}{ + re.NoError(recoveryController.RemoveFailedStores(map[uint64]struct{}{ 2: {}, 3: {}, - }, 60), IsNil) + }, 60)) reports := map[uint64]*pdpb.StoreReport{ 1: { @@ -429,26 +431,30 @@ func (s *testUnsafeRecoverySuite) TestAffectedTableID(c *C) { }, } - advanceUntilFinished(c, recoveryController, reports) + advanceUntilFinished(re, recoveryController, reports) - c.Assert(len(recoveryController.affectedTableIDs), Equals, 1) + re.Equal(1, len(recoveryController.affectedTableIDs)) _, exists := recoveryController.affectedTableIDs[6] - c.Assert(exists, IsTrue) + re.True(exists) } -func (s *testUnsafeRecoverySuite) TestForceLeaderForCommitMerge(c *C) { +func TestForceLeaderForCommitMerge(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, _ := newTestScheduleConfig() - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, hbstream.NewTestHeartbeatStreams(s.ctx, cluster.meta.GetId(), cluster, true)) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, hbstream.NewTestHeartbeatStreams(ctx, cluster.meta.GetId(), cluster, true)) cluster.coordinator.run() for _, store := range newTestStores(3, "6.0.0") { - c.Assert(cluster.PutStore(store.GetMeta()), IsNil) + re.NoError(cluster.PutStore(store.GetMeta())) } recoveryController := newUnsafeRecoveryController(cluster) - c.Assert(recoveryController.RemoveFailedStores(map[uint64]struct{}{ + re.NoError(recoveryController.RemoveFailedStores(map[uint64]struct{}{ 2: {}, 3: {}, - }, 60), IsNil) + }, 60)) reports := map[uint64]*pdpb.StoreReport{ 1: { @@ -483,44 +489,48 @@ func (s *testUnsafeRecoverySuite) TestForceLeaderForCommitMerge(c *C) { resp := &pdpb.StoreHeartbeatResponse{} req.StoreReport.Step = 1 recoveryController.HandleStoreHeartbeat(req, resp) - c.Assert(recoveryController.GetStage(), Equals, forceLeaderForCommitMerge) + re.Equal(forceLeaderForCommitMerge, recoveryController.GetStage()) // force leader on regions of commit merge first - c.Assert(resp.RecoveryPlan, NotNil) - c.Assert(resp.RecoveryPlan.ForceLeader, NotNil) - c.Assert(len(resp.RecoveryPlan.ForceLeader.EnterForceLeaders), Equals, 1) - c.Assert(resp.RecoveryPlan.ForceLeader.EnterForceLeaders[0], Equals, uint64(1002)) - c.Assert(resp.RecoveryPlan.ForceLeader.FailedStores, NotNil) - applyRecoveryPlan(c, 1, reports, resp) + re.NotNil(resp.RecoveryPlan) + re.NotNil(resp.RecoveryPlan.ForceLeader) + re.Equal(1, len(resp.RecoveryPlan.ForceLeader.EnterForceLeaders)) + re.Equal(uint64(1002), resp.RecoveryPlan.ForceLeader.EnterForceLeaders[0]) + re.NotNil(resp.RecoveryPlan.ForceLeader.FailedStores) + applyRecoveryPlan(re, 1, reports, resp) recoveryController.HandleStoreHeartbeat(req, resp) - c.Assert(recoveryController.GetStage(), Equals, forceLeader) + re.Equal(forceLeader, recoveryController.GetStage()) // force leader on the rest regions - c.Assert(resp.RecoveryPlan, NotNil) - c.Assert(resp.RecoveryPlan.ForceLeader, NotNil) - c.Assert(len(resp.RecoveryPlan.ForceLeader.EnterForceLeaders), Equals, 1) - c.Assert(resp.RecoveryPlan.ForceLeader.EnterForceLeaders[0], Equals, uint64(1001)) - c.Assert(resp.RecoveryPlan.ForceLeader.FailedStores, NotNil) - applyRecoveryPlan(c, 1, reports, resp) + re.NotNil(resp.RecoveryPlan) + re.NotNil(resp.RecoveryPlan.ForceLeader) + re.Equal(1, len(resp.RecoveryPlan.ForceLeader.EnterForceLeaders)) + re.Equal(uint64(1001), resp.RecoveryPlan.ForceLeader.EnterForceLeaders[0]) + re.NotNil(resp.RecoveryPlan.ForceLeader.FailedStores) + applyRecoveryPlan(re, 1, reports, resp) recoveryController.HandleStoreHeartbeat(req, resp) - c.Assert(recoveryController.GetStage(), Equals, demoteFailedVoter) + re.Equal(demoteFailedVoter, recoveryController.GetStage()) } -func (s *testUnsafeRecoverySuite) TestOneLearner(c *C) { +func TestOneLearner(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, _ := newTestScheduleConfig() - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, hbstream.NewTestHeartbeatStreams(s.ctx, cluster.meta.GetId(), cluster, true)) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, hbstream.NewTestHeartbeatStreams(ctx, cluster.meta.GetId(), cluster, true)) cluster.coordinator.run() for _, store := range newTestStores(3, "6.0.0") { - c.Assert(cluster.PutStore(store.GetMeta()), IsNil) + re.NoError(cluster.PutStore(store.GetMeta())) } recoveryController := newUnsafeRecoveryController(cluster) - c.Assert(recoveryController.RemoveFailedStores(map[uint64]struct{}{ + re.NoError(recoveryController.RemoveFailedStores(map[uint64]struct{}{ 2: {}, 3: {}, - }, 60), IsNil) + }, 60)) reports := map[uint64]*pdpb.StoreReport{ 1: {PeerReports: []*pdpb.PeerReport{ @@ -535,7 +545,7 @@ func (s *testUnsafeRecoverySuite) TestOneLearner(c *C) { }}, } - advanceUntilFinished(c, recoveryController, reports) + advanceUntilFinished(re, recoveryController, reports) expects := map[uint64]*pdpb.StoreReport{ 1: {PeerReports: []*pdpb.PeerReport{ @@ -552,29 +562,33 @@ func (s *testUnsafeRecoverySuite) TestOneLearner(c *C) { for storeID, report := range reports { if result, ok := expects[storeID]; ok { - c.Assert(report.PeerReports, DeepEquals, result.PeerReports) + re.Equal(result.PeerReports, report.PeerReports) } else { - c.Assert(len(report.PeerReports), Equals, 0) + re.Empty(len(report.PeerReports)) } } } -func (s *testUnsafeRecoverySuite) TestTiflashLearnerPeer(c *C) { +func TestTiflashLearnerPeer(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, _ := newTestScheduleConfig() - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, hbstream.NewTestHeartbeatStreams(s.ctx, cluster.meta.GetId(), cluster, true)) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, hbstream.NewTestHeartbeatStreams(ctx, cluster.meta.GetId(), cluster, true)) cluster.coordinator.run() for _, store := range newTestStores(5, "6.0.0") { if store.GetID() == 3 { store.GetMeta().Labels = []*metapb.StoreLabel{{Key: "engine", Value: "tiflash"}} } - c.Assert(cluster.PutStore(store.GetMeta()), IsNil) + re.NoError(cluster.PutStore(store.GetMeta())) } recoveryController := newUnsafeRecoveryController(cluster) - c.Assert(recoveryController.RemoveFailedStores(map[uint64]struct{}{ + re.NoError(recoveryController.RemoveFailedStores(map[uint64]struct{}{ 4: {}, 5: {}, - }, 60), IsNil) + }, 60)) reports := map[uint64]*pdpb.StoreReport{ 1: {PeerReports: []*pdpb.PeerReport{ @@ -650,7 +664,7 @@ func (s *testUnsafeRecoverySuite) TestTiflashLearnerPeer(c *C) { }}, } - advanceUntilFinished(c, recoveryController, reports) + advanceUntilFinished(re, recoveryController, reports) expects := map[uint64]*pdpb.StoreReport{ 1: {PeerReports: []*pdpb.PeerReport{ @@ -704,7 +718,7 @@ func (s *testUnsafeRecoverySuite) TestTiflashLearnerPeer(c *C) { for i, p := range report.PeerReports { // As the store of newly created region is not fixed, check it separately if p.RegionState.Region.GetId() == 1 { - c.Assert(p, DeepEquals, &pdpb.PeerReport{ + re.Equal(&pdpb.PeerReport{ RaftState: &raft_serverpb.RaftLocalState{LastIndex: 10, HardState: &eraftpb.HardState{Term: 1, Commit: 10}}, RegionState: &raft_serverpb.RegionLocalState{ Region: &metapb.Region{ @@ -717,32 +731,36 @@ func (s *testUnsafeRecoverySuite) TestTiflashLearnerPeer(c *C) { }, }, }, - }) + }, p) report.PeerReports = append(report.PeerReports[:i], report.PeerReports[i+1:]...) break } } if result, ok := expects[storeID]; ok { - c.Assert(report.PeerReports, DeepEquals, result.PeerReports) + re.Equal(result.PeerReports, report.PeerReports) } else { - c.Assert(len(report.PeerReports), Equals, 0) + re.Empty(len(report.PeerReports)) } } } -func (s *testUnsafeRecoverySuite) TestUninitializedPeer(c *C) { +func TestUninitializedPeer(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, _ := newTestScheduleConfig() - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, hbstream.NewTestHeartbeatStreams(s.ctx, cluster.meta.GetId(), cluster, true)) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, hbstream.NewTestHeartbeatStreams(ctx, cluster.meta.GetId(), cluster, true)) cluster.coordinator.run() for _, store := range newTestStores(3, "6.0.0") { - c.Assert(cluster.PutStore(store.GetMeta()), IsNil) + re.NoError(cluster.PutStore(store.GetMeta())) } recoveryController := newUnsafeRecoveryController(cluster) - c.Assert(recoveryController.RemoveFailedStores(map[uint64]struct{}{ + re.NoError(recoveryController.RemoveFailedStores(map[uint64]struct{}{ 2: {}, 3: {}, - }, 60), IsNil) + }, 60)) reports := map[uint64]*pdpb.StoreReport{ 1: {PeerReports: []*pdpb.PeerReport{ @@ -756,7 +774,7 @@ func (s *testUnsafeRecoverySuite) TestUninitializedPeer(c *C) { }}, } - advanceUntilFinished(c, recoveryController, reports) + advanceUntilFinished(re, recoveryController, reports) expects := map[uint64]*pdpb.StoreReport{ 1: {PeerReports: []*pdpb.PeerReport{ @@ -775,26 +793,30 @@ func (s *testUnsafeRecoverySuite) TestUninitializedPeer(c *C) { for storeID, report := range reports { if result, ok := expects[storeID]; ok { - c.Assert(report.PeerReports, DeepEquals, result.PeerReports) + re.Equal(result.PeerReports, report.PeerReports) } else { - c.Assert(len(report.PeerReports), Equals, 0) + re.Empty(report.PeerReports) } } } -func (s *testUnsafeRecoverySuite) TestJointState(c *C) { +func TestJointState(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, _ := newTestScheduleConfig() - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, hbstream.NewTestHeartbeatStreams(s.ctx, cluster.meta.GetId(), cluster, true)) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, hbstream.NewTestHeartbeatStreams(ctx, cluster.meta.GetId(), cluster, true)) cluster.coordinator.run() for _, store := range newTestStores(5, "6.0.0") { - c.Assert(cluster.PutStore(store.GetMeta()), IsNil) + re.NoError(cluster.PutStore(store.GetMeta())) } recoveryController := newUnsafeRecoveryController(cluster) - c.Assert(recoveryController.RemoveFailedStores(map[uint64]struct{}{ + re.NoError(recoveryController.RemoveFailedStores(map[uint64]struct{}{ 4: {}, 5: {}, - }, 60), IsNil) + }, 60)) reports := map[uint64]*pdpb.StoreReport{ 1: {PeerReports: []*pdpb.PeerReport{ @@ -878,7 +900,7 @@ func (s *testUnsafeRecoverySuite) TestJointState(c *C) { }}, } - advanceUntilFinished(c, recoveryController, reports) + advanceUntilFinished(re, recoveryController, reports) expects := map[uint64]*pdpb.StoreReport{ 1: {PeerReports: []*pdpb.PeerReport{ @@ -964,50 +986,58 @@ func (s *testUnsafeRecoverySuite) TestJointState(c *C) { for storeID, report := range reports { if result, ok := expects[storeID]; ok { - c.Assert(report.PeerReports, DeepEquals, result.PeerReports) + re.Equal(result.PeerReports, report.PeerReports) } else { - c.Assert(len(report.PeerReports), Equals, 0) + re.Empty(len(report.PeerReports)) } } } -func (s *testUnsafeRecoverySuite) TestTimeout(c *C) { +func TestTimeout(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, _ := newTestScheduleConfig() - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, hbstream.NewTestHeartbeatStreams(s.ctx, cluster.meta.GetId(), cluster, true)) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, hbstream.NewTestHeartbeatStreams(ctx, cluster.meta.GetId(), cluster, true)) cluster.coordinator.run() for _, store := range newTestStores(3, "6.0.0") { - c.Assert(cluster.PutStore(store.GetMeta()), IsNil) + re.NoError(cluster.PutStore(store.GetMeta())) } recoveryController := newUnsafeRecoveryController(cluster) - c.Assert(recoveryController.RemoveFailedStores(map[uint64]struct{}{ + re.NoError(recoveryController.RemoveFailedStores(map[uint64]struct{}{ 2: {}, 3: {}, - }, 1), IsNil) + }, 1)) time.Sleep(time.Second) req := newStoreHeartbeat(1, nil) resp := &pdpb.StoreHeartbeatResponse{} recoveryController.HandleStoreHeartbeat(req, resp) - c.Assert(recoveryController.GetStage(), Equals, exitForceLeader) + re.Equal(exitForceLeader, recoveryController.GetStage()) req.StoreReport = &pdpb.StoreReport{Step: 2} recoveryController.HandleStoreHeartbeat(req, resp) - c.Assert(recoveryController.GetStage(), Equals, failed) + re.Equal(failed, recoveryController.GetStage()) } -func (s *testUnsafeRecoverySuite) TestExitForceLeader(c *C) { +func TestExitForceLeader(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, _ := newTestScheduleConfig() - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, hbstream.NewTestHeartbeatStreams(s.ctx, cluster.meta.GetId(), cluster, true)) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, hbstream.NewTestHeartbeatStreams(ctx, cluster.meta.GetId(), cluster, true)) cluster.coordinator.run() for _, store := range newTestStores(3, "6.0.0") { - c.Assert(cluster.PutStore(store.GetMeta()), IsNil) + re.NoError(cluster.PutStore(store.GetMeta())) } recoveryController := newUnsafeRecoveryController(cluster) - c.Assert(recoveryController.RemoveFailedStores(map[uint64]struct{}{ + re.NoError(recoveryController.RemoveFailedStores(map[uint64]struct{}{ 2: {}, 3: {}, - }, 60), IsNil) + }, 60)) reports := map[uint64]*pdpb.StoreReport{ 1: { @@ -1032,18 +1062,18 @@ func (s *testUnsafeRecoverySuite) TestExitForceLeader(c *C) { req.StoreReport = report resp := &pdpb.StoreHeartbeatResponse{} recoveryController.HandleStoreHeartbeat(req, resp) - applyRecoveryPlan(c, storeID, reports, resp) + applyRecoveryPlan(re, storeID, reports, resp) } - c.Assert(recoveryController.GetStage(), Equals, exitForceLeader) + re.Equal(exitForceLeader, recoveryController.GetStage()) for storeID, report := range reports { req := newStoreHeartbeat(storeID, report) req.StoreReport = report resp := &pdpb.StoreHeartbeatResponse{} recoveryController.HandleStoreHeartbeat(req, resp) - applyRecoveryPlan(c, storeID, reports, resp) + applyRecoveryPlan(re, storeID, reports, resp) } - c.Assert(recoveryController.GetStage(), Equals, finished) + re.Equal(finished, recoveryController.GetStage()) expects := map[uint64]*pdpb.StoreReport{ 1: { @@ -1062,26 +1092,30 @@ func (s *testUnsafeRecoverySuite) TestExitForceLeader(c *C) { for storeID, report := range reports { if result, ok := expects[storeID]; ok { - c.Assert(report.PeerReports, DeepEquals, result.PeerReports) + re.Equal(result.PeerReports, report.PeerReports) } else { - c.Assert(len(report.PeerReports), Equals, 0) + re.Empty(len(report.PeerReports)) } } } -func (s *testUnsafeRecoverySuite) TestStep(c *C) { +func TestStep(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, _ := newTestScheduleConfig() - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, hbstream.NewTestHeartbeatStreams(s.ctx, cluster.meta.GetId(), cluster, true)) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, hbstream.NewTestHeartbeatStreams(ctx, cluster.meta.GetId(), cluster, true)) cluster.coordinator.run() for _, store := range newTestStores(3, "6.0.0") { - c.Assert(cluster.PutStore(store.GetMeta()), IsNil) + re.NoError(cluster.PutStore(store.GetMeta())) } recoveryController := newUnsafeRecoveryController(cluster) - c.Assert(recoveryController.RemoveFailedStores(map[uint64]struct{}{ + re.NoError(recoveryController.RemoveFailedStores(map[uint64]struct{}{ 2: {}, 3: {}, - }, 60), IsNil) + }, 60)) reports := map[uint64]*pdpb.StoreReport{ 1: { @@ -1102,37 +1136,41 @@ func (s *testUnsafeRecoverySuite) TestStep(c *C) { resp := &pdpb.StoreHeartbeatResponse{} recoveryController.HandleStoreHeartbeat(req, resp) // step is not set, ignore - c.Assert(recoveryController.GetStage(), Equals, collectReport) + re.Equal(collectReport, recoveryController.GetStage()) // valid store report req.StoreReport.Step = 1 recoveryController.HandleStoreHeartbeat(req, resp) - c.Assert(recoveryController.GetStage(), Equals, forceLeader) + re.Equal(forceLeader, recoveryController.GetStage()) // duplicate report with same step, ignore recoveryController.HandleStoreHeartbeat(req, resp) - c.Assert(recoveryController.GetStage(), Equals, forceLeader) - applyRecoveryPlan(c, 1, reports, resp) + re.Equal(forceLeader, recoveryController.GetStage()) + applyRecoveryPlan(re, 1, reports, resp) recoveryController.HandleStoreHeartbeat(req, resp) - c.Assert(recoveryController.GetStage(), Equals, demoteFailedVoter) - applyRecoveryPlan(c, 1, reports, resp) + re.Equal(demoteFailedVoter, recoveryController.GetStage()) + applyRecoveryPlan(re, 1, reports, resp) recoveryController.HandleStoreHeartbeat(req, resp) - c.Assert(recoveryController.GetStage(), Equals, finished) + re.Equal(finished, recoveryController.GetStage()) } -func (s *testUnsafeRecoverySuite) TestOnHealthyRegions(c *C) { +func TestOnHealthyRegions(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, _ := newTestScheduleConfig() - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, hbstream.NewTestHeartbeatStreams(s.ctx, cluster.meta.GetId(), cluster, true)) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, hbstream.NewTestHeartbeatStreams(ctx, cluster.meta.GetId(), cluster, true)) cluster.coordinator.run() for _, store := range newTestStores(5, "6.0.0") { - c.Assert(cluster.PutStore(store.GetMeta()), IsNil) + re.NoError(cluster.PutStore(store.GetMeta())) } recoveryController := newUnsafeRecoveryController(cluster) - c.Assert(recoveryController.RemoveFailedStores(map[uint64]struct{}{ + re.NoError(recoveryController.RemoveFailedStores(map[uint64]struct{}{ 4: {}, 5: {}, - }, 60), IsNil) + }, 60)) reports := map[uint64]*pdpb.StoreReport{ 1: {PeerReports: []*pdpb.PeerReport{ @@ -1166,17 +1204,17 @@ func (s *testUnsafeRecoverySuite) TestOnHealthyRegions(c *C) { {Id: 11, StoreId: 1}, {Id: 21, StoreId: 2}, {Id: 31, StoreId: 3}}}}}, }}, } - c.Assert(recoveryController.GetStage(), Equals, collectReport) + re.Equal(collectReport, recoveryController.GetStage()) // require peer report for storeID := range reports { req := newStoreHeartbeat(storeID, nil) resp := &pdpb.StoreHeartbeatResponse{} recoveryController.HandleStoreHeartbeat(req, resp) - c.Assert(resp.RecoveryPlan, NotNil) - c.Assert(len(resp.RecoveryPlan.Creates), Equals, 0) - c.Assert(len(resp.RecoveryPlan.Demotes), Equals, 0) - c.Assert(resp.RecoveryPlan.ForceLeader, IsNil) - applyRecoveryPlan(c, storeID, reports, resp) + re.NotNil(resp.RecoveryPlan) + re.Empty(len(resp.RecoveryPlan.Creates)) + re.Empty(len(resp.RecoveryPlan.Demotes)) + re.Nil(resp.RecoveryPlan.ForceLeader) + applyRecoveryPlan(re, storeID, reports, resp) } // receive all reports and dispatch no plan @@ -1185,26 +1223,30 @@ func (s *testUnsafeRecoverySuite) TestOnHealthyRegions(c *C) { req.StoreReport = report resp := &pdpb.StoreHeartbeatResponse{} recoveryController.HandleStoreHeartbeat(req, resp) - c.Assert(resp.RecoveryPlan, IsNil) - applyRecoveryPlan(c, storeID, reports, resp) + re.Nil(resp.RecoveryPlan) + applyRecoveryPlan(re, storeID, reports, resp) } // nothing to do, finish directly - c.Assert(recoveryController.GetStage(), Equals, finished) + re.Equal(finished, recoveryController.GetStage()) } -func (s *testUnsafeRecoverySuite) TestCreateEmptyRegion(c *C) { +func TestCreateEmptyRegion(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, _ := newTestScheduleConfig() - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, hbstream.NewTestHeartbeatStreams(s.ctx, cluster.meta.GetId(), cluster, true)) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, hbstream.NewTestHeartbeatStreams(ctx, cluster.meta.GetId(), cluster, true)) cluster.coordinator.run() for _, store := range newTestStores(3, "6.0.0") { - c.Assert(cluster.PutStore(store.GetMeta()), IsNil) + re.NoError(cluster.PutStore(store.GetMeta())) } recoveryController := newUnsafeRecoveryController(cluster) - c.Assert(recoveryController.RemoveFailedStores(map[uint64]struct{}{ + re.NoError(recoveryController.RemoveFailedStores(map[uint64]struct{}{ 2: {}, 3: {}, - }, 60), IsNil) + }, 60)) reports := map[uint64]*pdpb.StoreReport{ 1: {PeerReports: []*pdpb.PeerReport{ @@ -1231,7 +1273,7 @@ func (s *testUnsafeRecoverySuite) TestCreateEmptyRegion(c *C) { }}, } - advanceUntilFinished(c, recoveryController, reports) + advanceUntilFinished(re, recoveryController, reports) expects := map[uint64]*pdpb.StoreReport{ 1: {PeerReports: []*pdpb.PeerReport{ @@ -1278,9 +1320,9 @@ func (s *testUnsafeRecoverySuite) TestCreateEmptyRegion(c *C) { for storeID, report := range reports { if expect, ok := expects[storeID]; ok { - c.Assert(report.PeerReports, DeepEquals, expect.PeerReports) + re.Equal(expect.PeerReports, report.PeerReports) } else { - c.Assert(len(report.PeerReports), Equals, 0) + re.Empty(len(report.PeerReports)) } } } @@ -1297,19 +1339,23 @@ func (s *testUnsafeRecoverySuite) TestCreateEmptyRegion(c *C) { // | Store 4, 5 and 6 fail | A=[a,m), B=[m,z) | A=[a,z) | C=[a,g) | fail | fail | fail | // +──────────────────────────────────+───────────────────+───────────────────+───────────────────+───────────────────+──────────+──────────+ -func (s *testUnsafeRecoverySuite) TestRangeOverlap1(c *C) { +func TestRangeOverlap1(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, _ := newTestScheduleConfig() - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, hbstream.NewTestHeartbeatStreams(s.ctx, cluster.meta.GetId(), cluster, true)) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, hbstream.NewTestHeartbeatStreams(ctx, cluster.meta.GetId(), cluster, true)) cluster.coordinator.run() for _, store := range newTestStores(5, "6.0.0") { - c.Assert(cluster.PutStore(store.GetMeta()), IsNil) + re.NoError(cluster.PutStore(store.GetMeta())) } recoveryController := newUnsafeRecoveryController(cluster) - c.Assert(recoveryController.RemoveFailedStores(map[uint64]struct{}{ + re.NoError(recoveryController.RemoveFailedStores(map[uint64]struct{}{ 4: {}, 5: {}, - }, 60), IsNil) + }, 60)) reports := map[uint64]*pdpb.StoreReport{ 1: {PeerReports: []*pdpb.PeerReport{ @@ -1350,7 +1396,7 @@ func (s *testUnsafeRecoverySuite) TestRangeOverlap1(c *C) { }}, } - advanceUntilFinished(c, recoveryController, reports) + advanceUntilFinished(re, recoveryController, reports) expects := map[uint64]*pdpb.StoreReport{ 1: {PeerReports: []*pdpb.PeerReport{ @@ -1381,26 +1427,30 @@ func (s *testUnsafeRecoverySuite) TestRangeOverlap1(c *C) { for storeID, report := range reports { if result, ok := expects[storeID]; ok { - c.Assert(report.PeerReports, DeepEquals, result.PeerReports) + re.Equal(result.PeerReports, report.PeerReports) } else { - c.Assert(len(report.PeerReports), Equals, 0) + re.Empty(len(report.PeerReports)) } } } -func (s *testUnsafeRecoverySuite) TestRangeOverlap2(c *C) { +func TestRangeOverlap2(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, _ := newTestScheduleConfig() - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, hbstream.NewTestHeartbeatStreams(s.ctx, cluster.meta.GetId(), cluster, true)) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, hbstream.NewTestHeartbeatStreams(ctx, cluster.meta.GetId(), cluster, true)) cluster.coordinator.run() for _, store := range newTestStores(5, "6.0.0") { - c.Assert(cluster.PutStore(store.GetMeta()), IsNil) + re.NoError(cluster.PutStore(store.GetMeta())) } recoveryController := newUnsafeRecoveryController(cluster) - c.Assert(recoveryController.RemoveFailedStores(map[uint64]struct{}{ + re.NoError(recoveryController.RemoveFailedStores(map[uint64]struct{}{ 4: {}, 5: {}, - }, 60), IsNil) + }, 60)) reports := map[uint64]*pdpb.StoreReport{ 1: {PeerReports: []*pdpb.PeerReport{ @@ -1441,7 +1491,7 @@ func (s *testUnsafeRecoverySuite) TestRangeOverlap2(c *C) { }}, } - advanceUntilFinished(c, recoveryController, reports) + advanceUntilFinished(re, recoveryController, reports) expects := map[uint64]*pdpb.StoreReport{ 1: {PeerReports: []*pdpb.PeerReport{ @@ -1471,72 +1521,80 @@ func (s *testUnsafeRecoverySuite) TestRangeOverlap2(c *C) { for storeID, report := range reports { if result, ok := expects[storeID]; ok { - c.Assert(report.PeerReports, DeepEquals, result.PeerReports) + re.Equal(result.PeerReports, report.PeerReports) } else { - c.Assert(len(report.PeerReports), Equals, 0) + re.Empty(report.PeerReports) } } } -func (s *testUnsafeRecoverySuite) TestRemoveFailedStores(c *C) { +func TestRemoveFailedStores(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, _ := newTestScheduleConfig() - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = newCoordinator(s.ctx, cluster, hbstream.NewTestHeartbeatStreams(s.ctx, cluster.meta.GetId(), cluster, true)) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster.coordinator = newCoordinator(ctx, cluster, hbstream.NewTestHeartbeatStreams(ctx, cluster.meta.GetId(), cluster, true)) cluster.coordinator.run() stores := newTestStores(2, "5.3.0") stores[1] = stores[1].Clone(core.SetLastHeartbeatTS(time.Now())) for _, store := range stores { - c.Assert(cluster.PutStore(store.GetMeta()), IsNil) + re.NoError(cluster.PutStore(store.GetMeta())) } recoveryController := newUnsafeRecoveryController(cluster) // Store 3 doesn't exist, reject to remove. - c.Assert(recoveryController.RemoveFailedStores(map[uint64]struct{}{ + re.Error(recoveryController.RemoveFailedStores(map[uint64]struct{}{ 1: {}, 3: {}, - }, 60), NotNil) + }, 60)) - c.Assert(recoveryController.RemoveFailedStores(map[uint64]struct{}{ + re.NoError(recoveryController.RemoveFailedStores(map[uint64]struct{}{ 1: {}, - }, 60), IsNil) - c.Assert(cluster.GetStore(uint64(1)).IsRemoved(), IsTrue) + }, 60)) + re.True(cluster.GetStore(uint64(1)).IsRemoved()) for _, s := range cluster.GetSchedulers() { paused, err := cluster.IsSchedulerAllowed(s) if s != "split-bucket-scheduler" { - c.Assert(err, IsNil) - c.Assert(paused, IsTrue) + re.NoError(err) + re.True(paused) } } // Store 2's last heartbeat is recent, and is not allowed to be removed. - c.Assert(recoveryController.RemoveFailedStores( + re.Error(recoveryController.RemoveFailedStores( map[uint64]struct{}{ 2: {}, - }, 60), NotNil) + }, 60)) } -func (s *testUnsafeRecoverySuite) TestSplitPaused(c *C) { +func TestSplitPaused(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, opt, _ := newTestScheduleConfig() - cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) + cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) recoveryController := newUnsafeRecoveryController(cluster) cluster.Lock() cluster.unsafeRecoveryController = recoveryController - cluster.coordinator = newCoordinator(s.ctx, cluster, hbstream.NewTestHeartbeatStreams(s.ctx, cluster.meta.GetId(), cluster, true)) + cluster.coordinator = newCoordinator(ctx, cluster, hbstream.NewTestHeartbeatStreams(ctx, cluster.meta.GetId(), cluster, true)) cluster.Unlock() cluster.coordinator.run() stores := newTestStores(2, "5.3.0") stores[1] = stores[1].Clone(core.SetLastHeartbeatTS(time.Now())) for _, store := range stores { - c.Assert(cluster.PutStore(store.GetMeta()), IsNil) + re.NoError(cluster.PutStore(store.GetMeta())) } failedStores := map[uint64]struct{}{ 1: {}, } - c.Assert(recoveryController.RemoveFailedStores(failedStores, 60), IsNil) + re.NoError(recoveryController.RemoveFailedStores(failedStores, 60)) askSplitReq := &pdpb.AskSplitRequest{} _, err := cluster.HandleAskSplit(askSplitReq) - c.Assert(err.Error(), Equals, "[PD:unsaferecovery:ErrUnsafeRecoveryIsRunning]unsafe recovery is running") + re.Equal("[PD:unsaferecovery:ErrUnsafeRecoveryIsRunning]unsafe recovery is running", err.Error()) askBatchSplitReq := &pdpb.AskBatchSplitRequest{} _, err = cluster.HandleAskBatchSplit(askBatchSplitReq) - c.Assert(err.Error(), Equals, "[PD:unsaferecovery:ErrUnsafeRecoveryIsRunning]unsafe recovery is running") + re.Equal("[PD:unsaferecovery:ErrUnsafeRecoveryIsRunning]unsafe recovery is running", err.Error()) } From 1be84d652654989c9bcf22092c24557148c97058 Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Thu, 23 Jun 2022 20:00:37 +0800 Subject: [PATCH 75/82] *: remove duplicated functions (#5224) ref tikv/pd#4813 Signed-off-by: Ryan Leung --- pkg/assertutil/assertutil.go | 16 +++++++++++++++- server/api/server_test.go | 3 ++- server/api/version_test.go | 13 +------------ server/join/join_test.go | 13 +------------ server/server_test.go | 23 ++++++----------------- tests/client/client_test.go | 13 +------------ tests/pdctl/global_test.go | 3 ++- tests/pdctl/helper.go | 12 ------------ tests/server/member/member_test.go | 13 +------------ 9 files changed, 29 insertions(+), 80 deletions(-) diff --git a/pkg/assertutil/assertutil.go b/pkg/assertutil/assertutil.go index 5da16155674..9eb2719b220 100644 --- a/pkg/assertutil/assertutil.go +++ b/pkg/assertutil/assertutil.go @@ -14,6 +14,8 @@ package assertutil +import "github.com/stretchr/testify/require" + // Checker accepts the injection of check functions and context from test files. // Any check function should be set before usage unless the test will fail. type Checker struct { @@ -21,11 +23,23 @@ type Checker struct { FailNow func() } -// NewChecker creates Checker with FailNow function. +// NewChecker creates Checker. func NewChecker() *Checker { return &Checker{} } +// CheckerWithNilAssert creates Checker with nil assert function. +func CheckerWithNilAssert(re *require.Assertions) *Checker { + checker := NewChecker() + checker.FailNow = func() { + re.FailNow("should be nil") + } + checker.IsNil = func(obtained interface{}) { + re.Nil(obtained) + } + return checker +} + // AssertNil calls the injected IsNil assertion. func (c *Checker) AssertNil(obtained interface{}) { if c.IsNil == nil { diff --git a/server/api/server_test.go b/server/api/server_test.go index 273f62cab54..b82dfc5ea21 100644 --- a/server/api/server_test.go +++ b/server/api/server_test.go @@ -27,6 +27,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/apiutil" + "github.com/tikv/pd/pkg/assertutil" "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server" "github.com/tikv/pd/server/config" @@ -78,7 +79,7 @@ var zapLogOnce sync.Once func mustNewCluster(re *require.Assertions, num int, opts ...func(cfg *config.Config)) ([]*config.Config, []*server.Server, cleanUpFunc) { ctx, cancel := context.WithCancel(context.Background()) svrs := make([]*server.Server, 0, num) - cfgs := server.NewTestMultiConfig(checkerWithNilAssert(re), num) + cfgs := server.NewTestMultiConfig(assertutil.CheckerWithNilAssert(re), num) ch := make(chan *server.Server, num) for _, cfg := range cfgs { diff --git a/server/api/version_test.go b/server/api/version_test.go index 41254649c34..9973a871b05 100644 --- a/server/api/version_test.go +++ b/server/api/version_test.go @@ -30,17 +30,6 @@ import ( "github.com/tikv/pd/server/config" ) -func checkerWithNilAssert(re *require.Assertions) *assertutil.Checker { - checker := assertutil.NewChecker() - checker.FailNow = func() { - re.FailNow("should be nil") - } - checker.IsNil = func(obtained interface{}) { - re.Nil(obtained) - } - return checker -} - func TestGetVersion(t *testing.T) { // TODO: enable it. t.Skip("Temporary disable. See issue: https://github.com/tikv/pd/issues/1893") @@ -51,7 +40,7 @@ func TestGetVersion(t *testing.T) { temp, _ := os.Create(fname) os.Stdout = temp - cfg := server.NewTestSingleConfig(checkerWithNilAssert(re)) + cfg := server.NewTestSingleConfig(assertutil.CheckerWithNilAssert(re)) reqCh := make(chan struct{}) go func() { <-reqCh diff --git a/server/join/join_test.go b/server/join/join_test.go index b8f001b5398..1dbdd7d374f 100644 --- a/server/join/join_test.go +++ b/server/join/join_test.go @@ -23,21 +23,10 @@ import ( "github.com/tikv/pd/server" ) -func checkerWithNilAssert(re *require.Assertions) *assertutil.Checker { - checker := assertutil.NewChecker() - checker.FailNow = func() { - re.FailNow("") - } - checker.IsNil = func(obtained interface{}) { - re.Nil(obtained) - } - return checker -} - // A PD joins itself. func TestPDJoinsItself(t *testing.T) { re := require.New(t) - cfg := server.NewTestSingleConfig(checkerWithNilAssert(re)) + cfg := server.NewTestSingleConfig(assertutil.CheckerWithNilAssert(re)) defer testutil.CleanServer(cfg.DataDir) cfg.Join = cfg.AdvertiseClientUrls re.Error(PrepareJoinCluster(cfg)) diff --git a/server/server_test.go b/server/server_test.go index 58f572fb2df..f520314a5b1 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -63,22 +63,11 @@ func (suite *leaderServerTestSuite) mustWaitLeader(svrs []*Server) *Server { return leader } -func (suite *leaderServerTestSuite) checkerWithNilAssert() *assertutil.Checker { - checker := assertutil.NewChecker() - checker.FailNow = func() { - suite.FailNow("should be nil") - } - checker.IsNil = func(obtained interface{}) { - suite.Nil(obtained) - } - return checker -} - func (suite *leaderServerTestSuite) SetupSuite() { suite.ctx, suite.cancel = context.WithCancel(context.Background()) suite.svrs = make(map[string]*Server) - cfgs := NewTestMultiConfig(suite.checkerWithNilAssert(), 3) + cfgs := NewTestMultiConfig(assertutil.CheckerWithNilAssert(suite.Require()), 3) ch := make(chan *Server, 3) for i := 0; i < 3; i++ { @@ -153,7 +142,7 @@ func (suite *leaderServerTestSuite) newTestServersWithCfgs(ctx context.Context, func (suite *leaderServerTestSuite) TestCheckClusterID() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - cfgs := NewTestMultiConfig(suite.checkerWithNilAssert(), 2) + cfgs := NewTestMultiConfig(assertutil.CheckerWithNilAssert(suite.Require()), 2) for i, cfg := range cfgs { cfg.DataDir = fmt.Sprintf("/tmp/test_pd_check_clusterID_%d", i) // Clean up before testing. @@ -209,7 +198,7 @@ func (suite *leaderServerTestSuite) TestRegisterServerHandler() { } return mux, info, nil } - cfg := NewTestSingleConfig(suite.checkerWithNilAssert()) + cfg := NewTestSingleConfig(assertutil.CheckerWithNilAssert(suite.Require())) ctx, cancel := context.WithCancel(context.Background()) svr, err := CreateServer(ctx, cfg, mokHandler) suite.NoError(err) @@ -248,7 +237,7 @@ func (suite *leaderServerTestSuite) TestSourceIpForHeaderForwarded() { } return mux, info, nil } - cfg := NewTestSingleConfig(suite.checkerWithNilAssert()) + cfg := NewTestSingleConfig(assertutil.CheckerWithNilAssert(suite.Require())) ctx, cancel := context.WithCancel(context.Background()) svr, err := CreateServer(ctx, cfg, mokHandler) suite.NoError(err) @@ -291,7 +280,7 @@ func (suite *leaderServerTestSuite) TestSourceIpForHeaderXReal() { } return mux, info, nil } - cfg := NewTestSingleConfig(suite.checkerWithNilAssert()) + cfg := NewTestSingleConfig(assertutil.CheckerWithNilAssert(suite.Require())) ctx, cancel := context.WithCancel(context.Background()) svr, err := CreateServer(ctx, cfg, mokHandler) suite.NoError(err) @@ -334,7 +323,7 @@ func (suite *leaderServerTestSuite) TestSourceIpForHeaderBoth() { } return mux, info, nil } - cfg := NewTestSingleConfig(suite.checkerWithNilAssert()) + cfg := NewTestSingleConfig(assertutil.CheckerWithNilAssert(suite.Require())) ctx, cancel := context.WithCancel(context.Background()) svr, err := CreateServer(ctx, cfg, mokHandler) suite.NoError(err) diff --git a/tests/client/client_test.go b/tests/client/client_test.go index 003b4f73c32..8e67cfb4949 100644 --- a/tests/client/client_test.go +++ b/tests/client/client_test.go @@ -690,7 +690,7 @@ func TestClientTestSuite(t *testing.T) { func (suite *clientTestSuite) SetupSuite() { var err error re := suite.Require() - suite.srv, suite.cleanup, err = server.NewTestServer(suite.checkerWithNilAssert()) + suite.srv, suite.cleanup, err = server.NewTestServer(assertutil.CheckerWithNilAssert(re)) suite.NoError(err) suite.grpcPDClient = testutil.MustNewGrpcClient(re, suite.srv.GetAddr()) suite.grpcSvr = &server.GrpcServer{Server: suite.srv} @@ -728,17 +728,6 @@ func (suite *clientTestSuite) TearDownSuite() { suite.cleanup() } -func (suite *clientTestSuite) checkerWithNilAssert() *assertutil.Checker { - checker := assertutil.NewChecker() - checker.FailNow = func() { - suite.FailNow("should be nil") - } - checker.IsNil = func(obtained interface{}) { - suite.Nil(obtained) - } - return checker -} - func (suite *clientTestSuite) mustWaitLeader(svrs map[string]*server.Server) *server.Server { for i := 0; i < 500; i++ { for _, s := range svrs { diff --git a/tests/pdctl/global_test.go b/tests/pdctl/global_test.go index c182c739403..a13fee11441 100644 --- a/tests/pdctl/global_test.go +++ b/tests/pdctl/global_test.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/log" "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/apiutil" + "github.com/tikv/pd/pkg/assertutil" "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server" cmd "github.com/tikv/pd/tools/pd-ctl/pdctl" @@ -47,7 +48,7 @@ func TestSendAndGetComponent(t *testing.T) { } return mux, info, nil } - cfg := server.NewTestSingleConfig(checkerWithNilAssert(re)) + cfg := server.NewTestSingleConfig(assertutil.CheckerWithNilAssert(re)) ctx, cancel := context.WithCancel(context.Background()) svr, err := server.CreateServer(ctx, cfg, handler) re.NoError(err) diff --git a/tests/pdctl/helper.go b/tests/pdctl/helper.go index 9a6adf566a3..5691dde66ca 100644 --- a/tests/pdctl/helper.go +++ b/tests/pdctl/helper.go @@ -25,7 +25,6 @@ import ( "github.com/pingcap/kvproto/pkg/pdpb" "github.com/spf13/cobra" "github.com/stretchr/testify/require" - "github.com/tikv/pd/pkg/assertutil" "github.com/tikv/pd/server" "github.com/tikv/pd/server/api" "github.com/tikv/pd/server/core" @@ -122,14 +121,3 @@ func MustPutRegion(re *require.Assertions, cluster *tests.TestCluster, regionID, re.NoError(err) return r } - -func checkerWithNilAssert(re *require.Assertions) *assertutil.Checker { - checker := assertutil.NewChecker() - checker.FailNow = func() { - re.FailNow("should be nil") - } - checker.IsNil = func(obtained interface{}) { - re.Nil(obtained) - } - return checker -} diff --git a/tests/server/member/member_test.go b/tests/server/member/member_test.go index 552a2d0c221..229b4756045 100644 --- a/tests/server/member/member_test.go +++ b/tests/server/member/member_test.go @@ -42,17 +42,6 @@ func TestMain(m *testing.M) { goleak.VerifyTestMain(m, testutil.LeakOptions...) } -func checkerWithNilAssert(re *require.Assertions) *assertutil.Checker { - checker := assertutil.NewChecker() - checker.FailNow = func() { - re.FailNow("should be nil") - } - checker.IsNil = func(obtained interface{}) { - re.Nil(obtained) - } - return checker -} - func TestMemberDelete(t *testing.T) { re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) @@ -314,7 +303,7 @@ func TestGetLeader(t *testing.T) { re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - cfg := server.NewTestSingleConfig(checkerWithNilAssert(re)) + cfg := server.NewTestSingleConfig(assertutil.CheckerWithNilAssert(re)) wg := &sync.WaitGroup{} wg.Add(1) done := make(chan bool) From 37cb7e405979fb28f46f0ba696cba283d55ffa67 Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Fri, 24 Jun 2022 10:28:37 +0800 Subject: [PATCH 76/82] schedule: migrate test framework to testify (#5196) ref tikv/pd#4813 Signed-off-by: lhy1024 Co-authored-by: Ti Chi Robot --- server/api/label_test.go | 34 +- server/schedule/healthy_test.go | 41 +- server/schedule/operator_controller_test.go | 395 ++++++++++---------- server/schedule/region_scatterer_test.go | 133 +++---- server/schedule/region_splitter_test.go | 65 ++-- server/schedule/waiting_operator_test.go | 27 +- tests/server/config/config_test.go | 2 +- 7 files changed, 350 insertions(+), 347 deletions(-) diff --git a/server/api/label_test.go b/server/api/label_test.go index b9503871a5a..6729abe45ea 100644 --- a/server/api/label_test.go +++ b/server/api/label_test.go @@ -215,7 +215,7 @@ func (suite *strictlyLabelsStoreTestSuite) SetupSuite() { } func (suite *strictlyLabelsStoreTestSuite) TestStoreMatch() { - cases := []struct { + testCases := []struct { store *metapb.Store valid bool expectError string @@ -276,21 +276,21 @@ func (suite *strictlyLabelsStoreTestSuite) TestStoreMatch() { }, } - for _, t := range cases { + for _, testCase := range testCases { _, err := suite.grpcSvr.PutStore(context.Background(), &pdpb.PutStoreRequest{ Header: &pdpb.RequestHeader{ClusterId: suite.svr.ClusterID()}, Store: &metapb.Store{ - Id: t.store.Id, - Address: fmt.Sprintf("tikv%d", t.store.Id), - State: t.store.State, - Labels: t.store.Labels, - Version: t.store.Version, + Id: testCase.store.Id, + Address: fmt.Sprintf("tikv%d", testCase.store.Id), + State: testCase.store.State, + Labels: testCase.store.Labels, + Version: testCase.store.Version, }, }) - if t.valid { + if testCase.valid { suite.NoError(err) } else { - suite.Contains(err.Error(), t.expectError) + suite.Contains(err.Error(), testCase.expectError) } } @@ -300,21 +300,21 @@ func (suite *strictlyLabelsStoreTestSuite) TestStoreMatch() { fmt.Sprintf("%s/config", suite.urlPrefix), []byte(`{"enable-placement-rules":"true"}`), tu.StatusOK(suite.Require()))) - for _, t := range cases { + for _, testCase := range testCases { _, err := suite.grpcSvr.PutStore(context.Background(), &pdpb.PutStoreRequest{ Header: &pdpb.RequestHeader{ClusterId: suite.svr.ClusterID()}, Store: &metapb.Store{ - Id: t.store.Id, - Address: fmt.Sprintf("tikv%d", t.store.Id), - State: t.store.State, - Labels: t.store.Labels, - Version: t.store.Version, + Id: testCase.store.Id, + Address: fmt.Sprintf("tikv%d", testCase.store.Id), + State: testCase.store.State, + Labels: testCase.store.Labels, + Version: testCase.store.Version, }, }) - if t.valid { + if testCase.valid { suite.NoError(err) } else { - suite.Contains(err.Error(), t.expectError) + suite.Contains(err.Error(), testCase.expectError) } } } diff --git a/server/schedule/healthy_test.go b/server/schedule/healthy_test.go index 8adb69d9ede..74b3d572fd5 100644 --- a/server/schedule/healthy_test.go +++ b/server/schedule/healthy_test.go @@ -16,31 +16,20 @@ package schedule import ( "context" + "testing" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/mock/mockcluster" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/core" ) -var _ = Suite(&testRegionHealthySuite{}) - -type testRegionHealthySuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *testRegionHealthySuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) -} - -func (s *testRegionHealthySuite) TearDownSuite(c *C) { - s.cancel() -} - -func (s *testRegionHealthySuite) TestIsRegionHealthy(c *C) { +func TestIsRegionHealthy(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() peers := func(ids ...uint64) []*metapb.Peer { var peers []*metapb.Peer for _, id := range ids { @@ -70,7 +59,7 @@ func (s *testRegionHealthySuite) TestIsRegionHealthy(c *C) { } // healthy only check down peer and pending peer - cases := []testCase{ + testCases := []testCase{ {region(peers(1, 2, 3)), true, true, true, true, true, true}, {region(peers(1, 2, 3), core.WithPendingPeers(peers(1))), false, true, true, false, true, true}, {region(peers(1, 2, 3), core.WithLearners(peers(1))), true, true, false, true, true, false}, @@ -80,19 +69,19 @@ func (s *testRegionHealthySuite) TestIsRegionHealthy(c *C) { } opt := config.NewTestOptions() - tc := mockcluster.NewCluster(s.ctx, opt) + tc := mockcluster.NewCluster(ctx, opt) tc.AddRegionStore(1, 1) tc.AddRegionStore(2, 1) tc.AddRegionStore(3, 1) tc.AddRegionStore(4, 1) - for _, t := range cases { + for _, testCase := range testCases { tc.SetEnablePlacementRules(false) - c.Assert(IsRegionHealthy(t.region), Equals, t.healthy1) - c.Assert(IsRegionHealthyAllowPending(t.region), Equals, t.healthyAllowPending1) - c.Assert(IsRegionReplicated(tc, t.region), Equals, t.replicated1) + re.Equal(testCase.healthy1, IsRegionHealthy(testCase.region)) + re.Equal(testCase.healthyAllowPending1, IsRegionHealthyAllowPending(testCase.region)) + re.Equal(testCase.replicated1, IsRegionReplicated(tc, testCase.region)) tc.SetEnablePlacementRules(true) - c.Assert(IsRegionHealthy(t.region), Equals, t.healthy2) - c.Assert(IsRegionHealthyAllowPending(t.region), Equals, t.healthyAllowPending2) - c.Assert(IsRegionReplicated(tc, t.region), Equals, t.replicated2) + re.Equal(testCase.healthy2, IsRegionHealthy(testCase.region)) + re.Equal(testCase.healthyAllowPending2, IsRegionHealthyAllowPending(testCase.region)) + re.Equal(testCase.replicated2, IsRegionReplicated(tc, testCase.region)) } } diff --git a/server/schedule/operator_controller_test.go b/server/schedule/operator_controller_test.go index acb3553fa76..cb8f8470622 100644 --- a/server/schedule/operator_controller_test.go +++ b/server/schedule/operator_controller_test.go @@ -23,10 +23,10 @@ import ( "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/mock/mockcluster" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/core" @@ -36,31 +36,31 @@ import ( "github.com/tikv/pd/server/schedule/operator" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testOperatorControllerSuite{}) +type operatorControllerTestSuite struct { + suite.Suite -type testOperatorControllerSuite struct { ctx context.Context cancel context.CancelFunc } -func (t *testOperatorControllerSuite) SetUpSuite(c *C) { - t.ctx, t.cancel = context.WithCancel(context.Background()) - c.Assert(failpoint.Enable("github.com/tikv/pd/server/schedule/unexpectedOperator", "return(true)"), IsNil) +func TestOperatorControllerTestSuite(t *testing.T) { + suite.Run(t, new(operatorControllerTestSuite)) } -func (t *testOperatorControllerSuite) TearDownSuite(c *C) { - t.cancel() +func (suite *operatorControllerTestSuite) SetupSuite() { + suite.ctx, suite.cancel = context.WithCancel(context.Background()) + suite.NoError(failpoint.Enable("github.com/tikv/pd/server/schedule/unexpectedOperator", "return(true)")) +} + +func (suite *operatorControllerTestSuite) TearDownSuite() { + suite.cancel() } // issue #1338 -func (t *testOperatorControllerSuite) TestGetOpInfluence(c *C) { +func (suite *operatorControllerTestSuite) TestGetOpInfluence() { opt := config.NewTestOptions() - tc := mockcluster.NewCluster(t.ctx, opt) - oc := NewOperatorController(t.ctx, tc, nil) + tc := mockcluster.NewCluster(suite.ctx, opt) + oc := NewOperatorController(suite.ctx, tc, nil) tc.AddLeaderStore(2, 1) tc.AddLeaderRegion(1, 1, 2) tc.AddLeaderRegion(2, 1, 2) @@ -69,21 +69,22 @@ func (t *testOperatorControllerSuite) TestGetOpInfluence(c *C) { } op1 := operator.NewTestOperator(1, &metapb.RegionEpoch{}, operator.OpRegion, steps...) op2 := operator.NewTestOperator(2, &metapb.RegionEpoch{}, operator.OpRegion, steps...) - c.Assert(op1.Start(), IsTrue) + suite.True(op1.Start()) oc.SetOperator(op1) - c.Assert(op2.Start(), IsTrue) + suite.True(op2.Start()) oc.SetOperator(op2) + re := suite.Require() go func(ctx context.Context) { - checkRemoveOperatorSuccess(c, oc, op1) + suite.checkRemoveOperatorSuccess(oc, op1) for { select { case <-ctx.Done(): return default: - c.Assert(oc.RemoveOperator(op1), IsFalse) + re.False(oc.RemoveOperator(op1)) } } - }(t.ctx) + }(suite.ctx) go func(ctx context.Context) { for { select { @@ -93,16 +94,16 @@ func (t *testOperatorControllerSuite) TestGetOpInfluence(c *C) { oc.GetOpInfluence(tc) } } - }(t.ctx) + }(suite.ctx) time.Sleep(1 * time.Second) - c.Assert(oc.GetOperator(2), NotNil) + suite.NotNil(oc.GetOperator(2)) } -func (t *testOperatorControllerSuite) TestOperatorStatus(c *C) { +func (suite *operatorControllerTestSuite) TestOperatorStatus() { opt := config.NewTestOptions() - tc := mockcluster.NewCluster(t.ctx, opt) - stream := hbstream.NewTestHeartbeatStreams(t.ctx, tc.ID, tc, false /* no need to run */) - oc := NewOperatorController(t.ctx, tc, stream) + tc := mockcluster.NewCluster(suite.ctx, opt) + stream := hbstream.NewTestHeartbeatStreams(suite.ctx, tc.ID, tc, false /* no need to run */) + oc := NewOperatorController(suite.ctx, tc, stream) tc.AddLeaderStore(1, 2) tc.AddLeaderStore(2, 0) tc.AddLeaderRegion(1, 1, 2) @@ -115,29 +116,29 @@ func (t *testOperatorControllerSuite) TestOperatorStatus(c *C) { region2 := tc.GetRegion(2) op1 := operator.NewTestOperator(1, &metapb.RegionEpoch{}, operator.OpRegion, steps...) op2 := operator.NewTestOperator(2, &metapb.RegionEpoch{}, operator.OpRegion, steps...) - c.Assert(op1.Start(), IsTrue) + suite.True(op1.Start()) oc.SetOperator(op1) - c.Assert(op2.Start(), IsTrue) + suite.True(op2.Start()) oc.SetOperator(op2) - c.Assert(oc.GetOperatorStatus(1).Status, Equals, pdpb.OperatorStatus_RUNNING) - c.Assert(oc.GetOperatorStatus(2).Status, Equals, pdpb.OperatorStatus_RUNNING) + suite.Equal(pdpb.OperatorStatus_RUNNING, oc.GetOperatorStatus(1).Status) + suite.Equal(pdpb.OperatorStatus_RUNNING, oc.GetOperatorStatus(2).Status) operator.SetOperatorStatusReachTime(op1, operator.STARTED, time.Now().Add(-10*time.Minute)) region2 = ApplyOperatorStep(region2, op2) tc.PutRegion(region2) oc.Dispatch(region1, "test") oc.Dispatch(region2, "test") - c.Assert(oc.GetOperatorStatus(1).Status, Equals, pdpb.OperatorStatus_TIMEOUT) - c.Assert(oc.GetOperatorStatus(2).Status, Equals, pdpb.OperatorStatus_RUNNING) + suite.Equal(pdpb.OperatorStatus_TIMEOUT, oc.GetOperatorStatus(1).Status) + suite.Equal(pdpb.OperatorStatus_RUNNING, oc.GetOperatorStatus(2).Status) ApplyOperator(tc, op2) oc.Dispatch(region2, "test") - c.Assert(oc.GetOperatorStatus(2).Status, Equals, pdpb.OperatorStatus_SUCCESS) + suite.Equal(pdpb.OperatorStatus_SUCCESS, oc.GetOperatorStatus(2).Status) } -func (t *testOperatorControllerSuite) TestFastFailOperator(c *C) { +func (suite *operatorControllerTestSuite) TestFastFailOperator() { opt := config.NewTestOptions() - tc := mockcluster.NewCluster(t.ctx, opt) - stream := hbstream.NewTestHeartbeatStreams(t.ctx, tc.ID, tc, false /* no need to run */) - oc := NewOperatorController(t.ctx, tc, stream) + tc := mockcluster.NewCluster(suite.ctx, opt) + stream := hbstream.NewTestHeartbeatStreams(suite.ctx, tc.ID, tc, false /* no need to run */) + oc := NewOperatorController(suite.ctx, tc, stream) tc.AddLeaderStore(1, 2) tc.AddLeaderStore(2, 0) tc.AddLeaderStore(3, 0) @@ -148,30 +149,30 @@ func (t *testOperatorControllerSuite) TestFastFailOperator(c *C) { } region := tc.GetRegion(1) op := operator.NewTestOperator(1, &metapb.RegionEpoch{}, operator.OpRegion, steps...) - c.Assert(op.Start(), IsTrue) + suite.True(op.Start()) oc.SetOperator(op) oc.Dispatch(region, "test") - c.Assert(oc.GetOperatorStatus(1).Status, Equals, pdpb.OperatorStatus_RUNNING) + suite.Equal(pdpb.OperatorStatus_RUNNING, oc.GetOperatorStatus(1).Status) // change the leader region = region.Clone(core.WithLeader(region.GetPeer(2))) oc.Dispatch(region, DispatchFromHeartBeat) - c.Assert(op.Status(), Equals, operator.CANCELED) - c.Assert(oc.GetOperator(region.GetID()), IsNil) + suite.Equal(operator.CANCELED, op.Status()) + suite.Nil(oc.GetOperator(region.GetID())) // transfer leader to an illegal store. op = operator.NewTestOperator(1, &metapb.RegionEpoch{}, operator.OpRegion, operator.TransferLeader{ToStore: 5}) oc.SetOperator(op) oc.Dispatch(region, DispatchFromHeartBeat) - c.Assert(op.Status(), Equals, operator.CANCELED) - c.Assert(oc.GetOperator(region.GetID()), IsNil) + suite.Equal(operator.CANCELED, op.Status()) + suite.Nil(oc.GetOperator(region.GetID())) } // Issue 3353 -func (t *testOperatorControllerSuite) TestFastFailWithUnhealthyStore(c *C) { +func (suite *operatorControllerTestSuite) TestFastFailWithUnhealthyStore() { opt := config.NewTestOptions() - tc := mockcluster.NewCluster(t.ctx, opt) - stream := hbstream.NewTestHeartbeatStreams(t.ctx, tc.ID, tc, false /* no need to run */) - oc := NewOperatorController(t.ctx, tc, stream) + tc := mockcluster.NewCluster(suite.ctx, opt) + stream := hbstream.NewTestHeartbeatStreams(suite.ctx, tc.ID, tc, false /* no need to run */) + oc := NewOperatorController(suite.ctx, tc, stream) tc.AddLeaderStore(1, 2) tc.AddLeaderStore(2, 0) tc.AddLeaderStore(3, 0) @@ -180,17 +181,18 @@ func (t *testOperatorControllerSuite) TestFastFailWithUnhealthyStore(c *C) { steps := []operator.OpStep{operator.TransferLeader{ToStore: 2}} op := operator.NewTestOperator(1, region.GetRegionEpoch(), operator.OpLeader, steps...) oc.SetOperator(op) - c.Assert(oc.checkStaleOperator(op, steps[0], region), IsFalse) + suite.False(oc.checkStaleOperator(op, steps[0], region)) tc.SetStoreDown(2) - c.Assert(oc.checkStaleOperator(op, steps[0], region), IsTrue) + suite.True(oc.checkStaleOperator(op, steps[0], region)) } -func (t *testOperatorControllerSuite) TestCheckAddUnexpectedStatus(c *C) { - c.Assert(failpoint.Disable("github.com/tikv/pd/server/schedule/unexpectedOperator"), IsNil) +func (suite *operatorControllerTestSuite) TestCheckAddUnexpectedStatus() { + suite.NoError(failpoint.Disable("github.com/tikv/pd/server/schedule/unexpectedOperator")) + opt := config.NewTestOptions() - tc := mockcluster.NewCluster(t.ctx, opt) - stream := hbstream.NewTestHeartbeatStreams(t.ctx, tc.ID, tc, false /* no need to run */) - oc := NewOperatorController(t.ctx, tc, stream) + tc := mockcluster.NewCluster(suite.ctx, opt) + stream := hbstream.NewTestHeartbeatStreams(suite.ctx, tc.ID, tc, false /* no need to run */) + oc := NewOperatorController(suite.ctx, tc, stream) tc.AddLeaderStore(1, 0) tc.AddLeaderStore(2, 1) tc.AddLeaderRegion(1, 2, 1) @@ -203,58 +205,59 @@ func (t *testOperatorControllerSuite) TestCheckAddUnexpectedStatus(c *C) { { // finished op op := operator.NewTestOperator(1, &metapb.RegionEpoch{}, operator.OpRegion, operator.TransferLeader{ToStore: 2}) - c.Assert(oc.checkAddOperator(false, op), IsTrue) + suite.True(oc.checkAddOperator(false, op)) op.Start() - c.Assert(oc.checkAddOperator(false, op), IsFalse) // started - c.Assert(op.Check(region1), IsNil) - c.Assert(op.Status(), Equals, operator.SUCCESS) - c.Assert(oc.checkAddOperator(false, op), IsFalse) // success + suite.False(oc.checkAddOperator(false, op)) // started + suite.Nil(op.Check(region1)) + + suite.Equal(operator.SUCCESS, op.Status()) + suite.False(oc.checkAddOperator(false, op)) // success } { // finished op canceled op := operator.NewTestOperator(1, &metapb.RegionEpoch{}, operator.OpRegion, operator.TransferLeader{ToStore: 2}) - c.Assert(oc.checkAddOperator(false, op), IsTrue) - c.Assert(op.Cancel(), IsTrue) - c.Assert(oc.checkAddOperator(false, op), IsFalse) + suite.True(oc.checkAddOperator(false, op)) + suite.True(op.Cancel()) + suite.False(oc.checkAddOperator(false, op)) } { // finished op replaced op := operator.NewTestOperator(1, &metapb.RegionEpoch{}, operator.OpRegion, operator.TransferLeader{ToStore: 2}) - c.Assert(oc.checkAddOperator(false, op), IsTrue) - c.Assert(op.Start(), IsTrue) - c.Assert(op.Replace(), IsTrue) - c.Assert(oc.checkAddOperator(false, op), IsFalse) + suite.True(oc.checkAddOperator(false, op)) + suite.True(op.Start()) + suite.True(op.Replace()) + suite.False(oc.checkAddOperator(false, op)) } { // finished op expired op1 := operator.NewTestOperator(1, &metapb.RegionEpoch{}, operator.OpRegion, operator.TransferLeader{ToStore: 2}) op2 := operator.NewTestOperator(2, &metapb.RegionEpoch{}, operator.OpRegion, operator.TransferLeader{ToStore: 1}) - c.Assert(oc.checkAddOperator(false, op1, op2), IsTrue) + suite.True(oc.checkAddOperator(false, op1, op2)) operator.SetOperatorStatusReachTime(op1, operator.CREATED, time.Now().Add(-operator.OperatorExpireTime)) operator.SetOperatorStatusReachTime(op2, operator.CREATED, time.Now().Add(-operator.OperatorExpireTime)) - c.Assert(oc.checkAddOperator(false, op1, op2), IsFalse) - c.Assert(op1.Status(), Equals, operator.EXPIRED) - c.Assert(op2.Status(), Equals, operator.EXPIRED) + suite.False(oc.checkAddOperator(false, op1, op2)) + suite.Equal(operator.EXPIRED, op1.Status()) + suite.Equal(operator.EXPIRED, op2.Status()) } // finished op never timeout { // unfinished op timeout op := operator.NewTestOperator(1, &metapb.RegionEpoch{}, operator.OpRegion, steps...) - c.Assert(oc.checkAddOperator(false, op), IsTrue) + suite.True(oc.checkAddOperator(false, op)) op.Start() operator.SetOperatorStatusReachTime(op, operator.STARTED, time.Now().Add(-operator.SlowOperatorWaitTime)) - c.Assert(op.CheckTimeout(), IsTrue) - c.Assert(oc.checkAddOperator(false, op), IsFalse) + suite.True(op.CheckTimeout()) + suite.False(oc.checkAddOperator(false, op)) } } // issue #1716 -func (t *testOperatorControllerSuite) TestConcurrentRemoveOperator(c *C) { +func (suite *operatorControllerTestSuite) TestConcurrentRemoveOperator() { opt := config.NewTestOptions() - tc := mockcluster.NewCluster(t.ctx, opt) - stream := hbstream.NewTestHeartbeatStreams(t.ctx, tc.ID, tc, false /* no need to run */) - oc := NewOperatorController(t.ctx, tc, stream) + tc := mockcluster.NewCluster(suite.ctx, opt) + stream := hbstream.NewTestHeartbeatStreams(suite.ctx, tc.ID, tc, false /* no need to run */) + oc := NewOperatorController(suite.ctx, tc, stream) tc.AddLeaderStore(1, 0) tc.AddLeaderStore(2, 1) tc.AddLeaderRegion(1, 2, 1) @@ -268,10 +271,10 @@ func (t *testOperatorControllerSuite) TestConcurrentRemoveOperator(c *C) { // unfinished op with high priority op2 := operator.NewTestOperator(1, &metapb.RegionEpoch{}, operator.OpRegion|operator.OpAdmin, steps...) - c.Assert(op1.Start(), IsTrue) + suite.True(op1.Start()) oc.SetOperator(op1) - c.Assert(failpoint.Enable("github.com/tikv/pd/server/schedule/concurrentRemoveOperator", "return(true)"), IsNil) + suite.NoError(failpoint.Enable("github.com/tikv/pd/server/schedule/concurrentRemoveOperator", "return(true)")) var wg sync.WaitGroup wg.Add(2) @@ -283,19 +286,19 @@ func (t *testOperatorControllerSuite) TestConcurrentRemoveOperator(c *C) { time.Sleep(50 * time.Millisecond) success := oc.AddOperator(op2) // If the assert failed before wg.Done, the test will be blocked. - defer c.Assert(success, IsTrue) + defer suite.True(success) wg.Done() }() wg.Wait() - c.Assert(oc.GetOperator(1), Equals, op2) + suite.Equal(op2, oc.GetOperator(1)) } -func (t *testOperatorControllerSuite) TestPollDispatchRegion(c *C) { +func (suite *operatorControllerTestSuite) TestPollDispatchRegion() { opt := config.NewTestOptions() - tc := mockcluster.NewCluster(t.ctx, opt) - stream := hbstream.NewTestHeartbeatStreams(t.ctx, tc.ID, tc, false /* no need to run */) - oc := NewOperatorController(t.ctx, tc, stream) + tc := mockcluster.NewCluster(suite.ctx, opt) + stream := hbstream.NewTestHeartbeatStreams(suite.ctx, tc.ID, tc, false /* no need to run */) + oc := NewOperatorController(suite.ctx, tc, stream) tc.AddLeaderStore(1, 2) tc.AddLeaderStore(2, 1) tc.AddLeaderRegion(1, 1, 2) @@ -314,13 +317,13 @@ func (t *testOperatorControllerSuite) TestPollDispatchRegion(c *C) { region4 := tc.GetRegion(4) // Adds operator and pushes to the notifier queue. { - c.Assert(op1.Start(), IsTrue) + suite.True(op1.Start()) oc.SetOperator(op1) - c.Assert(op3.Start(), IsTrue) + suite.True(op3.Start()) oc.SetOperator(op3) - c.Assert(op4.Start(), IsTrue) + suite.True(op4.Start()) oc.SetOperator(op4) - c.Assert(op2.Start(), IsTrue) + suite.True(op2.Start()) oc.SetOperator(op2) heap.Push(&oc.opNotifierQueue, &operatorWithTime{op: op1, time: time.Now().Add(100 * time.Millisecond)}) heap.Push(&oc.opNotifierQueue, &operatorWithTime{op: op3, time: time.Now().Add(300 * time.Millisecond)}) @@ -329,45 +332,46 @@ func (t *testOperatorControllerSuite) TestPollDispatchRegion(c *C) { } // first poll got nil r, next := oc.pollNeedDispatchRegion() - c.Assert(r, IsNil) - c.Assert(next, IsFalse) + suite.Nil(r) + suite.False(next) // after wait 100 millisecond, the region1 need to dispatch, but not region2. time.Sleep(100 * time.Millisecond) r, next = oc.pollNeedDispatchRegion() - c.Assert(r, NotNil) - c.Assert(next, IsTrue) - c.Assert(r.GetID(), Equals, region1.GetID()) + suite.NotNil(r) + suite.True(next) + suite.Equal(region1.GetID(), r.GetID()) // find op3 with nil region, remove it - c.Assert(oc.GetOperator(3), NotNil) + suite.NotNil(oc.GetOperator(3)) + r, next = oc.pollNeedDispatchRegion() - c.Assert(r, IsNil) - c.Assert(next, IsTrue) - c.Assert(oc.GetOperator(3), IsNil) + suite.Nil(r) + suite.True(next) + suite.Nil(oc.GetOperator(3)) // find op4 finished r, next = oc.pollNeedDispatchRegion() - c.Assert(r, NotNil) - c.Assert(next, IsTrue) - c.Assert(r.GetID(), Equals, region4.GetID()) + suite.NotNil(r) + suite.True(next) + suite.Equal(region4.GetID(), r.GetID()) // after waiting 500 milliseconds, the region2 need to dispatch time.Sleep(400 * time.Millisecond) r, next = oc.pollNeedDispatchRegion() - c.Assert(r, NotNil) - c.Assert(next, IsTrue) - c.Assert(r.GetID(), Equals, region2.GetID()) + suite.NotNil(r) + suite.True(next) + suite.Equal(region2.GetID(), r.GetID()) r, next = oc.pollNeedDispatchRegion() - c.Assert(r, IsNil) - c.Assert(next, IsFalse) + suite.Nil(r) + suite.False(next) } -func (t *testOperatorControllerSuite) TestStoreLimit(c *C) { +func (suite *operatorControllerTestSuite) TestStoreLimit() { opt := config.NewTestOptions() - tc := mockcluster.NewCluster(t.ctx, opt) - stream := hbstream.NewTestHeartbeatStreams(t.ctx, tc.ID, tc, false /* no need to run */) - oc := NewOperatorController(t.ctx, tc, stream) + tc := mockcluster.NewCluster(suite.ctx, opt) + stream := hbstream.NewTestHeartbeatStreams(suite.ctx, tc.ID, tc, false /* no need to run */) + oc := NewOperatorController(suite.ctx, tc, stream) tc.AddLeaderStore(1, 0) tc.UpdateLeaderCount(1, 1000) tc.AddLeaderStore(2, 0) @@ -380,61 +384,61 @@ func (t *testOperatorControllerSuite) TestStoreLimit(c *C) { tc.SetStoreLimit(2, storelimit.AddPeer, 60) for i := uint64(1); i <= 5; i++ { op := operator.NewTestOperator(1, &metapb.RegionEpoch{}, operator.OpRegion, operator.AddPeer{ToStore: 2, PeerID: i}) - c.Assert(oc.AddOperator(op), IsTrue) - checkRemoveOperatorSuccess(c, oc, op) + suite.True(oc.AddOperator(op)) + suite.checkRemoveOperatorSuccess(oc, op) } op := operator.NewTestOperator(1, &metapb.RegionEpoch{}, operator.OpRegion, operator.AddPeer{ToStore: 2, PeerID: 1}) - c.Assert(oc.AddOperator(op), IsFalse) - c.Assert(oc.RemoveOperator(op), IsFalse) + suite.False(oc.AddOperator(op)) + suite.False(oc.RemoveOperator(op)) tc.SetStoreLimit(2, storelimit.AddPeer, 120) for i := uint64(1); i <= 10; i++ { op = operator.NewTestOperator(i, &metapb.RegionEpoch{}, operator.OpRegion, operator.AddPeer{ToStore: 2, PeerID: i}) - c.Assert(oc.AddOperator(op), IsTrue) - checkRemoveOperatorSuccess(c, oc, op) + suite.True(oc.AddOperator(op)) + suite.checkRemoveOperatorSuccess(oc, op) } tc.SetAllStoresLimit(storelimit.AddPeer, 60) for i := uint64(1); i <= 5; i++ { op = operator.NewTestOperator(i, &metapb.RegionEpoch{}, operator.OpRegion, operator.AddPeer{ToStore: 2, PeerID: i}) - c.Assert(oc.AddOperator(op), IsTrue) - checkRemoveOperatorSuccess(c, oc, op) + suite.True(oc.AddOperator(op)) + suite.checkRemoveOperatorSuccess(oc, op) } op = operator.NewTestOperator(1, &metapb.RegionEpoch{}, operator.OpRegion, operator.AddPeer{ToStore: 2, PeerID: 1}) - c.Assert(oc.AddOperator(op), IsFalse) - c.Assert(oc.RemoveOperator(op), IsFalse) + suite.False(oc.AddOperator(op)) + suite.False(oc.RemoveOperator(op)) tc.SetStoreLimit(2, storelimit.RemovePeer, 60) for i := uint64(1); i <= 5; i++ { op := operator.NewTestOperator(1, &metapb.RegionEpoch{}, operator.OpRegion, operator.RemovePeer{FromStore: 2}) - c.Assert(oc.AddOperator(op), IsTrue) - checkRemoveOperatorSuccess(c, oc, op) + suite.True(oc.AddOperator(op)) + suite.checkRemoveOperatorSuccess(oc, op) } op = operator.NewTestOperator(1, &metapb.RegionEpoch{}, operator.OpRegion, operator.RemovePeer{FromStore: 2}) - c.Assert(oc.AddOperator(op), IsFalse) - c.Assert(oc.RemoveOperator(op), IsFalse) + suite.False(oc.AddOperator(op)) + suite.False(oc.RemoveOperator(op)) tc.SetStoreLimit(2, storelimit.RemovePeer, 120) for i := uint64(1); i <= 10; i++ { op = operator.NewTestOperator(i, &metapb.RegionEpoch{}, operator.OpRegion, operator.RemovePeer{FromStore: 2}) - c.Assert(oc.AddOperator(op), IsTrue) - checkRemoveOperatorSuccess(c, oc, op) + suite.True(oc.AddOperator(op)) + suite.checkRemoveOperatorSuccess(oc, op) } tc.SetAllStoresLimit(storelimit.RemovePeer, 60) for i := uint64(1); i <= 5; i++ { op = operator.NewTestOperator(i, &metapb.RegionEpoch{}, operator.OpRegion, operator.RemovePeer{FromStore: 2}) - c.Assert(oc.AddOperator(op), IsTrue) - checkRemoveOperatorSuccess(c, oc, op) + suite.True(oc.AddOperator(op)) + suite.checkRemoveOperatorSuccess(oc, op) } op = operator.NewTestOperator(1, &metapb.RegionEpoch{}, operator.OpRegion, operator.RemovePeer{FromStore: 2}) - c.Assert(oc.AddOperator(op), IsFalse) - c.Assert(oc.RemoveOperator(op), IsFalse) + suite.False(oc.AddOperator(op)) + suite.False(oc.RemoveOperator(op)) } // #1652 -func (t *testOperatorControllerSuite) TestDispatchOutdatedRegion(c *C) { - cluster := mockcluster.NewCluster(t.ctx, config.NewTestOptions()) - stream := hbstream.NewTestHeartbeatStreams(t.ctx, cluster.ID, cluster, false /* no need to run */) - controller := NewOperatorController(t.ctx, cluster, stream) +func (suite *operatorControllerTestSuite) TestDispatchOutdatedRegion() { + cluster := mockcluster.NewCluster(suite.ctx, config.NewTestOptions()) + stream := hbstream.NewTestHeartbeatStreams(suite.ctx, cluster.ID, cluster, false /* no need to run */) + controller := NewOperatorController(suite.ctx, cluster, stream) cluster.AddLeaderStore(1, 2) cluster.AddLeaderStore(2, 0) @@ -446,45 +450,45 @@ func (t *testOperatorControllerSuite) TestDispatchOutdatedRegion(c *C) { } op := operator.NewTestOperator(1, &metapb.RegionEpoch{ConfVer: 0, Version: 0}, operator.OpRegion, steps...) - c.Assert(controller.AddOperator(op), IsTrue) - c.Assert(stream.MsgLength(), Equals, 1) + suite.True(controller.AddOperator(op)) + suite.Equal(1, stream.MsgLength()) // report the result of transferring leader region := cluster.MockRegionInfo(1, 2, []uint64{1, 2}, []uint64{}, &metapb.RegionEpoch{ConfVer: 0, Version: 0}) controller.Dispatch(region, DispatchFromHeartBeat) - c.Assert(op.ConfVerChanged(region), Equals, uint64(0)) - c.Assert(stream.MsgLength(), Equals, 2) + suite.Equal(uint64(0), op.ConfVerChanged(region)) + suite.Equal(2, stream.MsgLength()) // report the result of removing peer region = cluster.MockRegionInfo(1, 2, []uint64{2}, []uint64{}, &metapb.RegionEpoch{ConfVer: 0, Version: 0}) controller.Dispatch(region, DispatchFromHeartBeat) - c.Assert(op.ConfVerChanged(region), Equals, uint64(1)) - c.Assert(stream.MsgLength(), Equals, 2) + suite.Equal(uint64(1), op.ConfVerChanged(region)) + suite.Equal(2, stream.MsgLength()) // add and dispatch op again, the op should be stale op = operator.NewTestOperator(1, &metapb.RegionEpoch{ConfVer: 0, Version: 0}, operator.OpRegion, steps...) - c.Assert(controller.AddOperator(op), IsTrue) - c.Assert(op.ConfVerChanged(region), Equals, uint64(0)) - c.Assert(stream.MsgLength(), Equals, 3) + suite.True(controller.AddOperator(op)) + suite.Equal(uint64(0), op.ConfVerChanged(region)) + suite.Equal(3, stream.MsgLength()) // report region with an abnormal confver region = cluster.MockRegionInfo(1, 1, []uint64{1, 2}, []uint64{}, &metapb.RegionEpoch{ConfVer: 1, Version: 0}) controller.Dispatch(region, DispatchFromHeartBeat) - c.Assert(op.ConfVerChanged(region), Equals, uint64(0)) + suite.Equal(uint64(0), op.ConfVerChanged(region)) // no new step - c.Assert(stream.MsgLength(), Equals, 3) + suite.Equal(3, stream.MsgLength()) } -func (t *testOperatorControllerSuite) TestDispatchUnfinishedStep(c *C) { - cluster := mockcluster.NewCluster(t.ctx, config.NewTestOptions()) - stream := hbstream.NewTestHeartbeatStreams(t.ctx, cluster.ID, cluster, false /* no need to run */) - controller := NewOperatorController(t.ctx, cluster, stream) +func (suite *operatorControllerTestSuite) TestDispatchUnfinishedStep() { + cluster := mockcluster.NewCluster(suite.ctx, config.NewTestOptions()) + stream := hbstream.NewTestHeartbeatStreams(suite.ctx, cluster.ID, cluster, false /* no need to run */) + controller := NewOperatorController(suite.ctx, cluster, stream) // Create a new region with epoch(0, 0) // the region has two peers with its peer id allocated incrementally. @@ -517,8 +521,8 @@ func (t *testOperatorControllerSuite) TestDispatchUnfinishedStep(c *C) { for _, steps := range testSteps { // Create an operator op := operator.NewTestOperator(1, epoch, operator.OpRegion, steps...) - c.Assert(controller.AddOperator(op), IsTrue) - c.Assert(stream.MsgLength(), Equals, 1) + suite.True(controller.AddOperator(op)) + suite.Equal(1, stream.MsgLength()) // Create region2 which is cloned from the original region. // region2 has peer 2 in pending state, so the AddPeer step @@ -530,62 +534,64 @@ func (t *testOperatorControllerSuite) TestDispatchUnfinishedStep(c *C) { }), core.WithIncConfVer(), ) - c.Assert(region2.GetPendingPeers(), NotNil) - c.Assert(steps[0].IsFinish(region2), IsFalse) + suite.NotNil(region2.GetPendingPeers()) + + suite.False(steps[0].IsFinish(region2)) controller.Dispatch(region2, DispatchFromHeartBeat) // In this case, the conf version has been changed, but the // peer added is in pending state, the operator should not be // removed by the stale checker - c.Assert(op.ConfVerChanged(region2), Equals, uint64(1)) - c.Assert(controller.GetOperator(1), NotNil) + suite.Equal(uint64(1), op.ConfVerChanged(region2)) + suite.NotNil(controller.GetOperator(1)) + // The operator is valid yet, but the step should not be sent // again, because it is in pending state, so the message channel // should not be increased - c.Assert(stream.MsgLength(), Equals, 1) + suite.Equal(1, stream.MsgLength()) // Finish the step by clearing the pending state region3 := region.Clone( core.WithAddPeer(&metapb.Peer{Id: 3, StoreId: 3, Role: metapb.PeerRole_Learner}), core.WithIncConfVer(), ) - c.Assert(steps[0].IsFinish(region3), IsTrue) + suite.True(steps[0].IsFinish(region3)) controller.Dispatch(region3, DispatchFromHeartBeat) - c.Assert(op.ConfVerChanged(region3), Equals, uint64(1)) - c.Assert(stream.MsgLength(), Equals, 2) + suite.Equal(uint64(1), op.ConfVerChanged(region3)) + suite.Equal(2, stream.MsgLength()) region4 := region3.Clone( core.WithPromoteLearner(3), core.WithIncConfVer(), ) - c.Assert(steps[1].IsFinish(region4), IsTrue) + suite.True(steps[1].IsFinish(region4)) controller.Dispatch(region4, DispatchFromHeartBeat) - c.Assert(op.ConfVerChanged(region4), Equals, uint64(2)) - c.Assert(stream.MsgLength(), Equals, 3) + suite.Equal(uint64(2), op.ConfVerChanged(region4)) + suite.Equal(3, stream.MsgLength()) // Transfer leader region5 := region4.Clone( core.WithLeader(region4.GetStorePeer(3)), ) - c.Assert(steps[2].IsFinish(region5), IsTrue) + suite.True(steps[2].IsFinish(region5)) controller.Dispatch(region5, DispatchFromHeartBeat) - c.Assert(op.ConfVerChanged(region5), Equals, uint64(2)) - c.Assert(stream.MsgLength(), Equals, 4) + suite.Equal(uint64(2), op.ConfVerChanged(region5)) + suite.Equal(4, stream.MsgLength()) // Remove peer region6 := region5.Clone( core.WithRemoveStorePeer(1), core.WithIncConfVer(), ) - c.Assert(steps[3].IsFinish(region6), IsTrue) + suite.True(steps[3].IsFinish(region6)) controller.Dispatch(region6, DispatchFromHeartBeat) - c.Assert(op.ConfVerChanged(region6), Equals, uint64(3)) + suite.Equal(uint64(3), op.ConfVerChanged(region6)) // The Operator has finished, so no message should be sent - c.Assert(stream.MsgLength(), Equals, 4) - c.Assert(controller.GetOperator(1), IsNil) + suite.Equal(4, stream.MsgLength()) + suite.Nil(controller.GetOperator(1)) e := stream.Drain(4) - c.Assert(e, IsNil) + suite.NoError(e) } } @@ -609,17 +615,17 @@ func newRegionInfo(id uint64, startKey, endKey string, size, keys int64, leader ) } -func checkRemoveOperatorSuccess(c *C, oc *OperatorController, op *operator.Operator) { - c.Assert(oc.RemoveOperator(op), IsTrue) - c.Assert(op.IsEnd(), IsTrue) - c.Assert(oc.GetOperatorStatus(op.RegionID()).Operator, DeepEquals, op) +func (suite *operatorControllerTestSuite) checkRemoveOperatorSuccess(oc *OperatorController, op *operator.Operator) { + suite.True(oc.RemoveOperator(op)) + suite.True(op.IsEnd()) + suite.Equal(op, oc.GetOperatorStatus(op.RegionID()).Operator) } -func (t *testOperatorControllerSuite) TestAddWaitingOperator(c *C) { +func (suite *operatorControllerTestSuite) TestAddWaitingOperator() { opts := config.NewTestOptions() - cluster := mockcluster.NewCluster(t.ctx, opts) - stream := hbstream.NewTestHeartbeatStreams(t.ctx, cluster.ID, cluster, false /* no need to run */) - controller := NewOperatorController(t.ctx, cluster, stream) + cluster := mockcluster.NewCluster(suite.ctx, opts) + stream := hbstream.NewTestHeartbeatStreams(suite.ctx, cluster.ID, cluster, false /* no need to run */) + controller := NewOperatorController(suite.ctx, cluster, stream) cluster.AddLabelsStore(1, 1, map[string]string{"host": "host1"}) cluster.AddLabelsStore(2, 1, map[string]string{"host": "host2"}) cluster.AddLabelsStore(3, 1, map[string]string{"host": "host3"}) @@ -632,8 +638,9 @@ func (t *testOperatorControllerSuite) TestAddWaitingOperator(c *C) { StoreId: 2, } op, err := operator.CreateAddPeerOperator("add-peer", cluster, region, peer, operator.OpKind(0)) - c.Assert(err, IsNil) - c.Assert(op, NotNil) + suite.NoError(err) + suite.NotNil(op) + return op } @@ -643,21 +650,21 @@ func (t *testOperatorControllerSuite) TestAddWaitingOperator(c *C) { batch = append(batch, addPeerOp(i)) } added := controller.AddWaitingOperator(batch...) - c.Assert(added, Equals, int(cluster.GetSchedulerMaxWaitingOperator())) + suite.Equal(int(cluster.GetSchedulerMaxWaitingOperator()), added) // test adding a batch of operators when some operators will get false in check // and remain operators can be added normally batch = append(batch, addPeerOp(cluster.GetSchedulerMaxWaitingOperator())) added = controller.AddWaitingOperator(batch...) - c.Assert(added, Equals, 1) + suite.Equal(1, added) scheduleCfg := opts.GetScheduleConfig().Clone() scheduleCfg.SchedulerMaxWaitingOperator = 1 opts.SetScheduleConfig(scheduleCfg) batch = append(batch, addPeerOp(100)) added = controller.AddWaitingOperator(batch...) - c.Assert(added, Equals, 1) - c.Assert(controller.operators[uint64(100)], NotNil) + suite.Equal(1, added) + suite.NotNil(controller.operators[uint64(100)]) source := newRegionInfo(101, "1a", "1b", 1, 1, []uint64{101, 1}, []uint64{101, 1}) cluster.PutRegion(source) @@ -665,8 +672,8 @@ func (t *testOperatorControllerSuite) TestAddWaitingOperator(c *C) { cluster.PutRegion(target) ops, err := operator.CreateMergeRegionOperator("merge-region", cluster, source, target, operator.OpMerge) - c.Assert(err, IsNil) - c.Assert(ops, HasLen, 2) + suite.NoError(err) + suite.Len(ops, 2) // test with label schedule=deny labelerManager := cluster.GetRegionLabeler() @@ -677,22 +684,22 @@ func (t *testOperatorControllerSuite) TestAddWaitingOperator(c *C) { Data: []interface{}{map[string]interface{}{"start_key": "1a", "end_key": "1b"}}, }) - c.Assert(labelerManager.ScheduleDisabled(source), IsTrue) + suite.True(labelerManager.ScheduleDisabled(source)) // add operator should be failed since it is labeled with `schedule=deny`. - c.Assert(controller.AddWaitingOperator(ops...), Equals, 0) + suite.Equal(0, controller.AddWaitingOperator(ops...)) // add operator should be success without `schedule=deny` labelerManager.DeleteLabelRule("schedulelabel") labelerManager.ScheduleDisabled(source) - c.Assert(labelerManager.ScheduleDisabled(source), IsFalse) + suite.False(labelerManager.ScheduleDisabled(source)) // now there is one operator being allowed to add, if it is a merge operator // both of the pair are allowed ops, err = operator.CreateMergeRegionOperator("merge-region", cluster, source, target, operator.OpMerge) - c.Assert(err, IsNil) - c.Assert(ops, HasLen, 2) - c.Assert(controller.AddWaitingOperator(ops...), Equals, 2) - c.Assert(controller.AddWaitingOperator(ops...), Equals, 0) + suite.NoError(err) + suite.Len(ops, 2) + suite.Equal(2, controller.AddWaitingOperator(ops...)) + suite.Equal(0, controller.AddWaitingOperator(ops...)) // no space left, new operator can not be added. - c.Assert(controller.AddWaitingOperator(addPeerOp(0)), Equals, 0) + suite.Equal(0, controller.AddWaitingOperator(addPeerOp(0))) } diff --git a/server/schedule/region_scatterer_test.go b/server/schedule/region_scatterer_test.go index bfbd99f1e4e..f9034ee103c 100644 --- a/server/schedule/region_scatterer_test.go +++ b/server/schedule/region_scatterer_test.go @@ -19,11 +19,12 @@ import ( "fmt" "math" "math/rand" + "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/mock/mockcluster" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/core" @@ -59,33 +60,30 @@ func (s *sequencer) next() uint64 { return s.curID } -var _ = Suite(&testScatterRegionSuite{}) - -type testScatterRegionSuite struct{} - -func (s *testScatterRegionSuite) TestScatterRegions(c *C) { - s.scatter(c, 5, 50, true) - s.scatter(c, 5, 500, true) - s.scatter(c, 6, 50, true) - s.scatter(c, 5, 50, false) - s.scatterSpecial(c, 3, 6, 50) - s.scatterSpecial(c, 5, 5, 50) +func TestScatterRegions(t *testing.T) { + re := require.New(t) + scatter(re, 5, 50, true) + scatter(re, 5, 500, true) + scatter(re, 6, 50, true) + scatter(re, 5, 50, false) + scatterSpecial(re, 3, 6, 50) + scatterSpecial(re, 5, 5, 50) } -func (s *testScatterRegionSuite) checkOperator(op *operator.Operator, c *C) { +func checkOperator(re *require.Assertions, op *operator.Operator) { for i := 0; i < op.Len(); i++ { if rp, ok := op.Step(i).(operator.RemovePeer); ok { for j := i + 1; j < op.Len(); j++ { if tr, ok := op.Step(j).(operator.TransferLeader); ok { - c.Assert(rp.FromStore, Not(Equals), tr.FromStore) - c.Assert(rp.FromStore, Not(Equals), tr.ToStore) + re.NotEqual(tr.FromStore, rp.FromStore) + re.NotEqual(tr.ToStore, rp.FromStore) } } } } } -func (s *testScatterRegionSuite) scatter(c *C, numStores, numRegions uint64, useRules bool) { +func scatter(re *require.Assertions, numStores, numRegions uint64, useRules bool) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() opt := config.NewTestOptions() @@ -107,7 +105,7 @@ func (s *testScatterRegionSuite) scatter(c *C, numStores, numRegions uint64, use for i := uint64(1); i <= numRegions; i++ { region := tc.GetRegion(i) if op, _ := scatterer.Scatter(region, ""); op != nil { - s.checkOperator(op, c) + checkOperator(re, op) ApplyOperator(tc, op) } } @@ -127,20 +125,20 @@ func (s *testScatterRegionSuite) scatter(c *C, numStores, numRegions uint64, use // Each store should have the same number of peers. for _, count := range countPeers { - c.Assert(float64(count), LessEqual, 1.1*float64(numRegions*3)/float64(numStores)) - c.Assert(float64(count), GreaterEqual, 0.9*float64(numRegions*3)/float64(numStores)) + re.LessOrEqual(float64(count), 1.1*float64(numRegions*3)/float64(numStores)) + re.GreaterOrEqual(float64(count), 0.9*float64(numRegions*3)/float64(numStores)) } // Each store should have the same number of leaders. - c.Assert(countPeers, HasLen, int(numStores)) - c.Assert(countLeader, HasLen, int(numStores)) + re.Len(countPeers, int(numStores)) + re.Len(countLeader, int(numStores)) for _, count := range countLeader { - c.Assert(float64(count), LessEqual, 1.1*float64(numRegions)/float64(numStores)) - c.Assert(float64(count), GreaterEqual, 0.9*float64(numRegions)/float64(numStores)) + re.LessOrEqual(float64(count), 1.1*float64(numRegions)/float64(numStores)) + re.GreaterOrEqual(float64(count), 0.9*float64(numRegions)/float64(numStores)) } } -func (s *testScatterRegionSuite) scatterSpecial(c *C, numOrdinaryStores, numSpecialStores, numRegions uint64) { +func scatterSpecial(re *require.Assertions, numOrdinaryStores, numSpecialStores, numRegions uint64) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() opt := config.NewTestOptions() @@ -156,9 +154,9 @@ func (s *testScatterRegionSuite) scatterSpecial(c *C, numOrdinaryStores, numSpec tc.AddLabelsStore(numOrdinaryStores+i, 0, map[string]string{"engine": "tiflash"}) } tc.SetEnablePlacementRules(true) - c.Assert(tc.RuleManager.SetRule(&placement.Rule{ + re.NoError(tc.RuleManager.SetRule(&placement.Rule{ GroupID: "pd", ID: "learner", Role: placement.Learner, Count: 3, - LabelConstraints: []placement.LabelConstraint{{Key: "engine", Op: placement.In, Values: []string{"tiflash"}}}}), IsNil) + LabelConstraints: []placement.LabelConstraint{{Key: "engine", Op: placement.In, Values: []string{"tiflash"}}}})) // Region 1 has the same distribution with the Region 2, which is used to test selectPeerToReplace. tc.AddRegionWithLearner(1, 1, []uint64{2, 3}, []uint64{numOrdinaryStores + 1, numOrdinaryStores + 2, numOrdinaryStores + 3}) @@ -175,7 +173,7 @@ func (s *testScatterRegionSuite) scatterSpecial(c *C, numOrdinaryStores, numSpec for i := uint64(1); i <= numRegions; i++ { region := tc.GetRegion(i) if op, _ := scatterer.Scatter(region, ""); op != nil { - s.checkOperator(op, c) + checkOperator(re, op) ApplyOperator(tc, op) } } @@ -202,20 +200,21 @@ func (s *testScatterRegionSuite) scatterSpecial(c *C, numOrdinaryStores, numSpec // Each store should have the same number of peers. for _, count := range countOrdinaryPeers { - c.Assert(float64(count), LessEqual, 1.1*float64(numRegions*3)/float64(numOrdinaryStores)) - c.Assert(float64(count), GreaterEqual, 0.9*float64(numRegions*3)/float64(numOrdinaryStores)) + re.LessOrEqual(float64(count), 1.1*float64(numRegions*3)/float64(numOrdinaryStores)) + re.GreaterOrEqual(float64(count), 0.9*float64(numRegions*3)/float64(numOrdinaryStores)) } for _, count := range countSpecialPeers { - c.Assert(float64(count), LessEqual, 1.1*float64(numRegions*3)/float64(numSpecialStores)) - c.Assert(float64(count), GreaterEqual, 0.9*float64(numRegions*3)/float64(numSpecialStores)) + re.LessOrEqual(float64(count), 1.1*float64(numRegions*3)/float64(numSpecialStores)) + re.GreaterOrEqual(float64(count), 0.9*float64(numRegions*3)/float64(numSpecialStores)) } for _, count := range countOrdinaryLeaders { - c.Assert(float64(count), LessEqual, 1.1*float64(numRegions)/float64(numOrdinaryStores)) - c.Assert(float64(count), GreaterEqual, 0.9*float64(numRegions)/float64(numOrdinaryStores)) + re.LessOrEqual(float64(count), 1.1*float64(numRegions)/float64(numOrdinaryStores)) + re.GreaterOrEqual(float64(count), 0.9*float64(numRegions)/float64(numOrdinaryStores)) } } -func (s *testScatterRegionSuite) TestStoreLimit(c *C) { +func TestStoreLimit(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() opt := config.NewTestOptions() @@ -241,12 +240,13 @@ func (s *testScatterRegionSuite) TestStoreLimit(c *C) { for i := uint64(1); i <= 5; i++ { region := tc.GetRegion(i) if op, _ := scatterer.Scatter(region, ""); op != nil { - c.Assert(oc.AddWaitingOperator(op), Equals, 1) + re.Equal(1, oc.AddWaitingOperator(op)) } } } -func (s *testScatterRegionSuite) TestScatterCheck(c *C) { +func TestScatterCheck(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() opt := config.NewTestOptions() @@ -277,21 +277,21 @@ func (s *testScatterRegionSuite) TestScatterCheck(c *C) { }, } for _, testCase := range testCases { - c.Logf(testCase.name) scatterer := NewRegionScatterer(ctx, tc) _, err := scatterer.Scatter(testCase.checkRegion, "") if testCase.needFix { - c.Assert(err, NotNil) - c.Assert(tc.CheckRegionUnderSuspect(1), IsTrue) + re.Error(err) + re.True(tc.CheckRegionUnderSuspect(1)) } else { - c.Assert(err, IsNil) - c.Assert(tc.CheckRegionUnderSuspect(1), IsFalse) + re.NoError(err) + re.False(tc.CheckRegionUnderSuspect(1)) } tc.ResetSuspectRegions() } } -func (s *testScatterRegionSuite) TestScatterGroupInConcurrency(c *C) { +func TestScatterGroupInConcurrency(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() opt := config.NewTestOptions() @@ -323,7 +323,6 @@ func (s *testScatterRegionSuite) TestScatterGroupInConcurrency(c *C) { // We send scatter interweave request for each group to simulate scattering multiple region groups in concurrency. for _, testCase := range testCases { - c.Logf(testCase.name) scatterer := NewRegionScatterer(ctx, tc) regionID := 1 for i := 0; i < 100; i++ { @@ -349,8 +348,8 @@ func (s *testScatterRegionSuite) TestScatterGroupInConcurrency(c *C) { min = count } } - c.Assert(math.Abs(float64(max)-float64(expected)), LessEqual, delta) - c.Assert(math.Abs(float64(min)-float64(expected)), LessEqual, delta) + re.LessOrEqual(math.Abs(float64(max)-float64(expected)), delta) + re.LessOrEqual(math.Abs(float64(min)-float64(expected)), delta) } } // For leader, we expect each store have about 20 leader for each group @@ -360,7 +359,8 @@ func (s *testScatterRegionSuite) TestScatterGroupInConcurrency(c *C) { } } -func (s *testScatterRegionSuite) TestScattersGroup(c *C) { +func TestScattersGroup(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() opt := config.NewTestOptions() @@ -389,17 +389,16 @@ func (s *testScatterRegionSuite) TestScattersGroup(c *C) { for i := 1; i <= 100; i++ { regions[uint64(i)] = tc.AddLeaderRegion(uint64(i), 1, 2, 3) } - c.Log(testCase.name) failures := map[uint64]error{} if testCase.failure { - c.Assert(failpoint.Enable("github.com/tikv/pd/server/schedule/scatterFail", `return(true)`), IsNil) + re.NoError(failpoint.Enable("github.com/tikv/pd/server/schedule/scatterFail", `return(true)`)) } scatterer.scatterRegions(regions, failures, group, 3) max := uint64(0) min := uint64(math.MaxUint64) groupDistribution, exist := scatterer.ordinaryEngine.selectedLeader.GetGroupDistribution(group) - c.Assert(exist, IsTrue) + re.True(exist) for _, count := range groupDistribution { if count > max { max = count @@ -409,22 +408,22 @@ func (s *testScatterRegionSuite) TestScattersGroup(c *C) { } } // 100 regions divided 5 stores, each store expected to have about 20 regions. - c.Assert(min, LessEqual, uint64(20)) - c.Assert(max, GreaterEqual, uint64(20)) - c.Assert(max-min, LessEqual, uint64(3)) + re.LessOrEqual(min, uint64(20)) + re.GreaterOrEqual(max, uint64(20)) + re.LessOrEqual(max-min, uint64(3)) if testCase.failure { - c.Assert(failures, HasLen, 1) + re.Len(failures, 1) _, ok := failures[1] - c.Assert(ok, IsTrue) - c.Assert(failpoint.Disable("github.com/tikv/pd/server/schedule/scatterFail"), IsNil) + re.True(ok) + re.NoError(failpoint.Disable("github.com/tikv/pd/server/schedule/scatterFail")) } else { - c.Assert(failures, HasLen, 0) + re.Empty(failures) } } } -func (s *testScatterRegionSuite) TestSelectedStoreGC(c *C) { - // use a shorter gcTTL and gcInterval during the test +func TestSelectedStoreGC(t *testing.T) { + re := require.New(t) gcInterval = time.Second gcTTL = time.Second * 3 ctx, cancel := context.WithCancel(context.Background()) @@ -432,19 +431,20 @@ func (s *testScatterRegionSuite) TestSelectedStoreGC(c *C) { stores := newSelectedStores(ctx) stores.Put(1, "testgroup") _, ok := stores.GetGroupDistribution("testgroup") - c.Assert(ok, IsTrue) + re.True(ok) _, ok = stores.GetGroupDistribution("testgroup") - c.Assert(ok, IsTrue) + re.True(ok) time.Sleep(gcTTL) _, ok = stores.GetGroupDistribution("testgroup") - c.Assert(ok, IsFalse) + re.False(ok) _, ok = stores.GetGroupDistribution("testgroup") - c.Assert(ok, IsFalse) + re.False(ok) } // TestRegionFromDifferentGroups test the multi regions. each region have its own group. // After scatter, the distribution for the whole cluster should be well. -func (s *testScatterRegionSuite) TestRegionFromDifferentGroups(c *C) { +func TestRegionFromDifferentGroups(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() opt := config.NewTestOptions() @@ -472,14 +472,15 @@ func (s *testScatterRegionSuite) TestRegionFromDifferentGroups(c *C) { min = count } } - c.Assert(max-min, LessEqual, uint64(2)) + re.LessOrEqual(max-min, uint64(2)) } check(scatterer.ordinaryEngine.selectedPeer) } // TestSelectedStores tests if the peer count has changed due to the picking strategy. // Ref https://github.com/tikv/pd/issues/4565 -func (s *testScatterRegionSuite) TestSelectedStores(c *C) { +func TestSelectedStores(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() opt := config.NewTestOptions() @@ -507,7 +508,7 @@ func (s *testScatterRegionSuite) TestSelectedStores(c *C) { for i := uint64(1); i < 20; i++ { region := tc.AddLeaderRegion(i+200, i%3+2, (i+1)%3+2, (i+2)%3+2) op := scatterer.scatterRegion(region, group) - c.Assert(isPeerCountChanged(op), IsFalse) + re.False(isPeerCountChanged(op)) } } diff --git a/server/schedule/region_splitter_test.go b/server/schedule/region_splitter_test.go index 94abcfe0ccf..eff6a67c15b 100644 --- a/server/schedule/region_splitter_test.go +++ b/server/schedule/region_splitter_test.go @@ -17,8 +17,9 @@ package schedule import ( "bytes" "context" + "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/mock/mockcluster" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/core" @@ -60,52 +61,56 @@ func (m *mockSplitRegionsHandler) ScanRegionsByKeyRange(groupKeys *regionGroupKe groupKeys.finished = true } -var _ = Suite(&testRegionSplitterSuite{}) +type regionSplitterTestSuite struct { + suite.Suite -type testRegionSplitterSuite struct { ctx context.Context cancel context.CancelFunc } -func (s *testRegionSplitterSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) +func TestRegionSplitterTestSuite(t *testing.T) { + suite.Run(t, new(regionSplitterTestSuite)) } -func (s *testRegionSplitterSuite) TearDownTest(c *C) { - s.cancel() +func (suite *regionSplitterTestSuite) SetupSuite() { + suite.ctx, suite.cancel = context.WithCancel(context.Background()) } -func (s *testRegionSplitterSuite) TestRegionSplitter(c *C) { +func (suite *regionSplitterTestSuite) TearDownTest() { + suite.cancel() +} + +func (suite *regionSplitterTestSuite) TestRegionSplitter() { opt := config.NewTestOptions() opt.SetPlacementRuleEnabled(false) - tc := mockcluster.NewCluster(s.ctx, opt) + tc := mockcluster.NewCluster(suite.ctx, opt) handler := newMockSplitRegionsHandler() tc.AddLeaderRegionWithRange(1, "eee", "hhh", 2, 3, 4) splitter := NewRegionSplitter(tc, handler) newRegions := map[uint64]struct{}{} // assert success - failureKeys := splitter.splitRegionsByKeys(s.ctx, [][]byte{[]byte("fff"), []byte("ggg")}, newRegions) - c.Assert(failureKeys, HasLen, 0) - c.Assert(newRegions, HasLen, 2) + failureKeys := splitter.splitRegionsByKeys(suite.ctx, [][]byte{[]byte("fff"), []byte("ggg")}, newRegions) + suite.Len(failureKeys, 0) + suite.Len(newRegions, 2) - percentage, newRegionsID := splitter.SplitRegions(s.ctx, [][]byte{[]byte("fff"), []byte("ggg")}, 1) - c.Assert(percentage, Equals, 100) - c.Assert(newRegionsID, HasLen, 2) + percentage, newRegionsID := splitter.SplitRegions(suite.ctx, [][]byte{[]byte("fff"), []byte("ggg")}, 1) + suite.Equal(100, percentage) + suite.Len(newRegionsID, 2) // assert out of range newRegions = map[uint64]struct{}{} - failureKeys = splitter.splitRegionsByKeys(s.ctx, [][]byte{[]byte("aaa"), []byte("bbb")}, newRegions) - c.Assert(failureKeys, HasLen, 2) - c.Assert(newRegions, HasLen, 0) + failureKeys = splitter.splitRegionsByKeys(suite.ctx, [][]byte{[]byte("aaa"), []byte("bbb")}, newRegions) + suite.Len(failureKeys, 2) + suite.Len(newRegions, 0) - percentage, newRegionsID = splitter.SplitRegions(s.ctx, [][]byte{[]byte("aaa"), []byte("bbb")}, 1) - c.Assert(percentage, Equals, 0) - c.Assert(newRegionsID, HasLen, 0) + percentage, newRegionsID = splitter.SplitRegions(suite.ctx, [][]byte{[]byte("aaa"), []byte("bbb")}, 1) + suite.Equal(0, percentage) + suite.Len(newRegionsID, 0) } -func (s *testRegionSplitterSuite) TestGroupKeysByRegion(c *C) { +func (suite *regionSplitterTestSuite) TestGroupKeysByRegion() { opt := config.NewTestOptions() opt.SetPlacementRuleEnabled(false) - tc := mockcluster.NewCluster(s.ctx, opt) + tc := mockcluster.NewCluster(suite.ctx, opt) handler := newMockSplitRegionsHandler() tc.AddLeaderRegionWithRange(1, "aaa", "ccc", 2, 3, 4) tc.AddLeaderRegionWithRange(2, "ccc", "eee", 2, 3, 4) @@ -117,18 +122,18 @@ func (s *testRegionSplitterSuite) TestGroupKeysByRegion(c *C) { []byte("fff"), []byte("zzz"), }) - c.Assert(groupKeys, HasLen, 3) + suite.Len(groupKeys, 3) for k, v := range groupKeys { switch k { case uint64(1): - c.Assert(v.keys, HasLen, 1) - c.Assert(v.keys[0], DeepEquals, []byte("bbb")) + suite.Len(v.keys, 1) + suite.Equal([]byte("bbb"), v.keys[0]) case uint64(2): - c.Assert(v.keys, HasLen, 1) - c.Assert(v.keys[0], DeepEquals, []byte("ddd")) + suite.Len(v.keys, 1) + suite.Equal([]byte("ddd"), v.keys[0]) case uint64(3): - c.Assert(v.keys, HasLen, 1) - c.Assert(v.keys[0], DeepEquals, []byte("fff")) + suite.Len(v.keys, 1) + suite.Equal([]byte("fff"), v.keys[0]) } } } diff --git a/server/schedule/waiting_operator_test.go b/server/schedule/waiting_operator_test.go index 9237505acc3..bc7c37936db 100644 --- a/server/schedule/waiting_operator_test.go +++ b/server/schedule/waiting_operator_test.go @@ -15,24 +15,23 @@ package schedule import ( - . "github.com/pingcap/check" + "testing" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/server/core" "github.com/tikv/pd/server/schedule/operator" ) -var _ = Suite(&testWaitingOperatorSuite{}) - -type testWaitingOperatorSuite struct{} - -func (s *testWaitingOperatorSuite) TestRandBuckets(c *C) { +func TestRandBuckets(t *testing.T) { + re := require.New(t) rb := NewRandBuckets() addOperators(rb) for i := 0; i < 3; i++ { op := rb.GetOperator() - c.Assert(op, NotNil) + re.NotNil(op) } - c.Assert(rb.GetOperator(), IsNil) + re.Nil(rb.GetOperator()) } func addOperators(wop WaitingOperator) { @@ -52,13 +51,15 @@ func addOperators(wop WaitingOperator) { wop.PutOperator(op) } -func (s *testWaitingOperatorSuite) TestListOperator(c *C) { +func TestListOperator(t *testing.T) { + re := require.New(t) rb := NewRandBuckets() addOperators(rb) - c.Assert(rb.ListOperator(), HasLen, 3) + re.Len(rb.ListOperator(), 3) } -func (s *testWaitingOperatorSuite) TestRandomBucketsWithMergeRegion(c *C) { +func TestRandomBucketsWithMergeRegion(t *testing.T) { + re := require.New(t) rb := NewRandBuckets() descs := []string{"merge-region", "admin-merge-region", "random-merge"} for j := 0; j < 100; j++ { @@ -105,8 +106,8 @@ func (s *testWaitingOperatorSuite) TestRandomBucketsWithMergeRegion(c *C) { for i := 0; i < 2; i++ { op := rb.GetOperator() - c.Assert(op, NotNil) + re.NotNil(op) } - c.Assert(rb.GetOperator(), IsNil) + re.Nil(rb.GetOperator()) } } diff --git a/tests/server/config/config_test.go b/tests/server/config/config_test.go index 3f333c1722b..9f46f5a6676 100644 --- a/tests/server/config/config_test.go +++ b/tests/server/config/config_test.go @@ -46,7 +46,7 @@ func TestRateLimitConfigReload(t *testing.T) { re.NotEmpty(cluster.WaitLeader()) leader := cluster.GetServer(cluster.GetLeader()) - re.Len(leader.GetServer().GetServiceMiddlewareConfig().RateLimitConfig.LimiterConfig, 0) + re.Empty(leader.GetServer().GetServiceMiddlewareConfig().RateLimitConfig.LimiterConfig) limitCfg := make(map[string]ratelimit.DimensionConfig) limitCfg["GetRegions"] = ratelimit.DimensionConfig{QPS: 1} From e0084c5d7ac17aeac78917d788e8e6155f4b4a8c Mon Sep 17 00:00:00 2001 From: Shirly Date: Fri, 24 Jun 2022 10:46:37 +0800 Subject: [PATCH 77/82] schedulers/balance_leader: use binary search instead of store in indexMap (#5188) close tikv/pd#5187 Signed-off-by: shirly Co-authored-by: Ti Chi Robot --- pkg/typeutil/comparison.go | 10 +- pkg/typeutil/comparison_test.go | 9 ++ server/schedulers/balance_leader.go | 128 ++++++++++++++--------- server/schedulers/balance_leader_test.go | 39 +++++++ server/schedulers/balance_test.go | 78 ++++++++------ 5 files changed, 178 insertions(+), 86 deletions(-) diff --git a/pkg/typeutil/comparison.go b/pkg/typeutil/comparison.go index 8c0362c29ef..2c94ad94182 100644 --- a/pkg/typeutil/comparison.go +++ b/pkg/typeutil/comparison.go @@ -14,7 +14,10 @@ package typeutil -import "time" +import ( + "math" + "time" +) // MinUint64 returns the min value between two variables whose type are uint64. func MinUint64(a, b uint64) uint64 { @@ -52,3 +55,8 @@ func StringsEqual(a, b []string) bool { } return true } + +// Float64Equal checks if two float64 are equal. +func Float64Equal(a, b float64) bool { + return math.Abs(a-b) <= 1e-6 +} diff --git a/pkg/typeutil/comparison_test.go b/pkg/typeutil/comparison_test.go index 2a4774091be..9f6f832b1e8 100644 --- a/pkg/typeutil/comparison_test.go +++ b/pkg/typeutil/comparison_test.go @@ -15,6 +15,7 @@ package typeutil import ( + "math/rand" "testing" "time" @@ -44,3 +45,11 @@ func TestMinDuration(t *testing.T) { re.Equal(time.Second, MinDuration(time.Second, time.Minute)) re.Equal(time.Second, MinDuration(time.Second, time.Second)) } + +func TestEqualFloat(t *testing.T) { + t.Parallel() + re := require.New(t) + f1 := rand.Float64() + re.True(Float64Equal(f1, f1*1.000)) + re.True(Float64Equal(f1, f1/1.000)) +} diff --git a/server/schedulers/balance_leader.go b/server/schedulers/balance_leader.go index f5c9c667264..3000afdd7c9 100644 --- a/server/schedulers/balance_leader.go +++ b/server/schedulers/balance_leader.go @@ -28,6 +28,7 @@ import ( "github.com/tikv/pd/pkg/errs" "github.com/tikv/pd/pkg/reflectutil" "github.com/tikv/pd/pkg/syncutil" + "github.com/tikv/pd/pkg/typeutil" "github.com/tikv/pd/server/core" "github.com/tikv/pd/server/schedule" "github.com/tikv/pd/server/schedule/filter" @@ -249,20 +250,41 @@ func (l *balanceLeaderScheduler) IsScheduleAllowed(cluster schedule.Cluster) boo return allowed } +// candidateStores for balance_leader, order by `getStore` `asc` type candidateStores struct { - stores []*core.StoreInfo - storeIndexMap map[uint64]int - index int - compareOption func([]*core.StoreInfo) func(int, int) bool + stores []*core.StoreInfo + getScore func(*core.StoreInfo) float64 + index int + asc bool } -func newCandidateStores(stores []*core.StoreInfo, compareOption func([]*core.StoreInfo) func(int, int) bool) *candidateStores { - cs := &candidateStores{stores: stores, compareOption: compareOption} - cs.storeIndexMap = map[uint64]int{} - cs.initSort() +func newCandidateStores(stores []*core.StoreInfo, asc bool, getScore func(*core.StoreInfo) float64) *candidateStores { + cs := &candidateStores{stores: stores, getScore: getScore, asc: asc} + sort.Slice(cs.stores, cs.sortFunc()) return cs } +func (cs *candidateStores) sortFunc() (less func(int, int) bool) { + less = func(i, j int) bool { + scorei := cs.getScore(cs.stores[i]) + scorej := cs.getScore(cs.stores[j]) + return cs.less(cs.stores[i].GetID(), scorei, cs.stores[j].GetID(), scorej) + } + return less +} + +func (cs *candidateStores) less(iID uint64, scorei float64, jID uint64, scorej float64) bool { + if typeutil.Float64Equal(scorei, scorej) { + // when the stores share the same score, returns the one with the bigger ID, + // Since we assume that the bigger storeID, the newer store(which would be scheduled as soon as possible). + return iID > jID + } + if cs.asc { + return scorei < scorej + } + return scorei > scorej +} + // hasStore returns returns true when there are leftover stores. func (cs *candidateStores) hasStore() bool { return cs.index < len(cs.stores) @@ -276,23 +298,47 @@ func (cs *candidateStores) next() { cs.index++ } -func (cs *candidateStores) initSort() { - sort.Slice(cs.stores, cs.compareOption(cs.stores)) - for i := 0; i < len(cs.stores); i++ { - cs.storeIndexMap[cs.stores[i].GetID()] = i +func (cs *candidateStores) binarySearch(store *core.StoreInfo) (index int) { + score := cs.getScore(store) + searchFunc := func(i int) bool { + curScore := cs.getScore(cs.stores[i]) + return !cs.less(cs.stores[i].GetID(), curScore, store.GetID(), score) } + return sort.Search(len(cs.stores)-1, searchFunc) } -func (cs *candidateStores) reSort(stores ...*core.StoreInfo) { +// return the slice of index for the searched stores. +func (cs *candidateStores) binarySearchStores(stores ...*core.StoreInfo) (offsets []int) { if !cs.hasStore() { return } for _, store := range stores { - index, ok := cs.storeIndexMap[store.GetID()] - if !ok { - continue + index := cs.binarySearch(store) + offsets = append(offsets, index) + } + return offsets +} + +// resortStoreWithPos is used to sort stores again after creating an operator. +// It will repeatedly swap the specific store and next store if they are in wrong order. +// In general, it has very few swaps. In the worst case, the time complexity is O(n). +func (cs *candidateStores) resortStoreWithPos(pos int) { + swapper := func(i, j int) { cs.stores[i], cs.stores[j] = cs.stores[j], cs.stores[i] } + score := cs.getScore(cs.stores[pos]) + storeID := cs.stores[pos].GetID() + for ; pos+1 < len(cs.stores); pos++ { + curScore := cs.getScore(cs.stores[pos+1]) + if cs.less(storeID, score, cs.stores[pos+1].GetID(), curScore) { + break } - resortStores(cs.stores, cs.storeIndexMap, index, cs.compareOption(cs.stores)) + swapper(pos, pos+1) + } + for ; pos > 1; pos-- { + curScore := cs.getScore(cs.stores[pos-1]) + if !cs.less(storeID, score, cs.stores[pos-1].GetID(), curScore) { + break + } + swapper(pos, pos-1) } } @@ -308,24 +354,11 @@ func (l *balanceLeaderScheduler) Schedule(cluster schedule.Cluster) []*operator. plan := newBalancePlan(kind, cluster, opInfluence) stores := cluster.GetStores() - greaterOption := func(stores []*core.StoreInfo) func(int, int) bool { - return func(i, j int) bool { - iOp := plan.GetOpInfluence(stores[i].GetID()) - jOp := plan.GetOpInfluence(stores[j].GetID()) - return stores[i].LeaderScore(plan.kind.Policy, iOp) > - stores[j].LeaderScore(plan.kind.Policy, jOp) - } - } - lessOption := func(stores []*core.StoreInfo) func(int, int) bool { - return func(i, j int) bool { - iOp := plan.GetOpInfluence(stores[i].GetID()) - jOp := plan.GetOpInfluence(stores[j].GetID()) - return stores[i].LeaderScore(plan.kind.Policy, iOp) < - stores[j].LeaderScore(plan.kind.Policy, jOp) - } + scoreFunc := func(store *core.StoreInfo) float64 { + return store.LeaderScore(plan.kind.Policy, plan.GetOpInfluence(store.GetID())) } - sourceCandidate := newCandidateStores(filter.SelectSourceStores(stores, l.filters, cluster.GetOpts()), greaterOption) - targetCandidate := newCandidateStores(filter.SelectTargetStores(stores, l.filters, cluster.GetOpts()), lessOption) + sourceCandidate := newCandidateStores(filter.SelectSourceStores(stores, l.filters, cluster.GetOpts()), false, scoreFunc) + targetCandidate := newCandidateStores(filter.SelectTargetStores(stores, l.filters, cluster.GetOpts()), true, scoreFunc) usedRegions := make(map[uint64]struct{}) result := make([]*operator.Operator, 0, batch) @@ -395,26 +428,17 @@ func createTransferLeaderOperator(cs *candidateStores, dir string, l *balanceLea func makeInfluence(op *operator.Operator, plan *balancePlan, usedRegions map[uint64]struct{}, candidates ...*candidateStores) { usedRegions[op.RegionID()] = struct{}{} - schedule.AddOpInfluence(op, plan.opInfluence, plan.Cluster) - for _, candidate := range candidates { - candidate.reSort(plan.source, plan.target) - } -} - -// resortStores is used to sort stores again after creating an operator. -// It will repeatedly swap the specific store and next store if they are in wrong order. -// In general, it has very few swaps. In the worst case, the time complexity is O(n). -func resortStores(stores []*core.StoreInfo, index map[uint64]int, pos int, less func(i, j int) bool) { - swapper := func(i, j int) { stores[i], stores[j] = stores[j], stores[i] } - for ; pos+1 < len(stores) && !less(pos, pos+1); pos++ { - swapper(pos, pos+1) - index[stores[pos].GetID()] = pos + candidateUpdateStores := make([][]int, len(candidates)) + for id, candidate := range candidates { + storesIDs := candidate.binarySearchStores(plan.source, plan.target) + candidateUpdateStores[id] = storesIDs } - for ; pos > 1 && less(pos, pos-1); pos-- { - swapper(pos, pos-1) - index[stores[pos].GetID()] = pos + schedule.AddOpInfluence(op, plan.opInfluence, plan.Cluster) + for id, candidate := range candidates { + for _, pos := range candidateUpdateStores[id] { + candidate.resortStoreWithPos(pos) + } } - index[stores[pos].GetID()] = pos } // transferLeaderOut transfers leader from the source store. diff --git a/server/schedulers/balance_leader_test.go b/server/schedulers/balance_leader_test.go index a74709de640..06257ff072d 100644 --- a/server/schedulers/balance_leader_test.go +++ b/server/schedulers/balance_leader_test.go @@ -15,9 +15,14 @@ package schedulers import ( + "context" + "math/rand" "testing" "github.com/stretchr/testify/require" + "github.com/tikv/pd/pkg/mock/mockcluster" + "github.com/tikv/pd/server/config" + "github.com/tikv/pd/server/core" ) func TestBalanceLeaderSchedulerConfigClone(t *testing.T) { @@ -36,3 +41,37 @@ func TestBalanceLeaderSchedulerConfigClone(t *testing.T) { conf2.Ranges[1] = keyRanges2[1] re.NotEqual(conf.Ranges, conf2.Ranges) } + +func BenchmarkCandidateStores(b *testing.B) { + ctx := context.Background() + opt := config.NewTestOptions() + tc := mockcluster.NewCluster(ctx, opt) + + for id := uint64(1); id < uint64(10000); id++ { + leaderCount := int(rand.Int31n(10000)) + tc.AddLeaderStore(id, leaderCount) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + updateAndResortStoresInCandidateStores(tc) + } +} + +func updateAndResortStoresInCandidateStores(tc *mockcluster.Cluster) { + deltaMap := make(map[uint64]int64) + getScore := func(store *core.StoreInfo) float64 { + return store.LeaderScore(0, deltaMap[store.GetID()]) + } + cs := newCandidateStores(tc.GetStores(), false, getScore) + stores := tc.GetStores() + // update score for store and reorder + for id, store := range stores { + offsets := cs.binarySearchStores(store) + if id%2 == 1 { + deltaMap[store.GetID()] = int64(rand.Int31n(10000)) + } else { + deltaMap[store.GetID()] = int64(-rand.Int31n(10000)) + } + cs.resortStoreWithPos(offsets[0]) + } +} diff --git a/server/schedulers/balance_test.go b/server/schedulers/balance_test.go index 8d70747429e..d4a001ac84c 100644 --- a/server/schedulers/balance_test.go +++ b/server/schedulers/balance_test.go @@ -620,43 +620,55 @@ func (s *testBalanceLeaderRangeSchedulerSuite) TestReSortStores(c *C) { s.tc.AddLeaderStore(5, 100) s.tc.AddLeaderStore(6, 0) stores := s.tc.Stores.GetStores() + sort.Slice(stores, func(i, j int) bool { + return stores[i].GetID() < stores[j].GetID() + }) deltaMap := make(map[uint64]int64) - less := func(i, j int) bool { - iOp := deltaMap[stores[i].GetID()] - jOp := deltaMap[stores[j].GetID()] - return stores[i].LeaderScore(0, iOp) > stores[j].LeaderScore(0, jOp) + getScore := func(store *core.StoreInfo) float64 { + return store.LeaderScore(0, deltaMap[store.GetID()]) } - - sort.Slice(stores, less) - storeIndexMap := map[uint64]int{} - for i := 0; i < len(stores); i++ { - storeIndexMap[stores[i].GetID()] = i - } - c.Assert(stores[0].GetID(), Equals, uint64(1)) - c.Assert(storeIndexMap[uint64(1)], Equals, 0) - deltaMap[1] = -1 - - resortStores(stores, storeIndexMap, storeIndexMap[uint64(1)], less) - c.Assert(stores[0].GetID(), Equals, uint64(1)) - c.Assert(storeIndexMap[uint64(1)], Equals, 0) + candidateStores := make([]*core.StoreInfo, 0) + // order by score desc. + cs := newCandidateStores(append(candidateStores, stores...), false, getScore) + // in candidate,the order stores:1(104),5(100),4(100),6,3,2 + // store 4 should in pos 2 + c.Assert(cs.binarySearch(stores[3]), Equals, 2) + + // store 1 should in pos 0 + store1 := stores[0] + c.Assert(cs.binarySearch(store1), Equals, 0) + deltaMap[store1.GetID()] = -1 // store 1 + cs.resortStoreWithPos(0) + // store 1 should still in pos 0. + c.Assert(cs.stores[0].GetID(), Equals, uint64(1)) + curIndx := cs.binarySearch(store1) + c.Assert(0, Equals, curIndx) deltaMap[1] = -4 - resortStores(stores, storeIndexMap, storeIndexMap[uint64(1)], less) - c.Assert(stores[2].GetID(), Equals, uint64(1)) - c.Assert(storeIndexMap[uint64(1)], Equals, 2) - topID := stores[0].GetID() - deltaMap[topID] = -1 - resortStores(stores, storeIndexMap, storeIndexMap[topID], less) - c.Assert(stores[1].GetID(), Equals, uint64(1)) - c.Assert(storeIndexMap[uint64(1)], Equals, 1) - c.Assert(stores[2].GetID(), Equals, topID) - c.Assert(storeIndexMap[topID], Equals, 2) - - bottomID := stores[5].GetID() - deltaMap[bottomID] = 4 - resortStores(stores, storeIndexMap, storeIndexMap[bottomID], less) - c.Assert(stores[3].GetID(), Equals, bottomID) - c.Assert(storeIndexMap[bottomID], Equals, 3) + // store 1 update the scores to 104-4=100 + // the order stores should be:5(100),4(100),1(100),6,3,2 + cs.resortStoreWithPos(curIndx) + c.Assert(cs.stores[2].GetID(), Equals, uint64(1)) + c.Assert(cs.binarySearch(store1), Equals, 2) + // the top store is : 5(100) + topStore := cs.stores[0] + topStorePos := cs.binarySearch(topStore) + deltaMap[topStore.GetID()] = -1 + cs.resortStoreWithPos(topStorePos) + + // after recorder, the order stores should be: 4(100),1(100),5(99),6,3,2 + c.Assert(cs.stores[1].GetID(), Equals, uint64(1)) + c.Assert(cs.binarySearch(store1), Equals, 1) + c.Assert(cs.stores[2].GetID(), Equals, topStore.GetID()) + c.Assert(cs.binarySearch(topStore), Equals, 2) + + bottomStore := cs.stores[5] + deltaMap[bottomStore.GetID()] = 4 + cs.resortStoreWithPos(5) + + // the order stores should be: 4(100),1(100),5(99),2(5),6,3 + c.Assert(cs.stores[3].GetID(), Equals, bottomStore.GetID()) + c.Assert(cs.binarySearch(bottomStore), Equals, 3) } var _ = Suite(&testBalanceRegionSchedulerSuite{}) From 5dc97f8cfb1dcc41533e05929f4d3fcfcf91d491 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Fri, 24 Jun 2022 11:04:37 +0800 Subject: [PATCH 78/82] *: clean up the surrounding code of check pkg (#5225) ref tikv/pd#4813, close tikv/pd#5105 Clean up the surrounding code of check pkg. Signed-off-by: JmPotato --- client/testutil/testutil.go | 28 ++++----- pkg/cache/cache_test.go | 4 +- pkg/testutil/testutil.go | 58 +++++-------------- scripts/check-test.sh | 23 -------- server/api/tso_test.go | 2 +- .../schedule/operator/status_tracker_test.go | 4 +- tests/client/client_test.go | 8 +-- tests/server/member/member_test.go | 2 +- tests/server/tso/allocator_test.go | 2 +- tests/server/tso/manager_test.go | 2 +- 10 files changed, 39 insertions(+), 94 deletions(-) diff --git a/client/testutil/testutil.go b/client/testutil/testutil.go index 79a3c9eb913..de9821024e2 100644 --- a/client/testutil/testutil.go +++ b/client/testutil/testutil.go @@ -21,34 +21,34 @@ import ( ) const ( - defaultWaitFor = time.Second * 20 - defaultSleepInterval = time.Millisecond * 100 + defaultWaitFor = time.Second * 20 + defaultTickInterval = time.Millisecond * 100 ) // WaitOp represents available options when execute Eventually. type WaitOp struct { - waitFor time.Duration - sleepInterval time.Duration + waitFor time.Duration + tickInterval time.Duration } -// WaitOption configures WaitOp +// WaitOption configures WaitOp. type WaitOption func(op *WaitOp) -// WithSleepInterval specify the sleep duration -func WithSleepInterval(sleep time.Duration) WaitOption { - return func(op *WaitOp) { op.sleepInterval = sleep } -} - -// WithWaitFor specify the max wait for duration +// WithWaitFor specify the max wait duration. func WithWaitFor(waitFor time.Duration) WaitOption { return func(op *WaitOp) { op.waitFor = waitFor } } +// WithTickInterval specify the tick interval to check the condition. +func WithTickInterval(tickInterval time.Duration) WaitOption { + return func(op *WaitOp) { op.tickInterval = tickInterval } +} + // Eventually asserts that given condition will be met in a period of time. func Eventually(re *require.Assertions, condition func() bool, opts ...WaitOption) { option := &WaitOp{ - waitFor: defaultWaitFor, - sleepInterval: defaultSleepInterval, + waitFor: defaultWaitFor, + tickInterval: defaultTickInterval, } for _, opt := range opts { opt(option) @@ -56,6 +56,6 @@ func Eventually(re *require.Assertions, condition func() bool, opts ...WaitOptio re.Eventually( condition, option.waitFor, - option.sleepInterval, + option.tickInterval, ) } diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go index f708f966401..5bcff65f704 100644 --- a/pkg/cache/cache_test.go +++ b/pkg/cache/cache_test.go @@ -50,7 +50,7 @@ func TestExpireRegionCache(t *testing.T) { re.True(ok) re.Equal(expV, v2.(string)) - cache.PutWithTTL(11, "11", 1*time.Second) + cache.PutWithTTL(11, "11", time.Second) time.Sleep(5 * time.Second) k, v, success = cache.pop() re.False(success) @@ -58,7 +58,7 @@ func TestExpireRegionCache(t *testing.T) { re.Nil(v) // Test Get - cache.PutWithTTL(1, 1, 1*time.Second) + cache.PutWithTTL(1, 1, time.Second) cache.PutWithTTL(2, "v2", 5*time.Second) cache.PutWithTTL(3, 3.0, 5*time.Second) diff --git a/pkg/testutil/testutil.go b/pkg/testutil/testutil.go index 236438ecfd1..abe714384a5 100644 --- a/pkg/testutil/testutil.go +++ b/pkg/testutil/testutil.go @@ -19,72 +19,40 @@ import ( "strings" "time" - "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/pdpb" "github.com/stretchr/testify/require" "google.golang.org/grpc" ) const ( - defaultWaitRetryTimes = 200 - defaultSleepInterval = time.Millisecond * 100 - defaultWaitFor = time.Second * 20 + defaultWaitFor = time.Second * 20 + defaultTickInterval = time.Millisecond * 100 ) -// CheckFunc is a condition checker that passed to WaitUntil. Its implementation -// may call c.Fatal() to abort the test, or c.Log() to add more information. -type CheckFunc func() bool - -// WaitOp represents available options when execute WaitUntil +// WaitOp represents available options when execute Eventually. type WaitOp struct { - retryTimes int - sleepInterval time.Duration - waitFor time.Duration + waitFor time.Duration + tickInterval time.Duration } -// WaitOption configures WaitOp +// WaitOption configures WaitOp. type WaitOption func(op *WaitOp) -// WithRetryTimes specify the retry times -func WithRetryTimes(retryTimes int) WaitOption { - return func(op *WaitOp) { op.retryTimes = retryTimes } -} - -// WithSleepInterval specify the sleep duration -func WithSleepInterval(sleep time.Duration) WaitOption { - return func(op *WaitOp) { op.sleepInterval = sleep } -} - -// WithWaitFor specify the max wait for duration +// WithWaitFor specify the max wait duration. func WithWaitFor(waitFor time.Duration) WaitOption { return func(op *WaitOp) { op.waitFor = waitFor } } -// WaitUntil repeatedly evaluates f() for a period of time, util it returns true. -// NOTICE: this function will be removed soon, please use `Eventually` instead. -func WaitUntil(c *check.C, f CheckFunc, opts ...WaitOption) { - c.Log("wait start") - option := &WaitOp{ - retryTimes: defaultWaitRetryTimes, - sleepInterval: defaultSleepInterval, - } - for _, opt := range opts { - opt(option) - } - for i := 0; i < option.retryTimes; i++ { - if f() { - return - } - time.Sleep(option.sleepInterval) - } - c.Fatal("wait timeout") +// WithTickInterval specify the tick interval to check the condition. +func WithTickInterval(tickInterval time.Duration) WaitOption { + return func(op *WaitOp) { op.tickInterval = tickInterval } } // Eventually asserts that given condition will be met in a period of time. func Eventually(re *require.Assertions, condition func() bool, opts ...WaitOption) { option := &WaitOp{ - waitFor: defaultWaitFor, - sleepInterval: defaultSleepInterval, + waitFor: defaultWaitFor, + tickInterval: defaultTickInterval, } for _, opt := range opts { opt(option) @@ -92,7 +60,7 @@ func Eventually(re *require.Assertions, condition func() bool, opts ...WaitOptio re.Eventually( condition, option.waitFor, - option.sleepInterval, + option.tickInterval, ) } diff --git a/scripts/check-test.sh b/scripts/check-test.sh index 867d234cce6..c3168066e3d 100755 --- a/scripts/check-test.sh +++ b/scripts/check-test.sh @@ -1,28 +1,5 @@ #!/bin/bash -# TODO: remove this script after migrating all tests to the new test framework. - -# Check if there are any packages foget to add `TestingT` when use "github.com/pingcap/check". - -res=$(diff <(grep -rl --include=\*_test.go "github.com/pingcap/check" . | xargs -L 1 dirname | sort -u) \ - <(grep -rl --include=\*_test.go -E "^\s*(check\.)?TestingT\(" . | xargs -L 1 dirname | sort -u)) - -if [ "$res" ]; then - echo "following packages may be lost TestingT:" - echo "$res" | awk '{if(NF>1){print $2}}' - exit 1 -fi - -# Check if there are duplicated `TestingT` in package. - -res=$(grep -r --include=\*_test.go "TestingT(t)" . | cut -f1 | xargs -L 1 dirname | sort | uniq -d) - -if [ "$res" ]; then - echo "following packages may have duplicated TestingT:" - echo "$res" - exit 1 -fi - # Check if there is any inefficient assert function usage in package. res=$(grep -rn --include=\*_test.go -E "(re|suite|require)\.(True|False)\((t, )?reflect\.DeepEqual\(" . | sort -u) \ diff --git a/server/api/tso_test.go b/server/api/tso_test.go index 2f59a2ecf07..07770b130d8 100644 --- a/server/api/tso_test.go +++ b/server/api/tso_test.go @@ -58,7 +58,7 @@ func (suite *tsoTestSuite) TestTransferAllocator() { suite.svr.GetTSOAllocatorManager().ClusterDCLocationChecker() _, err := suite.svr.GetTSOAllocatorManager().GetAllocator("dc-1") return err == nil - }, tu.WithRetryTimes(5), tu.WithSleepInterval(3*time.Second)) + }, tu.WithWaitFor(15*time.Second), tu.WithTickInterval(3*time.Second)) addr := suite.urlPrefix + "/tso/allocator/transfer/pd1?dcLocation=dc-1" err := tu.CheckPostJSON(testDialClient, addr, nil, tu.StatusOK(re)) suite.NoError(err) diff --git a/server/schedule/operator/status_tracker_test.go b/server/schedule/operator/status_tracker_test.go index d4441b0e7b6..04eb16d6ecd 100644 --- a/server/schedule/operator/status_tracker_test.go +++ b/server/schedule/operator/status_tracker_test.go @@ -123,11 +123,11 @@ func TestCheckStepTimeout(t *testing.T) { status OpStatus }{{ step: AddLearner{}, - start: time.Now().Add(-(SlowOperatorWaitTime - 1*time.Second)), + start: time.Now().Add(-(SlowOperatorWaitTime - time.Second)), status: STARTED, }, { step: AddLearner{}, - start: time.Now().Add(-(SlowOperatorWaitTime + 1*time.Second)), + start: time.Now().Add(-(SlowOperatorWaitTime + time.Second)), status: TIMEOUT, }} diff --git a/tests/client/client_test.go b/tests/client/client_test.go index 8e67cfb4949..2709c7cbc24 100644 --- a/tests/client/client_test.go +++ b/tests/client/client_test.go @@ -424,14 +424,14 @@ func TestCustomTimeout(t *testing.T) { defer cluster.Destroy() endpoints := runServer(re, cluster) - cli := setupCli(re, ctx, endpoints, pd.WithCustomTimeoutOption(1*time.Second)) + cli := setupCli(re, ctx, endpoints, pd.WithCustomTimeoutOption(time.Second)) start := time.Now() re.NoError(failpoint.Enable("github.com/tikv/pd/server/customTimeout", "return(true)")) _, err = cli.GetAllStores(context.TODO()) re.NoError(failpoint.Disable("github.com/tikv/pd/server/customTimeout")) re.Error(err) - re.GreaterOrEqual(time.Since(start), 1*time.Second) + re.GreaterOrEqual(time.Since(start), time.Second) re.Less(time.Since(start), 2*time.Second) } @@ -1306,7 +1306,7 @@ func (suite *clientTestSuite) TestScatterRegion() { return resp.GetRegionId() == regionID && string(resp.GetDesc()) == "scatter-region" && resp.GetStatus() == pdpb.OperatorStatus_RUNNING - }, testutil.WithSleepInterval(1*time.Second)) + }, testutil.WithTickInterval(time.Second)) // Test interface `ScatterRegion`. // TODO: Deprecate interface `ScatterRegion`. @@ -1323,5 +1323,5 @@ func (suite *clientTestSuite) TestScatterRegion() { return resp.GetRegionId() == regionID && string(resp.GetDesc()) == "scatter-region" && resp.GetStatus() == pdpb.OperatorStatus_RUNNING - }, testutil.WithSleepInterval(1*time.Second)) + }, testutil.WithTickInterval(time.Second)) } diff --git a/tests/server/member/member_test.go b/tests/server/member/member_test.go index 229b4756045..15819bb4bf4 100644 --- a/tests/server/member/member_test.go +++ b/tests/server/member/member_test.go @@ -199,7 +199,7 @@ func waitEtcdLeaderChange(re *require.Assertions, server *tests.TestServer, old return false } return leader != old - }, testutil.WithWaitFor(time.Second*90), testutil.WithSleepInterval(time.Second)) + }, testutil.WithWaitFor(90*time.Second), testutil.WithTickInterval(time.Second)) return leader } diff --git a/tests/server/tso/allocator_test.go b/tests/server/tso/allocator_test.go index 59cedea0783..4640a74239a 100644 --- a/tests/server/tso/allocator_test.go +++ b/tests/server/tso/allocator_test.go @@ -166,7 +166,7 @@ func TestPriorityAndDifferentLocalTSO(t *testing.T) { defer wg.Done() testutil.Eventually(re, func() bool { return cluster.WaitAllocatorLeader(dc) == serName - }, testutil.WithWaitFor(time.Second*90), testutil.WithSleepInterval(time.Second)) + }, testutil.WithWaitFor(90*time.Second), testutil.WithTickInterval(time.Second)) }(serverName, dcLocation) } wg.Wait() diff --git a/tests/server/tso/manager_test.go b/tests/server/tso/manager_test.go index 00278544f55..5d19b3edea3 100644 --- a/tests/server/tso/manager_test.go +++ b/tests/server/tso/manager_test.go @@ -182,7 +182,7 @@ func TestNextLeaderKey(t *testing.T) { cluster.CheckClusterDCLocation() currName := cluster.WaitAllocatorLeader("dc-1") return currName == name - }, testutil.WithSleepInterval(1*time.Second)) + }, testutil.WithTickInterval(time.Second)) return } } From 1bcdafe1285fab98b0804db325e25c675db2495f Mon Sep 17 00:00:00 2001 From: buffer <1045931706@qq.com> Date: Fri, 24 Jun 2022 11:14:37 +0800 Subject: [PATCH 79/82] pd-simulator: clone before request (#5223) close tikv/pd#5222 Signed-off-by: bufferflies <1045931706@qq.com> Co-authored-by: Ti Chi Robot --- tools/pd-simulator/simulator/client.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tools/pd-simulator/simulator/client.go b/tools/pd-simulator/simulator/client.go index 3cbad1c6fee..13a4b7c0bf1 100644 --- a/tools/pd-simulator/simulator/client.go +++ b/tools/pd-simulator/simulator/client.go @@ -20,6 +20,7 @@ import ( "sync" "time" + "github.com/gogo/protobuf/proto" "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" @@ -198,7 +199,8 @@ func (c *client) reportRegionHeartbeat(ctx context.Context, stream pdpb.PD_Regio defer wg.Done() for { select { - case region := <-c.reportRegionHeartbeatCh: + case r := <-c.reportRegionHeartbeatCh: + region := r.Clone() request := &pdpb.RegionHeartbeatRequest{ Header: c.requestHeader(), Region: region.GetMeta(), @@ -263,8 +265,8 @@ func (c *client) Bootstrap(ctx context.Context, store *metapb.Store, region *met } _, err = c.pdClient().Bootstrap(ctx, &pdpb.BootstrapRequest{ Header: c.requestHeader(), - Store: store, - Region: region, + Store: proto.Clone(store).(*metapb.Store), + Region: proto.Clone(region).(*metapb.Region), }) if err != nil { return err @@ -276,7 +278,7 @@ func (c *client) PutStore(ctx context.Context, store *metapb.Store) error { ctx, cancel := context.WithTimeout(ctx, pdTimeout) resp, err := c.pdClient().PutStore(ctx, &pdpb.PutStoreRequest{ Header: c.requestHeader(), - Store: store, + Store: proto.Clone(store).(*metapb.Store), }) cancel() if err != nil { @@ -293,7 +295,7 @@ func (c *client) StoreHeartbeat(ctx context.Context, stats *pdpb.StoreStats) err ctx, cancel := context.WithTimeout(ctx, pdTimeout) resp, err := c.pdClient().StoreHeartbeat(ctx, &pdpb.StoreHeartbeatRequest{ Header: c.requestHeader(), - Stats: stats, + Stats: proto.Clone(stats).(*pdpb.StoreStats), }) cancel() if err != nil { From 2eecaeffb45254d9299c6e848754ef927166a1be Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Fri, 24 Jun 2022 13:54:38 +0800 Subject: [PATCH 80/82] *: unify the must wait leader (#5226) ref tikv/pd#4813 Signed-off-by: Ryan Leung --- server/api/admin_test.go | 2 +- server/api/checker_test.go | 2 +- server/api/cluster_test.go | 2 +- server/api/config_test.go | 2 +- server/api/hot_status_test.go | 2 +- server/api/label_test.go | 4 ++-- server/api/log_test.go | 2 +- server/api/min_resolved_ts_test.go | 2 +- server/api/operator_test.go | 4 ++-- server/api/pprof_test.go | 2 +- server/api/region_label_test.go | 2 +- server/api/region_test.go | 8 ++++---- server/api/rule_test.go | 2 +- server/api/scheduler_test.go | 2 +- server/api/server_test.go | 21 ++------------------- server/api/service_gc_safepoint_test.go | 2 +- server/api/service_middleware_test.go | 4 ++-- server/api/stats_test.go | 2 +- server/api/store_test.go | 2 +- server/api/trend_test.go | 2 +- server/api/tso_test.go | 2 +- server/api/unsafe_operation_test.go | 2 +- server/server_test.go | 16 +--------------- server/testutil.go | 20 ++++++++++++++++++++ tests/client/client_test.go | 15 +-------------- tests/server/api/api_test.go | 20 +++++--------------- tests/server/config/config_test.go | 21 +++++---------------- tests/server/member/member_test.go | 16 +--------------- 28 files changed, 62 insertions(+), 121 deletions(-) diff --git a/server/api/admin_test.go b/server/api/admin_test.go index ba9aaa875a4..f8fd0bcf74f 100644 --- a/server/api/admin_test.go +++ b/server/api/admin_test.go @@ -42,7 +42,7 @@ func TestAdminTestSuite(t *testing.T) { func (suite *adminTestSuite) SetupSuite() { re := suite.Require() suite.svr, suite.cleanup = mustNewServer(re) - mustWaitLeader(re, []*server.Server{suite.svr}) + server.MustWaitLeader(re, []*server.Server{suite.svr}) addr := suite.svr.GetAddr() suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) diff --git a/server/api/checker_test.go b/server/api/checker_test.go index a3ab815ffb7..d40a61f93a8 100644 --- a/server/api/checker_test.go +++ b/server/api/checker_test.go @@ -40,7 +40,7 @@ func TestCheckerTestSuite(t *testing.T) { func (suite *checkerTestSuite) SetupSuite() { re := suite.Require() suite.svr, suite.cleanup = mustNewServer(re) - mustWaitLeader(re, []*server.Server{suite.svr}) + server.MustWaitLeader(re, []*server.Server{suite.svr}) addr := suite.svr.GetAddr() suite.urlPrefix = fmt.Sprintf("%s%s/api/v1/checker", addr, apiPrefix) diff --git a/server/api/cluster_test.go b/server/api/cluster_test.go index 496d75e6f38..5ef1e5584bf 100644 --- a/server/api/cluster_test.go +++ b/server/api/cluster_test.go @@ -41,7 +41,7 @@ func TestClusterTestSuite(t *testing.T) { func (suite *clusterTestSuite) SetupSuite() { re := suite.Require() suite.svr, suite.cleanup = mustNewServer(re) - mustWaitLeader(re, []*server.Server{suite.svr}) + server.MustWaitLeader(re, []*server.Server{suite.svr}) addr := suite.svr.GetAddr() suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) diff --git a/server/api/config_test.go b/server/api/config_test.go index 144e511979a..6480e8f3967 100644 --- a/server/api/config_test.go +++ b/server/api/config_test.go @@ -44,7 +44,7 @@ func (suite *configTestSuite) SetupSuite() { suite.svr, suite.cleanup = mustNewServer(re, func(cfg *config.Config) { cfg.Replication.EnablePlacementRules = false }) - mustWaitLeader(re, []*server.Server{suite.svr}) + server.MustWaitLeader(re, []*server.Server{suite.svr}) addr := suite.svr.GetAddr() suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) diff --git a/server/api/hot_status_test.go b/server/api/hot_status_test.go index 66a4e29afb7..664d935563b 100644 --- a/server/api/hot_status_test.go +++ b/server/api/hot_status_test.go @@ -43,7 +43,7 @@ func TestHotStatusTestSuite(t *testing.T) { func (suite *hotStatusTestSuite) SetupSuite() { re := suite.Require() suite.svr, suite.cleanup = mustNewServer(re) - mustWaitLeader(re, []*server.Server{suite.svr}) + server.MustWaitLeader(re, []*server.Server{suite.svr}) addr := suite.svr.GetAddr() suite.urlPrefix = fmt.Sprintf("%s%s/api/v1/hotspot", addr, apiPrefix) diff --git a/server/api/label_test.go b/server/api/label_test.go index 6729abe45ea..93720525686 100644 --- a/server/api/label_test.go +++ b/server/api/label_test.go @@ -119,7 +119,7 @@ func (suite *labelsStoreTestSuite) SetupSuite() { suite.svr, suite.cleanup = mustNewServer(re, func(cfg *config.Config) { cfg.Replication.StrictlyMatchLabel = false }) - mustWaitLeader(re, []*server.Server{suite.svr}) + server.MustWaitLeader(re, []*server.Server{suite.svr}) addr := suite.svr.GetAddr() suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) @@ -205,7 +205,7 @@ func (suite *strictlyLabelsStoreTestSuite) SetupSuite() { cfg.Replication.StrictlyMatchLabel = true cfg.Replication.EnablePlacementRules = false }) - mustWaitLeader(re, []*server.Server{suite.svr}) + server.MustWaitLeader(re, []*server.Server{suite.svr}) suite.grpcSvr = &server.GrpcServer{Server: suite.svr} addr := suite.svr.GetAddr() diff --git a/server/api/log_test.go b/server/api/log_test.go index f03472b8146..e85ed9fa120 100644 --- a/server/api/log_test.go +++ b/server/api/log_test.go @@ -39,7 +39,7 @@ func TestLogTestSuite(t *testing.T) { func (suite *logTestSuite) SetupSuite() { re := suite.Require() suite.svr, suite.cleanup = mustNewServer(re) - mustWaitLeader(re, []*server.Server{suite.svr}) + server.MustWaitLeader(re, []*server.Server{suite.svr}) addr := suite.svr.GetAddr() suite.urlPrefix = fmt.Sprintf("%s%s/api/v1/admin", addr, apiPrefix) diff --git a/server/api/min_resolved_ts_test.go b/server/api/min_resolved_ts_test.go index 47c47713bff..a8a7252ed33 100644 --- a/server/api/min_resolved_ts_test.go +++ b/server/api/min_resolved_ts_test.go @@ -42,7 +42,7 @@ func (suite *minResolvedTSTestSuite) SetupSuite() { re := suite.Require() cluster.DefaultMinResolvedTSPersistenceInterval = time.Microsecond suite.svr, suite.cleanup = mustNewServer(re) - mustWaitLeader(re, []*server.Server{suite.svr}) + server.MustWaitLeader(re, []*server.Server{suite.svr}) addr := suite.svr.GetAddr() suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) diff --git a/server/api/operator_test.go b/server/api/operator_test.go index ba08b890b9b..871cb8cc355 100644 --- a/server/api/operator_test.go +++ b/server/api/operator_test.go @@ -55,7 +55,7 @@ func (suite *operatorTestSuite) SetupSuite() { re := suite.Require() suite.NoError(failpoint.Enable("github.com/tikv/pd/server/schedule/unexpectedOperator", "return(true)")) suite.svr, suite.cleanup = mustNewServer(re, func(cfg *config.Config) { cfg.Replication.MaxReplicas = 1 }) - mustWaitLeader(re, []*server.Server{suite.svr}) + server.MustWaitLeader(re, []*server.Server{suite.svr}) addr := suite.svr.GetAddr() suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) @@ -178,7 +178,7 @@ func (suite *transferRegionOperatorTestSuite) SetupSuite() { re := suite.Require() suite.NoError(failpoint.Enable("github.com/tikv/pd/server/schedule/unexpectedOperator", "return(true)")) suite.svr, suite.cleanup = mustNewServer(re, func(cfg *config.Config) { cfg.Replication.MaxReplicas = 3 }) - mustWaitLeader(re, []*server.Server{suite.svr}) + server.MustWaitLeader(re, []*server.Server{suite.svr}) addr := suite.svr.GetAddr() suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) diff --git a/server/api/pprof_test.go b/server/api/pprof_test.go index 3d80a325758..a0165aaf27d 100644 --- a/server/api/pprof_test.go +++ b/server/api/pprof_test.go @@ -38,7 +38,7 @@ func TestProfTestSuite(t *testing.T) { func (suite *profTestSuite) SetupSuite() { re := suite.Require() suite.svr, suite.cleanup = mustNewServer(re) - mustWaitLeader(re, []*server.Server{suite.svr}) + server.MustWaitLeader(re, []*server.Server{suite.svr}) addr := suite.svr.GetAddr() suite.urlPrefix = fmt.Sprintf("%s%s/api/v1/debug", addr, apiPrefix) diff --git a/server/api/region_label_test.go b/server/api/region_label_test.go index 0165ec7d37e..b0cc1b60570 100644 --- a/server/api/region_label_test.go +++ b/server/api/region_label_test.go @@ -42,7 +42,7 @@ func TestRegionLabelTestSuite(t *testing.T) { func (suite *regionLabelTestSuite) SetupSuite() { re := suite.Require() suite.svr, suite.cleanup = mustNewServer(re) - mustWaitLeader(re, []*server.Server{suite.svr}) + server.MustWaitLeader(re, []*server.Server{suite.svr}) addr := suite.svr.GetAddr() suite.urlPrefix = fmt.Sprintf("%s%s/api/v1/config/region-label/", addr, apiPrefix) diff --git a/server/api/region_test.go b/server/api/region_test.go index 168edd0e419..3bccc1f0af7 100644 --- a/server/api/region_test.go +++ b/server/api/region_test.go @@ -96,7 +96,7 @@ func TestRegionTestSuite(t *testing.T) { func (suite *regionTestSuite) SetupSuite() { re := suite.Require() suite.svr, suite.cleanup = mustNewServer(re) - mustWaitLeader(re, []*server.Server{suite.svr}) + server.MustWaitLeader(re, []*server.Server{suite.svr}) addr := suite.svr.GetAddr() suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) @@ -437,7 +437,7 @@ func TestGetRegionTestSuite(t *testing.T) { func (suite *getRegionTestSuite) SetupSuite() { re := suite.Require() suite.svr, suite.cleanup = mustNewServer(re) - mustWaitLeader(re, []*server.Server{suite.svr}) + server.MustWaitLeader(re, []*server.Server{suite.svr}) addr := suite.svr.GetAddr() suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) @@ -545,7 +545,7 @@ func TestGetRegionRangeHolesTestSuite(t *testing.T) { func (suite *getRegionRangeHolesTestSuite) SetupSuite() { re := suite.Require() suite.svr, suite.cleanup = mustNewServer(re) - mustWaitLeader(re, []*server.Server{suite.svr}) + server.MustWaitLeader(re, []*server.Server{suite.svr}) addr := suite.svr.GetAddr() suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) mustBootstrapCluster(re, suite.svr) @@ -594,7 +594,7 @@ func TestRegionsReplicatedTestSuite(t *testing.T) { func (suite *regionsReplicatedTestSuite) SetupSuite() { re := suite.Require() suite.svr, suite.cleanup = mustNewServer(re) - mustWaitLeader(re, []*server.Server{suite.svr}) + server.MustWaitLeader(re, []*server.Server{suite.svr}) addr := suite.svr.GetAddr() suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) diff --git a/server/api/rule_test.go b/server/api/rule_test.go index 63af8b19c1c..669fca34489 100644 --- a/server/api/rule_test.go +++ b/server/api/rule_test.go @@ -43,7 +43,7 @@ func TestRuleTestSuite(t *testing.T) { func (suite *ruleTestSuite) SetupSuite() { re := suite.Require() suite.svr, suite.cleanup = mustNewServer(re) - mustWaitLeader(re, []*server.Server{suite.svr}) + server.MustWaitLeader(re, []*server.Server{suite.svr}) addr := suite.svr.GetAddr() suite.urlPrefix = fmt.Sprintf("%s%s/api/v1/config", addr, apiPrefix) diff --git a/server/api/scheduler_test.go b/server/api/scheduler_test.go index c4b30595967..bdcf85c2ea1 100644 --- a/server/api/scheduler_test.go +++ b/server/api/scheduler_test.go @@ -45,7 +45,7 @@ func TestScheduleTestSuite(t *testing.T) { func (suite *scheduleTestSuite) SetupSuite() { re := suite.Require() suite.svr, suite.cleanup = mustNewServer(re) - mustWaitLeader(re, []*server.Server{suite.svr}) + server.MustWaitLeader(re, []*server.Server{suite.svr}) addr := suite.svr.GetAddr() suite.urlPrefix = fmt.Sprintf("%s%s/api/v1/schedulers", addr, apiPrefix) diff --git a/server/api/server_test.go b/server/api/server_test.go index b82dfc5ea21..8693b4b87ca 100644 --- a/server/api/server_test.go +++ b/server/api/server_test.go @@ -106,7 +106,7 @@ func mustNewCluster(re *require.Assertions, num int, opts ...func(cfg *config.Co } close(ch) // wait etcd and http servers - mustWaitLeader(re, svrs) + server.MustWaitLeader(re, svrs) // clean up clean := func() { @@ -122,23 +122,6 @@ func mustNewCluster(re *require.Assertions, num int, opts ...func(cfg *config.Co return cfgs, svrs, clean } -func mustWaitLeader(re *require.Assertions, svrs []*server.Server) { - testutil.Eventually(re, func() bool { - var leader *pdpb.Member - for _, svr := range svrs { - l := svr.GetLeader() - // All servers' GetLeader should return the same leader. - if l == nil || (leader != nil && l.GetMemberId() != leader.GetMemberId()) { - return false - } - if leader == nil { - leader = l - } - } - return true - }) -} - func mustBootstrapCluster(re *require.Assertions, s *server.Server) { grpcPDClient := testutil.MustNewGrpcClient(re, s.GetAddr()) req := &pdpb.BootstrapRequest{ @@ -164,7 +147,7 @@ func TestServiceTestSuite(t *testing.T) { func (suite *serviceTestSuite) SetupSuite() { re := suite.Require() suite.svr, suite.cleanup = mustNewServer(re) - mustWaitLeader(re, []*server.Server{suite.svr}) + server.MustWaitLeader(re, []*server.Server{suite.svr}) mustBootstrapCluster(re, suite.svr) mustPutStore(re, suite.svr, 1, metapb.StoreState_Up, metapb.NodeState_Serving, nil) diff --git a/server/api/service_gc_safepoint_test.go b/server/api/service_gc_safepoint_test.go index 291bba0fcaf..e1d1a451922 100644 --- a/server/api/service_gc_safepoint_test.go +++ b/server/api/service_gc_safepoint_test.go @@ -41,7 +41,7 @@ func TestServiceGCSafepointTestSuite(t *testing.T) { func (suite *serviceGCSafepointTestSuite) SetupSuite() { re := suite.Require() suite.svr, suite.cleanup = mustNewServer(re) - mustWaitLeader(re, []*server.Server{suite.svr}) + server.MustWaitLeader(re, []*server.Server{suite.svr}) addr := suite.svr.GetAddr() suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) diff --git a/server/api/service_middleware_test.go b/server/api/service_middleware_test.go index ac188dd8759..ef1ab6b941a 100644 --- a/server/api/service_middleware_test.go +++ b/server/api/service_middleware_test.go @@ -44,7 +44,7 @@ func (suite *auditMiddlewareTestSuite) SetupSuite() { suite.svr, suite.cleanup = mustNewServer(re, func(cfg *config.Config) { cfg.Replication.EnablePlacementRules = false }) - mustWaitLeader(re, []*server.Server{suite.svr}) + server.MustWaitLeader(re, []*server.Server{suite.svr}) addr := suite.svr.GetAddr() suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) @@ -125,7 +125,7 @@ func TestRateLimitConfigTestSuite(t *testing.T) { func (suite *rateLimitConfigTestSuite) SetupSuite() { re := suite.Require() suite.svr, suite.cleanup = mustNewServer(re) - mustWaitLeader(re, []*server.Server{suite.svr}) + server.MustWaitLeader(re, []*server.Server{suite.svr}) mustBootstrapCluster(re, suite.svr) suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", suite.svr.GetAddr(), apiPrefix) } diff --git a/server/api/stats_test.go b/server/api/stats_test.go index 77c35b19679..c2a53597a90 100644 --- a/server/api/stats_test.go +++ b/server/api/stats_test.go @@ -41,7 +41,7 @@ func TestStatsTestSuite(t *testing.T) { func (suite *statsTestSuite) SetupSuite() { re := suite.Require() suite.svr, suite.cleanup = mustNewServer(re) - mustWaitLeader(re, []*server.Server{suite.svr}) + server.MustWaitLeader(re, []*server.Server{suite.svr}) addr := suite.svr.GetAddr() suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) diff --git a/server/api/store_test.go b/server/api/store_test.go index 64cb500164d..9f056ba111b 100644 --- a/server/api/store_test.go +++ b/server/api/store_test.go @@ -98,7 +98,7 @@ func (suite *storeTestSuite) SetupSuite() { // TODO: enable placmentrules re := suite.Require() suite.svr, suite.cleanup = mustNewServer(re, func(cfg *config.Config) { cfg.Replication.EnablePlacementRules = false }) - mustWaitLeader(re, []*server.Server{suite.svr}) + server.MustWaitLeader(re, []*server.Server{suite.svr}) addr := suite.svr.GetAddr() suite.grpcSvr = &server.GrpcServer{Server: suite.svr} diff --git a/server/api/trend_test.go b/server/api/trend_test.go index 972af465ef9..d8f0abcdac2 100644 --- a/server/api/trend_test.go +++ b/server/api/trend_test.go @@ -31,7 +31,7 @@ func TestTrend(t *testing.T) { re := require.New(t) svr, cleanup := mustNewServer(re) defer cleanup() - mustWaitLeader(re, []*server.Server{svr}) + server.MustWaitLeader(re, []*server.Server{svr}) mustBootstrapCluster(re, svr) for i := 1; i <= 3; i++ { diff --git a/server/api/tso_test.go b/server/api/tso_test.go index 07770b130d8..5129a7a2209 100644 --- a/server/api/tso_test.go +++ b/server/api/tso_test.go @@ -42,7 +42,7 @@ func (suite *tsoTestSuite) SetupSuite() { cfg.EnableLocalTSO = true cfg.Labels[config.ZoneLabel] = "dc-1" }) - mustWaitLeader(re, []*server.Server{suite.svr}) + server.MustWaitLeader(re, []*server.Server{suite.svr}) addr := suite.svr.GetAddr() suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) diff --git a/server/api/unsafe_operation_test.go b/server/api/unsafe_operation_test.go index 62df25c6b68..77c4149ec3b 100644 --- a/server/api/unsafe_operation_test.go +++ b/server/api/unsafe_operation_test.go @@ -40,7 +40,7 @@ func TestUnsafeOperationTestSuite(t *testing.T) { func (suite *unsafeOperationTestSuite) SetupSuite() { re := suite.Require() suite.svr, suite.cleanup = mustNewServer(re) - mustWaitLeader(re, []*server.Server{suite.svr}) + server.MustWaitLeader(re, []*server.Server{suite.svr}) addr := suite.svr.GetAddr() suite.urlPrefix = fmt.Sprintf("%s%s/api/v1/admin/unsafe", addr, apiPrefix) diff --git a/server/server_test.go b/server/server_test.go index f520314a5b1..17a7e330aff 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -49,20 +49,6 @@ func TestLeaderServerTestSuite(t *testing.T) { suite.Run(t, new(leaderServerTestSuite)) } -func (suite *leaderServerTestSuite) mustWaitLeader(svrs []*Server) *Server { - var leader *Server - testutil.Eventually(suite.Require(), func() bool { - for _, s := range svrs { - if !s.IsClosed() && s.member.IsLeader() { - leader = s - return true - } - } - return false - }) - return leader -} - func (suite *leaderServerTestSuite) SetupSuite() { suite.ctx, suite.cancel = context.WithCancel(context.Background()) suite.svrs = make(map[string]*Server) @@ -125,7 +111,7 @@ func (suite *leaderServerTestSuite) newTestServersWithCfgs(ctx context.Context, suite.NotNil(svr) svrs = append(svrs, svr) } - suite.mustWaitLeader(svrs) + MustWaitLeader(suite.Require(), svrs) cleanup := func() { for _, svr := range svrs { diff --git a/server/testutil.go b/server/testutil.go index 2534008a5b5..9173e9ab2fc 100644 --- a/server/testutil.go +++ b/server/testutil.go @@ -23,8 +23,10 @@ import ( "time" "github.com/pingcap/log" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/assertutil" "github.com/tikv/pd/pkg/tempurl" + "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/pkg/typeutil" "github.com/tikv/pd/server/config" "go.etcd.io/etcd/embed" @@ -115,3 +117,21 @@ func NewTestMultiConfig(c *assertutil.Checker, count int) []*config.Config { return cfgs } + +// MustWaitLeader return the leader until timeout. +func MustWaitLeader(re *require.Assertions, svrs []*Server) *Server { + var leader *Server + testutil.Eventually(re, func() bool { + for _, svr := range svrs { + // All servers' GetLeader should return the same leader. + if svr.GetLeader() == nil || (leader != nil && svr.GetLeader().GetMemberId() != leader.GetLeader().GetMemberId()) { + return false + } + if leader == nil && !svr.IsClosed() { + leader = svr + } + } + return true + }) + return leader +} diff --git a/tests/client/client_test.go b/tests/client/client_test.go index 2709c7cbc24..3c79fbbdbfd 100644 --- a/tests/client/client_test.go +++ b/tests/client/client_test.go @@ -695,7 +695,7 @@ func (suite *clientTestSuite) SetupSuite() { suite.grpcPDClient = testutil.MustNewGrpcClient(re, suite.srv.GetAddr()) suite.grpcSvr = &server.GrpcServer{Server: suite.srv} - suite.mustWaitLeader(map[string]*server.Server{suite.srv.GetAddr(): suite.srv}) + server.MustWaitLeader(re, []*server.Server{suite.srv}) suite.bootstrapServer(newHeader(suite.srv), suite.grpcPDClient) suite.ctx, suite.clean = context.WithCancel(context.Background()) @@ -728,19 +728,6 @@ func (suite *clientTestSuite) TearDownSuite() { suite.cleanup() } -func (suite *clientTestSuite) mustWaitLeader(svrs map[string]*server.Server) *server.Server { - for i := 0; i < 500; i++ { - for _, s := range svrs { - if !s.IsClosed() && s.GetMember().IsLeader() { - return s - } - } - time.Sleep(100 * time.Millisecond) - } - suite.FailNow("no leader") - return nil -} - func newHeader(srv *server.Server) *pdpb.RequestHeader { return &pdpb.RequestHeader{ ClusterId: srv.ClusterID(), diff --git a/tests/server/api/api_test.go b/tests/server/api/api_test.go index 012c09b287e..58323a9d236 100644 --- a/tests/server/api/api_test.go +++ b/tests/server/api/api_test.go @@ -274,7 +274,11 @@ func (suite *middlewareTestSuite) TestAuditPrometheusBackend() { // resign to test persist config oldLeaderName := leader.GetServer().Name() leader.GetServer().GetMember().ResignEtcdLeader(leader.GetServer().Context(), oldLeaderName, "") - suite.mustWaitLeader() + var servers []*server.Server + for _, s := range suite.cluster.GetServers() { + servers = append(servers, s.GetServer()) + } + server.MustWaitLeader(suite.Require(), servers) leader = suite.cluster.GetServer(suite.cluster.GetLeader()) timeUnix = time.Now().Unix() - 20 @@ -740,17 +744,3 @@ func sendRequest(re *require.Assertions, url string, method string, statusCode i resp.Body.Close() return output } - -func (suite *middlewareTestSuite) mustWaitLeader() *server.Server { - var leader *server.Server - testutil.Eventually(suite.Require(), func() bool { - for _, s := range suite.cluster.GetServers() { - if !s.GetServer().IsClosed() && s.GetServer().GetMember().IsLeader() { - leader = s.GetServer() - return true - } - } - return false - }) - return leader -} diff --git a/tests/server/config/config_test.go b/tests/server/config/config_test.go index 9f46f5a6676..57882d5e965 100644 --- a/tests/server/config/config_test.go +++ b/tests/server/config/config_test.go @@ -23,7 +23,6 @@ import ( "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/ratelimit" - "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server" "github.com/tikv/pd/tests" ) @@ -65,23 +64,13 @@ func TestRateLimitConfigReload(t *testing.T) { oldLeaderName := leader.GetServer().Name() leader.GetServer().GetMember().ResignEtcdLeader(leader.GetServer().Context(), oldLeaderName, "") - mustWaitLeader(re, cluster.GetServers()) + var servers []*server.Server + for _, s := range cluster.GetServers() { + servers = append(servers, s.GetServer()) + } + server.MustWaitLeader(re, servers) leader = cluster.GetServer(cluster.GetLeader()) re.True(leader.GetServer().GetServiceMiddlewarePersistOptions().IsRateLimitEnabled()) re.Len(leader.GetServer().GetServiceMiddlewarePersistOptions().GetRateLimitConfig().LimiterConfig, 1) } - -func mustWaitLeader(re *require.Assertions, svrs map[string]*tests.TestServer) *server.Server { - var leader *server.Server - testutil.Eventually(re, func() bool { - for _, svr := range svrs { - if !svr.GetServer().IsClosed() && svr.GetServer().GetMember().IsLeader() { - leader = svr.GetServer() - return true - } - } - return false - }) - return leader -} diff --git a/tests/server/member/member_test.go b/tests/server/member/member_test.go index 15819bb4bf4..023f58103ef 100644 --- a/tests/server/member/member_test.go +++ b/tests/server/member/member_test.go @@ -315,7 +315,7 @@ func TestGetLeader(t *testing.T) { go sendRequest(re, wg, done, cfg.ClientUrls) time.Sleep(100 * time.Millisecond) - mustWaitLeader(re, []*server.Server{svr}) + server.MustWaitLeader(re, []*server.Server{svr}) re.NotNil(svr.GetLeader()) @@ -345,17 +345,3 @@ func sendRequest(re *require.Assertions, wg *sync.WaitGroup, done <-chan bool, a time.Sleep(10 * time.Millisecond) } } - -func mustWaitLeader(re *require.Assertions, svrs []*server.Server) *server.Server { - var leader *server.Server - testutil.Eventually(re, func() bool { - for _, s := range svrs { - if !s.IsClosed() && s.GetMember().IsLeader() { - leader = s - return true - } - } - return false - }) - return leader -} From b5e113436c85f81d0e6ca2b5b592271114089db2 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Fri, 24 Jun 2022 14:36:37 +0800 Subject: [PATCH 81/82] tests: testify the global_config tests (#5227) ref tikv/pd#4813 Testify the global_config tests. Signed-off-by: JmPotato --- .../global_config/global_config_test.go | 181 +++++++++--------- 1 file changed, 91 insertions(+), 90 deletions(-) diff --git a/tests/server/global_config/global_config_test.go b/tests/server/global_config/global_config_test.go index c2b12353eea..f821d664b7a 100644 --- a/tests/server/global_config/global_config_test.go +++ b/tests/server/global_config/global_config_test.go @@ -21,9 +21,10 @@ import ( "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/log" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/assertutil" "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server" @@ -32,120 +33,120 @@ import ( "google.golang.org/grpc" ) -func Test(t *testing.T) { - TestingT(t) -} - func TestMain(m *testing.M) { goleak.VerifyTestMain(m, testutil.LeakOptions...) } -var _ = Suite(&GlobalConfigTestSuite{}) -var globalConfigPath = "/global/config/" - -type GlobalConfigTestSuite struct { - server *server.GrpcServer - client *grpc.ClientConn - cleanup server.CleanupFunc -} +const globalConfigPath = "/global/config/" -type TestReceiver struct { - c *C +type testReceiver struct { + re *require.Assertions grpc.ServerStream } -func (s TestReceiver) Send(m *pdpb.WatchGlobalConfigResponse) error { +func (s testReceiver) Send(m *pdpb.WatchGlobalConfigResponse) error { log.Info("received", zap.Any("received", m.GetChanges())) for _, change := range m.GetChanges() { - s.c.Assert(change.Name, Equals, globalConfigPath+change.Value) + s.re.Equal(globalConfigPath+change.Value, change.Name) } return nil } -func (s *GlobalConfigTestSuite) SetUpSuite(c *C) { +type globalConfigTestSuite struct { + suite.Suite + server *server.GrpcServer + client *grpc.ClientConn + cleanup server.CleanupFunc +} + +func TestGlobalConfigTestSuite(t *testing.T) { + suite.Run(t, new(globalConfigTestSuite)) +} + +func (suite *globalConfigTestSuite) SetupSuite() { var err error var gsi *server.Server checker := assertutil.NewChecker() checker.FailNow = func() {} - gsi, s.cleanup, err = server.NewTestServer(checker) - s.server = &server.GrpcServer{Server: gsi} - c.Assert(err, IsNil) - addr := s.server.GetAddr() - s.client, err = grpc.Dial(strings.TrimPrefix(addr, "http://"), grpc.WithInsecure()) - c.Assert(err, IsNil) + gsi, suite.cleanup, err = server.NewTestServer(checker) + suite.server = &server.GrpcServer{Server: gsi} + suite.NoError(err) + addr := suite.server.GetAddr() + suite.client, err = grpc.Dial(strings.TrimPrefix(addr, "http://"), grpc.WithInsecure()) + suite.NoError(err) } -func (s *GlobalConfigTestSuite) TearDownSuite(c *C) { - s.client.Close() - s.cleanup() +func (suite *globalConfigTestSuite) TearDownSuite() { + suite.client.Close() + suite.cleanup() } -func (s *GlobalConfigTestSuite) TestLoad(c *C) { +func (suite *globalConfigTestSuite) TestLoad() { defer func() { // clean up - _, err := s.server.GetClient().Delete(s.server.Context(), globalConfigPath+"test") - c.Assert(err, IsNil) + _, err := suite.server.GetClient().Delete(suite.server.Context(), globalConfigPath+"test") + suite.NoError(err) }() - _, err := s.server.GetClient().Put(s.server.Context(), globalConfigPath+"test", "test") - c.Assert(err, IsNil) - res, err := s.server.LoadGlobalConfig(s.server.Context(), &pdpb.LoadGlobalConfigRequest{Names: []string{"test"}}) - c.Assert(err, IsNil) - c.Assert(len(res.Items), Equals, 1) - c.Assert(res.Items[0].Value, Equals, "test") + _, err := suite.server.GetClient().Put(suite.server.Context(), globalConfigPath+"test", "test") + suite.NoError(err) + res, err := suite.server.LoadGlobalConfig(suite.server.Context(), &pdpb.LoadGlobalConfigRequest{Names: []string{"test"}}) + suite.NoError(err) + suite.Len(res.Items, 1) + suite.Equal("test", res.Items[0].Value) } -func (s *GlobalConfigTestSuite) TestLoadError(c *C) { - res, err := s.server.LoadGlobalConfig(s.server.Context(), &pdpb.LoadGlobalConfigRequest{Names: []string{"test"}}) - c.Assert(err, IsNil) - c.Assert(res.Items[0].Error, Not(Equals), nil) +func (suite *globalConfigTestSuite) TestLoadError() { + res, err := suite.server.LoadGlobalConfig(suite.server.Context(), &pdpb.LoadGlobalConfigRequest{Names: []string{"test"}}) + suite.NoError(err) + suite.NotNil(res.Items[0].Error) } -func (s *GlobalConfigTestSuite) TestStore(c *C) { +func (suite *globalConfigTestSuite) TestStore() { defer func() { for i := 1; i <= 3; i++ { - _, err := s.server.GetClient().Delete(s.server.Context(), globalConfigPath+strconv.Itoa(i)) - c.Assert(err, IsNil) + _, err := suite.server.GetClient().Delete(suite.server.Context(), globalConfigPath+strconv.Itoa(i)) + suite.NoError(err) } }() changes := []*pdpb.GlobalConfigItem{{Name: "1", Value: "1"}, {Name: "2", Value: "2"}, {Name: "3", Value: "3"}} - _, err := s.server.StoreGlobalConfig(s.server.Context(), &pdpb.StoreGlobalConfigRequest{Changes: changes}) - c.Assert(err, IsNil) + _, err := suite.server.StoreGlobalConfig(suite.server.Context(), &pdpb.StoreGlobalConfigRequest{Changes: changes}) + suite.NoError(err) for i := 1; i <= 3; i++ { - res, err := s.server.GetClient().Get(s.server.Context(), globalConfigPath+strconv.Itoa(i)) - c.Assert(err, IsNil) - c.Assert(string(res.Kvs[0].Key), Equals, globalConfigPath+string(res.Kvs[0].Value)) + res, err := suite.server.GetClient().Get(suite.server.Context(), globalConfigPath+strconv.Itoa(i)) + suite.NoError(err) + suite.Equal(globalConfigPath+string(res.Kvs[0].Value), string(res.Kvs[0].Key)) } } -func (s *GlobalConfigTestSuite) TestWatch(c *C) { +func (suite *globalConfigTestSuite) TestWatch() { defer func() { for i := 0; i < 3; i++ { // clean up - _, err := s.server.GetClient().Delete(s.server.Context(), globalConfigPath+strconv.Itoa(i)) - c.Assert(err, IsNil) + _, err := suite.server.GetClient().Delete(suite.server.Context(), globalConfigPath+strconv.Itoa(i)) + suite.NoError(err) } }() - server := TestReceiver{c: c} - go s.server.WatchGlobalConfig(&pdpb.WatchGlobalConfigRequest{}, server) + server := testReceiver{re: suite.Require()} + go suite.server.WatchGlobalConfig(&pdpb.WatchGlobalConfigRequest{}, server) for i := 0; i < 3; i++ { - _, err := s.server.GetClient().Put(s.server.Context(), globalConfigPath+strconv.Itoa(i), strconv.Itoa(i)) - c.Assert(err, IsNil) + _, err := suite.server.GetClient().Put(suite.server.Context(), globalConfigPath+strconv.Itoa(i), strconv.Itoa(i)) + suite.NoError(err) } } -func (s *GlobalConfigTestSuite) loadGlobalConfig(ctx context.Context, names []string) ([]*pdpb.GlobalConfigItem, error) { - res, err := pdpb.NewPDClient(s.client).LoadGlobalConfig(ctx, &pdpb.LoadGlobalConfigRequest{Names: names}) +func (suite *globalConfigTestSuite) loadGlobalConfig(ctx context.Context, names []string) ([]*pdpb.GlobalConfigItem, error) { + res, err := pdpb.NewPDClient(suite.client).LoadGlobalConfig(ctx, &pdpb.LoadGlobalConfigRequest{Names: names}) return res.GetItems(), err } -func (s *GlobalConfigTestSuite) storeGlobalConfig(ctx context.Context, changes []*pdpb.GlobalConfigItem) error { - _, err := pdpb.NewPDClient(s.client).StoreGlobalConfig(ctx, &pdpb.StoreGlobalConfigRequest{Changes: changes}) +func (suite *globalConfigTestSuite) storeGlobalConfig(ctx context.Context, changes []*pdpb.GlobalConfigItem) error { + _, err := pdpb.NewPDClient(suite.client).StoreGlobalConfig(ctx, &pdpb.StoreGlobalConfigRequest{Changes: changes}) return err } -func (s *GlobalConfigTestSuite) watchGlobalConfig(ctx context.Context) (chan []*pdpb.GlobalConfigItem, error) { +func (suite *globalConfigTestSuite) watchGlobalConfig(ctx context.Context) (chan []*pdpb.GlobalConfigItem, error) { globalConfigWatcherCh := make(chan []*pdpb.GlobalConfigItem, 16) - res, err := pdpb.NewPDClient(s.client).WatchGlobalConfig(ctx, &pdpb.WatchGlobalConfigRequest{}) + res, err := pdpb.NewPDClient(suite.client).WatchGlobalConfig(ctx, &pdpb.WatchGlobalConfigRequest{}) if err != nil { close(globalConfigWatcherCh) return nil, err @@ -177,53 +178,53 @@ func (s *GlobalConfigTestSuite) watchGlobalConfig(ctx context.Context) (chan []* return globalConfigWatcherCh, err } -func (s *GlobalConfigTestSuite) TestClientLoad(c *C) { +func (suite *globalConfigTestSuite) TestClientLoad() { defer func() { - _, err := s.server.GetClient().Delete(s.server.Context(), globalConfigPath+"test") - c.Assert(err, IsNil) + _, err := suite.server.GetClient().Delete(suite.server.Context(), globalConfigPath+"test") + suite.NoError(err) }() - _, err := s.server.GetClient().Put(s.server.Context(), globalConfigPath+"test", "test") - c.Assert(err, IsNil) - res, err := s.loadGlobalConfig(s.server.Context(), []string{"test"}) - c.Assert(err, IsNil) - c.Assert(len(res), Equals, 1) - c.Assert(res[0], DeepEquals, &pdpb.GlobalConfigItem{Name: "test", Value: "test", Error: nil}) + _, err := suite.server.GetClient().Put(suite.server.Context(), globalConfigPath+"test", "test") + suite.NoError(err) + res, err := suite.loadGlobalConfig(suite.server.Context(), []string{"test"}) + suite.NoError(err) + suite.Len(res, 1) + suite.Equal(&pdpb.GlobalConfigItem{Name: "test", Value: "test", Error: nil}, res[0]) } -func (s *GlobalConfigTestSuite) TestClientLoadError(c *C) { - res, err := s.loadGlobalConfig(s.server.Context(), []string{"test"}) - c.Assert(err, IsNil) - c.Assert(res[0].Error, Not(Equals), nil) +func (suite *globalConfigTestSuite) TestClientLoadError() { + res, err := suite.loadGlobalConfig(suite.server.Context(), []string{"test"}) + suite.NoError(err) + suite.NotNil(res[0].Error) } -func (s *GlobalConfigTestSuite) TestClientStore(c *C) { +func (suite *globalConfigTestSuite) TestClientStore() { defer func() { for i := 1; i <= 3; i++ { - _, err := s.server.GetClient().Delete(s.server.Context(), globalConfigPath+strconv.Itoa(i)) - c.Assert(err, IsNil) + _, err := suite.server.GetClient().Delete(suite.server.Context(), globalConfigPath+strconv.Itoa(i)) + suite.NoError(err) } }() - err := s.storeGlobalConfig(s.server.Context(), []*pdpb.GlobalConfigItem{{Name: "1", Value: "1"}, {Name: "2", Value: "2"}, {Name: "3", Value: "3"}}) - c.Assert(err, IsNil) + err := suite.storeGlobalConfig(suite.server.Context(), []*pdpb.GlobalConfigItem{{Name: "1", Value: "1"}, {Name: "2", Value: "2"}, {Name: "3", Value: "3"}}) + suite.NoError(err) for i := 1; i <= 3; i++ { - res, err := s.server.GetClient().Get(s.server.Context(), globalConfigPath+strconv.Itoa(i)) - c.Assert(err, IsNil) - c.Assert(string(res.Kvs[0].Key), Equals, globalConfigPath+string(res.Kvs[0].Value)) + res, err := suite.server.GetClient().Get(suite.server.Context(), globalConfigPath+strconv.Itoa(i)) + suite.NoError(err) + suite.Equal(globalConfigPath+string(res.Kvs[0].Value), string(res.Kvs[0].Key)) } } -func (s *GlobalConfigTestSuite) TestClientWatch(c *C) { +func (suite *globalConfigTestSuite) TestClientWatch() { defer func() { for i := 0; i < 3; i++ { - _, err := s.server.GetClient().Delete(s.server.Context(), globalConfigPath+strconv.Itoa(i)) - c.Assert(err, IsNil) + _, err := suite.server.GetClient().Delete(suite.server.Context(), globalConfigPath+strconv.Itoa(i)) + suite.NoError(err) } }() - wc, err := s.watchGlobalConfig(s.server.Context()) - c.Assert(err, IsNil) + wc, err := suite.watchGlobalConfig(suite.server.Context()) + suite.NoError(err) for i := 0; i < 3; i++ { - _, err = s.server.GetClient().Put(s.server.Context(), globalConfigPath+strconv.Itoa(i), strconv.Itoa(i)) - c.Assert(err, IsNil) + _, err = suite.server.GetClient().Put(suite.server.Context(), globalConfigPath+strconv.Itoa(i), strconv.Itoa(i)) + suite.NoError(err) } for { select { @@ -231,7 +232,7 @@ func (s *GlobalConfigTestSuite) TestClientWatch(c *C) { return case res := <-wc: for _, r := range res { - c.Assert(r.Name, Equals, globalConfigPath+r.Value) + suite.Equal(globalConfigPath+r.Value, r.Name) } } } From 844f5ed8fbc331eb86f702a43a1aa46d9e767430 Mon Sep 17 00:00:00 2001 From: Fu Yu <38561589+njuwelkin@users.noreply.github.com> Date: Mon, 27 Jun 2022 18:56:39 +0800 Subject: [PATCH 82/82] =?UTF-8?q?server/scheduler:=20check=20whether=20is?= =?UTF-8?q?=20valid=20for=20config=20of=20hot=20scheduler=20=E2=80=A6=20(#?= =?UTF-8?q?5208)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit close tikv/pd#3952 Signed-off-by: kevin fu Co-authored-by: Ti Chi Robot --- server/schedulers/hot_region_config.go | 43 ++++++++++++++++++++++++++ server/schedulers/hot_region_test.go | 31 +++++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/server/schedulers/hot_region_config.go b/server/schedulers/hot_region_config.go index 53228cf6a79..11c20c4279a 100644 --- a/server/schedulers/hot_region_config.go +++ b/server/schedulers/hot_region_config.go @@ -17,12 +17,14 @@ package schedulers import ( "bytes" "encoding/json" + "errors" "io" "net/http" "time" "github.com/gorilla/mux" "github.com/pingcap/log" + "github.com/tikv/pd/pkg/errs" "github.com/tikv/pd/pkg/reflectutil" "github.com/tikv/pd/pkg/slice" "github.com/tikv/pd/pkg/syncutil" @@ -299,6 +301,38 @@ func (conf *hotRegionSchedulerConfig) handleGetConfig(w http.ResponseWriter, r * rd.JSON(w, http.StatusOK, conf.getValidConf()) } +func (conf *hotRegionSchedulerConfig) validPriority() error { + isValid := func(priorities []string) (map[string]bool, error) { + priorityMap := map[string]bool{} + for _, p := range priorities { + if p != BytePriority && p != KeyPriority && p != QueryPriority { + return nil, errs.ErrSchedulerConfig.FastGenByArgs("invalid scheduling dimensions.") + } + priorityMap[p] = true + } + if len(priorityMap) != len(priorities) { + return nil, errors.New("priorities shouldn't be repeated") + } + if len(priorityMap) != 0 && len(priorityMap) < 2 { + return nil, errors.New("priorities should have at least 2 dimensions") + } + return priorityMap, nil + } + if _, err := isValid(conf.ReadPriorities); err != nil { + return err + } + if _, err := isValid(conf.WriteLeaderPriorities); err != nil { + return err + } + pm, err := isValid(conf.WritePeerPriorities) + if err != nil { + return err + } else if pm[QueryPriority] { + return errors.New("qps is not allowed to be set in priorities for write-peer-priorities") + } + return nil +} + func (conf *hotRegionSchedulerConfig) handleSetConfig(w http.ResponseWriter, r *http.Request) { conf.Lock() defer conf.Unlock() @@ -315,6 +349,15 @@ func (conf *hotRegionSchedulerConfig) handleSetConfig(w http.ResponseWriter, r * rd.JSON(w, http.StatusInternalServerError, err.Error()) return } + if err := conf.validPriority(); err != nil { + // revert to old version + if err2 := json.Unmarshal(oldc, conf); err2 != nil { + rd.JSON(w, http.StatusInternalServerError, err2.Error()) + } else { + rd.JSON(w, http.StatusBadRequest, err.Error()) + } + return + } newc, _ := json.Marshal(conf) if !bytes.Equal(oldc, newc) { conf.persistLocked() diff --git a/server/schedulers/hot_region_test.go b/server/schedulers/hot_region_test.go index a0abd95e2ba..0ebbe9c3090 100644 --- a/server/schedulers/hot_region_test.go +++ b/server/schedulers/hot_region_test.go @@ -2113,6 +2113,37 @@ func (s *testHotSchedulerSuite) TestCompatibilityConfig(c *C) { }) } +func (s *testHotSchedulerSuite) TestConfigValidation(c *C) { + // priority should be one of byte/query/key + hc := initHotRegionScheduleConfig() + hc.ReadPriorities = []string{"byte", "error"} + err := hc.validPriority() + c.Assert(err, NotNil) + + // priorities should have at least 2 dimensions + hc = initHotRegionScheduleConfig() + hc.WriteLeaderPriorities = []string{"byte"} + err = hc.validPriority() + c.Assert(err, NotNil) + + // qps is not allowed to be set in priorities for write-peer-priorities + hc = initHotRegionScheduleConfig() + hc.WritePeerPriorities = []string{"query", "byte"} + err = hc.validPriority() + c.Assert(err, NotNil) + + // priorities shouldn't be repeated + hc = initHotRegionScheduleConfig() + hc.WritePeerPriorities = []string{"byte", "byte"} + err = hc.validPriority() + c.Assert(err, NotNil) + + hc = initHotRegionScheduleConfig() + hc.WritePeerPriorities = []string{"byte", "key"} + err = hc.validPriority() + c.Assert(err, IsNil) +} + func checkPriority(c *C, hb *hotScheduler, tc *mockcluster.Cluster, dims [3][2]int) { readSolver := newBalanceSolver(hb, tc, statistics.Read, transferLeader) writeLeaderSolver := newBalanceSolver(hb, tc, statistics.Write, transferLeader)