diff --git a/client/base_client.go b/client/base_client.go index 9916f459acb..5e38370cbb8 100644 --- a/client/base_client.go +++ b/client/base_client.go @@ -47,6 +47,8 @@ type baseClient struct { security SecurityOption gRPCDialOptions []grpc.DialOption + + timeout time.Duration } // SecurityOption records options about tls @@ -66,6 +68,13 @@ func WithGRPCDialOptions(opts ...grpc.DialOption) ClientOption { } } +// WithCustomTimeoutOption configures the client with timeout option. +func WithCustomTimeoutOption(timeout time.Duration) ClientOption { + return func(c *baseClient) { + c.timeout = timeout + } +} + // newBaseClient returns a new baseClient. func newBaseClient(ctx context.Context, urls []string, security SecurityOption, opts ...ClientOption) (*baseClient, error) { ctx1, cancel := context.WithCancel(ctx) @@ -75,6 +84,7 @@ func newBaseClient(ctx context.Context, urls []string, security SecurityOption, ctx: ctx1, cancel: cancel, security: security, + timeout: defaultPDTimeout, } c.connMu.clientConns = make(map[string]*grpc.ClientConn) for _, opt := range opts { @@ -163,7 +173,7 @@ func (c *baseClient) initClusterID() error { ctx, cancel := context.WithCancel(c.ctx) defer cancel() for _, u := range c.urls { - timeoutCtx, timeoutCancel := context.WithTimeout(ctx, pdTimeout) + timeoutCtx, timeoutCancel := context.WithTimeout(ctx, c.timeout) members, err := c.getMembers(timeoutCtx, u) timeoutCancel() if err != nil || members.GetHeader() == nil { diff --git a/client/client.go b/client/client.go index 0ecce28eaaf..cdabd30dd8f 100644 --- a/client/client.go +++ b/client/client.go @@ -111,7 +111,7 @@ type tsoRequest struct { } const ( - pdTimeout = 3 * time.Second + defaultPDTimeout = 3 * time.Second dialTimeout = 3 * time.Second updateLeaderTimeout = time.Second // Use a shorter timeout to recover faster from network isolation. maxMergeTSORequests = 10000 @@ -238,7 +238,7 @@ func (c *client) tsLoop() { } done := make(chan struct{}) dl := deadline{ - timer: time.After(pdTimeout), + timer: time.After(c.timeout), done: done, cancel: cancel, } @@ -455,7 +455,7 @@ func (c *client) GetRegion(ctx context.Context, key []byte) (*Region, error) { start := time.Now() defer func() { cmdDurationGetRegion.Observe(time.Since(start).Seconds()) }() - ctx, cancel := context.WithTimeout(ctx, pdTimeout) + ctx, cancel := context.WithTimeout(ctx, c.timeout) resp, err := c.leaderClient().GetRegion(ctx, &pdpb.GetRegionRequest{ Header: c.requestHeader(), RegionKey: key, @@ -478,7 +478,7 @@ func (c *client) GetPrevRegion(ctx context.Context, key []byte) (*Region, error) start := time.Now() defer func() { cmdDurationGetPrevRegion.Observe(time.Since(start).Seconds()) }() - ctx, cancel := context.WithTimeout(ctx, pdTimeout) + ctx, cancel := context.WithTimeout(ctx, c.timeout) resp, err := c.leaderClient().GetPrevRegion(ctx, &pdpb.GetRegionRequest{ Header: c.requestHeader(), RegionKey: key, @@ -501,7 +501,7 @@ func (c *client) GetRegionByID(ctx context.Context, regionID uint64) (*Region, e start := time.Now() defer func() { cmdDurationGetRegionByID.Observe(time.Since(start).Seconds()) }() - ctx, cancel := context.WithTimeout(ctx, pdTimeout) + ctx, cancel := context.WithTimeout(ctx, c.timeout) resp, err := c.leaderClient().GetRegionByID(ctx, &pdpb.GetRegionByIDRequest{ Header: c.requestHeader(), RegionId: regionID, @@ -527,7 +527,7 @@ func (c *client) ScanRegions(ctx context.Context, key, endKey []byte, limit int) var cancel context.CancelFunc scanCtx := ctx if _, ok := ctx.Deadline(); !ok { - scanCtx, cancel = context.WithTimeout(ctx, pdTimeout) + scanCtx, cancel = context.WithTimeout(ctx, c.timeout) defer cancel() } @@ -553,7 +553,7 @@ func (c *client) GetStore(ctx context.Context, storeID uint64) (*metapb.Store, e start := time.Now() defer func() { cmdDurationGetStore.Observe(time.Since(start).Seconds()) }() - ctx, cancel := context.WithTimeout(ctx, pdTimeout) + ctx, cancel := context.WithTimeout(ctx, c.timeout) resp, err := c.leaderClient().GetStore(ctx, &pdpb.GetStoreRequest{ Header: c.requestHeader(), StoreId: storeID, @@ -589,7 +589,7 @@ func (c *client) GetAllStores(ctx context.Context, opts ...GetStoreOption) ([]*m start := time.Now() defer func() { cmdDurationGetAllStores.Observe(time.Since(start).Seconds()) }() - ctx, cancel := context.WithTimeout(ctx, pdTimeout) + ctx, cancel := context.WithTimeout(ctx, c.timeout) resp, err := c.leaderClient().GetAllStores(ctx, &pdpb.GetAllStoresRequest{ Header: c.requestHeader(), ExcludeTombstoneStores: options.excludeTombstone, @@ -613,7 +613,7 @@ func (c *client) UpdateGCSafePoint(ctx context.Context, safePoint uint64) (uint6 start := time.Now() defer func() { cmdDurationUpdateGCSafePoint.Observe(time.Since(start).Seconds()) }() - ctx, cancel := context.WithTimeout(ctx, pdTimeout) + ctx, cancel := context.WithTimeout(ctx, c.timeout) resp, err := c.leaderClient().UpdateGCSafePoint(ctx, &pdpb.UpdateGCSafePointRequest{ Header: c.requestHeader(), SafePoint: safePoint, @@ -641,7 +641,7 @@ func (c *client) UpdateServiceGCSafePoint(ctx context.Context, serviceID string, start := time.Now() defer func() { cmdDurationUpdateServiceGCSafePoint.Observe(time.Since(start).Seconds()) }() - ctx, cancel := context.WithTimeout(ctx, pdTimeout) + ctx, cancel := context.WithTimeout(ctx, c.timeout) resp, err := c.leaderClient().UpdateServiceGCSafePoint(ctx, &pdpb.UpdateServiceGCSafePointRequest{ Header: c.requestHeader(), ServiceId: []byte(serviceID), @@ -666,7 +666,7 @@ func (c *client) ScatterRegion(ctx context.Context, regionID uint64) error { start := time.Now() defer func() { cmdDurationScatterRegion.Observe(time.Since(start).Seconds()) }() - ctx, cancel := context.WithTimeout(ctx, pdTimeout) + ctx, cancel := context.WithTimeout(ctx, c.timeout) resp, err := c.leaderClient().ScatterRegion(ctx, &pdpb.ScatterRegionRequest{ Header: c.requestHeader(), RegionId: regionID, @@ -689,7 +689,7 @@ func (c *client) GetOperator(ctx context.Context, regionID uint64) (*pdpb.GetOpe start := time.Now() defer func() { cmdDurationGetOperator.Observe(time.Since(start).Seconds()) }() - ctx, cancel := context.WithTimeout(ctx, pdTimeout) + ctx, cancel := context.WithTimeout(ctx, c.timeout) defer cancel() return c.leaderClient().GetOperator(ctx, &pdpb.GetOperatorRequest{ Header: c.requestHeader(), diff --git a/pkg/testutil/leak.go b/pkg/testutil/leak.go index aa895810a34..6a4fc0b29f0 100644 --- a/pkg/testutil/leak.go +++ b/pkg/testutil/leak.go @@ -26,6 +26,7 @@ var LeakOptions = []goleak.Option{ goleak.IgnoreTopFunction("google.golang.org/grpc.(*ccResolverWrapper).watcher"), goleak.IgnoreTopFunction("google.golang.org/grpc.(*addrConn).createTransport"), goleak.IgnoreTopFunction("google.golang.org/grpc.(*addrConn).resetTransport"), + goleak.IgnoreTopFunction("google.golang.org/grpc.(*Server).handleRawConn"), goleak.IgnoreTopFunction("go.etcd.io/etcd/pkg/logutil.(*MergeLogger).outputLoop"), // TODO: remove the below options once we fixed the http connection leak problems goleak.IgnoreTopFunction("internal/poll.runtime_pollWait"), diff --git a/server/cluster/cluster_test.go b/server/cluster/cluster_test.go index a9ae9afb21c..fbeb39d4b4e 100644 --- a/server/cluster/cluster_test.go +++ b/server/cluster/cluster_test.go @@ -332,13 +332,13 @@ func (s *testClusterInfoSuite) TestConcurrentRegionHeartbeat(c *C) { var wg sync.WaitGroup wg.Add(1) - c.Assert(failpoint.Enable("github.com/pingcap/pd/server/cluster/concurrentRegionHeartbeat", "return(true)"), IsNil) + c.Assert(failpoint.Enable("github.com/pingcap/pd/v4/server/cluster/concurrentRegionHeartbeat", "return(true)"), IsNil) go func() { defer wg.Done() cluster.processRegionHeartbeat(source) }() time.Sleep(100 * time.Millisecond) - c.Assert(failpoint.Disable("github.com/pingcap/pd/server/cluster/concurrentRegionHeartbeat"), IsNil) + c.Assert(failpoint.Disable("github.com/pingcap/pd/v4/server/cluster/concurrentRegionHeartbeat"), IsNil) c.Assert(cluster.processRegionHeartbeat(target), IsNil) wg.Wait() checkRegion(c, cluster.GetRegionByKey([]byte{}), target) diff --git a/server/grpc_service.go b/server/grpc_service.go index 4256493e3da..0a12b30d131 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -21,6 +21,7 @@ import ( "sync/atomic" "time" + "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/log" @@ -243,6 +244,9 @@ func (s *Server) PutStore(ctx context.Context, request *pdpb.PutStoreRequest) (* // GetAllStores implements gRPC PDServer. func (s *Server) GetAllStores(ctx context.Context, request *pdpb.GetAllStoresRequest) (*pdpb.GetAllStoresResponse, error) { + failpoint.Inject("customTimeout", func() { + time.Sleep(5 * time.Second) + }) if err := s.validateRequest(request.GetHeader()); err != nil { return nil, err } diff --git a/tests/client/client_test.go b/tests/client/client_test.go index 70e8c75e644..e9c249093d9 100644 --- a/tests/client/client_test.go +++ b/tests/client/client_test.go @@ -25,6 +25,7 @@ import ( "github.com/gogo/protobuf/proto" . "github.com/pingcap/check" + "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" pd "github.com/pingcap/pd/v4/client" @@ -187,6 +188,31 @@ func (s *clientTestSuite) TestLeaderTransfer(c *C) { wg.Wait() } +func (s *clientTestSuite) TestCustomTimeout(c *C) { + cluster, err := tests.NewTestCluster(s.ctx, 1) + c.Assert(err, IsNil) + defer cluster.Destroy() + + err = cluster.RunInitialServers() + c.Assert(err, IsNil) + cluster.WaitLeader() + + var endpoints []string + for _, s := range cluster.GetServers() { + endpoints = append(endpoints, s.GetConfig().AdvertiseClientUrls) + } + cli, err := pd.NewClientWithContext(s.ctx, endpoints, pd.SecurityOption{}, pd.WithCustomTimeoutOption(1*time.Second)) + c.Assert(err, IsNil) + + start := time.Now() + c.Assert(failpoint.Enable("github.com/pingcap/pd/v4/server/customTimeout", "return(true)"), IsNil) + _, err = cli.GetAllStores(context.TODO()) + c.Assert(failpoint.Disable("github.com/pingcap/pd/v4/server/customTimeout"), IsNil) + c.Assert(err, NotNil) + c.Assert(time.Since(start), GreaterEqual, 1*time.Second) + c.Assert(time.Since(start), Less, 2*time.Second) +} + func (s *clientTestSuite) waitLeader(c *C, cli client, leader string) { testutil.WaitUntil(c, func(c *C) bool { cli.ScheduleCheckLeader()