Skip to content

Commit

Permalink
Reorder grpc interceptors (#3423)
Browse files Browse the repository at this point in the history
* Reorder server interceptor
  • Loading branch information
yux0 authored and dnr committed Sep 29, 2022
1 parent ac2593c commit ca5d0fd
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 43 deletions.
39 changes: 31 additions & 8 deletions common/rpc/interceptor/namespace_validator.go
Expand Up @@ -39,21 +39,18 @@ import (
)

type (
// NamespaceValidatorInterceptor validates:
// 1. Namespace is specified in task token if there is a `task_token` field.
// 2. Namespace is specified in request if there is a `namespace` field and no `task_token` field.
// 3. Namespace exists.
// 4. Namespace from request match namespace from task token, if check is enabled with dynamic config.
// 5. Namespace is in correct state.
// NamespaceValidatorInterceptor contains LengthValidationIntercept and StateValidationIntercept
NamespaceValidatorInterceptor struct {
namespaceRegistry namespace.Registry
tokenSerializer common.TaskTokenSerializer
enableTokenNamespaceEnforcement dynamicconfig.BoolPropertyFn
maxNamespaceLength dynamicconfig.IntPropertyFn
}
)

var (
ErrNamespaceNotSet = serviceerror.NewInvalidArgument("Namespace not set on request.")
errNamespaceTooLong = serviceerror.NewInvalidArgument("Namespace length exceeds limit.")
errNamespaceHandover = serviceerror.NewUnavailable(fmt.Sprintf("Namespace replication in %s state.", enumspb.REPLICATION_STATE_HANDOVER.String()))
errTaskTokenNotSet = serviceerror.NewInvalidArgument("Task token not set on request.")
errTaskTokenNamespaceMismatch = serviceerror.NewInvalidArgument("Operation requested with a token from a different namespace.")
Expand All @@ -73,20 +70,46 @@ var (
}
)

var _ grpc.UnaryServerInterceptor = (*NamespaceValidatorInterceptor)(nil).Intercept
var _ grpc.UnaryServerInterceptor = (*NamespaceValidatorInterceptor)(nil).StateValidationIntercept
var _ grpc.UnaryServerInterceptor = (*NamespaceValidatorInterceptor)(nil).LengthValidationIntercept

func NewNamespaceValidatorInterceptor(
namespaceRegistry namespace.Registry,
enableTokenNamespaceEnforcement dynamicconfig.BoolPropertyFn,
maxNamespaceLength dynamicconfig.IntPropertyFn,
) *NamespaceValidatorInterceptor {
return &NamespaceValidatorInterceptor{
namespaceRegistry: namespaceRegistry,
tokenSerializer: common.NewProtoTaskTokenSerializer(),
enableTokenNamespaceEnforcement: enableTokenNamespaceEnforcement,
maxNamespaceLength: maxNamespaceLength,
}
}

func (ni *NamespaceValidatorInterceptor) Intercept(
func (ni *NamespaceValidatorInterceptor) LengthValidationIntercept(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
reqWithNamespace, hasNamespace := req.(NamespaceNameGetter)
if hasNamespace {
namespaceName := namespace.Name(reqWithNamespace.GetNamespace())
if len(namespaceName) > ni.maxNamespaceLength() {
return nil, errNamespaceTooLong
}
}

return handler(ctx, req)
}

// StateValidationIntercept validates:
// 1. Namespace is specified in task token if there is a `task_token` field.
// 2. Namespace is specified in request if there is a `namespace` field and no `task_token` field.
// 3. Namespace exists.
// 4. Namespace from request match namespace from task token, if check is enabled with dynamic config.
// 5. Namespace is in correct state.
func (ni *NamespaceValidatorInterceptor) StateValidationIntercept(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
Expand Down
106 changes: 79 additions & 27 deletions common/rpc/interceptor/namespace_validator_test.go
Expand Up @@ -76,13 +76,16 @@ func (s *namespaceValidatorSuite) TearDownTest() {
s.controller.Finish()
}

func (s *namespaceValidatorSuite) Test_Intercept_NamespaceNotSet() {
func (s *namespaceValidatorSuite) Test_StateValidationIntercept_NamespaceNotSet() {
taskToken, _ := common.NewProtoTaskTokenSerializer().Serialize(&tokenspb.Task{
NamespaceId: "",
WorkflowId: "wid",
})

nvi := NewNamespaceValidatorInterceptor(s.mockRegistry, dynamicconfig.GetBoolPropertyFn(false))
nvi := NewNamespaceValidatorInterceptor(
s.mockRegistry,
dynamicconfig.GetBoolPropertyFn(false),
dynamicconfig.GetIntPropertyFn(100))
serverInfo := &grpc.UnaryServerInfo{
FullMethod: "/temporal/random",
}
Expand Down Expand Up @@ -113,7 +116,7 @@ func (s *namespaceValidatorSuite) Test_Intercept_NamespaceNotSet() {

for _, testCase := range testCases {
handlerCalled := false
_, err := nvi.Intercept(context.Background(), testCase.req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
_, err := nvi.StateValidationIntercept(context.Background(), testCase.req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
handlerCalled = true
return &workflowservice.StartWorkflowExecutionResponse{}, nil
})
Expand All @@ -128,17 +131,20 @@ func (s *namespaceValidatorSuite) Test_Intercept_NamespaceNotSet() {
}
}

func (s *namespaceValidatorSuite) Test_Intercept_NamespaceNotFound() {
func (s *namespaceValidatorSuite) Test_StateValidationIntercept_NamespaceNotFound() {

nvi := NewNamespaceValidatorInterceptor(s.mockRegistry, dynamicconfig.GetBoolPropertyFn(false))
nvi := NewNamespaceValidatorInterceptor(
s.mockRegistry,
dynamicconfig.GetBoolPropertyFn(false),
dynamicconfig.GetIntPropertyFn(100))
serverInfo := &grpc.UnaryServerInfo{
FullMethod: "/temporal/random",
}

s.mockRegistry.EXPECT().GetNamespace(namespace.Name("not-found-namespace")).Return(nil, serviceerror.NewNamespaceNotFound("missing-namespace"))
req := &workflowservice.StartWorkflowExecutionRequest{Namespace: "not-found-namespace"}
handlerCalled := false
_, err := nvi.Intercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
_, err := nvi.StateValidationIntercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
handlerCalled = true
return &workflowservice.StartWorkflowExecutionResponse{}, nil
})
Expand All @@ -155,7 +161,7 @@ func (s *namespaceValidatorSuite) Test_Intercept_NamespaceNotFound() {
TaskToken: taskToken,
}
handlerCalled = false
_, err = nvi.Intercept(context.Background(), tokenReq, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
_, err = nvi.StateValidationIntercept(context.Background(), tokenReq, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
handlerCalled = true
return &workflowservice.RespondWorkflowTaskCompletedResponse{}, nil
})
Expand All @@ -164,7 +170,7 @@ func (s *namespaceValidatorSuite) Test_Intercept_NamespaceNotFound() {
s.False(handlerCalled)
}

func (s *namespaceValidatorSuite) Test_Intercept_StatusFromNamespace() {
func (s *namespaceValidatorSuite) Test_StateValidationIntercept_StatusFromNamespace() {
testCases := []struct {
state enumspb.NamespaceState
replicationState enumspb.ReplicationState
Expand Down Expand Up @@ -296,13 +302,16 @@ func (s *namespaceValidatorSuite) Test_Intercept_StatusFromNamespace() {
}), nil)
}

nvi := NewNamespaceValidatorInterceptor(s.mockRegistry, dynamicconfig.GetBoolPropertyFn(false))
nvi := NewNamespaceValidatorInterceptor(
s.mockRegistry,
dynamicconfig.GetBoolPropertyFn(false),
dynamicconfig.GetIntPropertyFn(100))
serverInfo := &grpc.UnaryServerInfo{
FullMethod: testCase.method,
}

handlerCalled := false
_, err := nvi.Intercept(context.Background(), testCase.req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
_, err := nvi.StateValidationIntercept(context.Background(), testCase.req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
handlerCalled = true
return &workflowservice.StartWorkflowExecutionResponse{}, nil
})
Expand All @@ -318,7 +327,7 @@ func (s *namespaceValidatorSuite) Test_Intercept_StatusFromNamespace() {
}
}

func (s *namespaceValidatorSuite) Test_Intercept_StatusFromToken() {
func (s *namespaceValidatorSuite) Test_StateValidationIntercept_StatusFromToken() {
taskToken, _ := common.NewProtoTaskTokenSerializer().Serialize(&tokenspb.Task{
NamespaceId: "test-namespace-id",
})
Expand Down Expand Up @@ -368,13 +377,16 @@ func (s *namespaceValidatorSuite) Test_Intercept_StatusFromToken() {
},
}), nil)

nvi := NewNamespaceValidatorInterceptor(s.mockRegistry, dynamicconfig.GetBoolPropertyFn(false))
nvi := NewNamespaceValidatorInterceptor(
s.mockRegistry,
dynamicconfig.GetBoolPropertyFn(false),
dynamicconfig.GetIntPropertyFn(100))
serverInfo := &grpc.UnaryServerInfo{
FullMethod: testCase.method,
}

handlerCalled := false
_, err := nvi.Intercept(context.Background(), testCase.req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
_, err := nvi.StateValidationIntercept(context.Background(), testCase.req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
handlerCalled = true
return &workflowservice.RespondWorkflowTaskCompletedResponse{}, nil
})
Expand All @@ -389,15 +401,18 @@ func (s *namespaceValidatorSuite) Test_Intercept_StatusFromToken() {
}
}

func (s *namespaceValidatorSuite) Test_Intercept_DescribeNamespace_Id() {
nvi := NewNamespaceValidatorInterceptor(s.mockRegistry, dynamicconfig.GetBoolPropertyFn(false))
func (s *namespaceValidatorSuite) Test_StateValidationIntercept_DescribeNamespace_Id() {
nvi := NewNamespaceValidatorInterceptor(
s.mockRegistry,
dynamicconfig.GetBoolPropertyFn(false),
dynamicconfig.GetIntPropertyFn(100))
serverInfo := &grpc.UnaryServerInfo{
FullMethod: "/temporal/random",
}

req := &workflowservice.DescribeNamespaceRequest{Id: "test-namespace-id"}
handlerCalled := false
_, err := nvi.Intercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
_, err := nvi.StateValidationIntercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
handlerCalled = true
return &workflowservice.DescribeNamespaceResponse{}, nil
})
Expand All @@ -407,7 +422,7 @@ func (s *namespaceValidatorSuite) Test_Intercept_DescribeNamespace_Id() {

req = &workflowservice.DescribeNamespaceRequest{}
handlerCalled = false
_, err = nvi.Intercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
_, err = nvi.StateValidationIntercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
handlerCalled = true
return &workflowservice.DescribeNamespaceResponse{}, nil
})
Expand All @@ -416,16 +431,19 @@ func (s *namespaceValidatorSuite) Test_Intercept_DescribeNamespace_Id() {
s.False(handlerCalled)
}

func (s *namespaceValidatorSuite) Test_Intercept_GetClusterInfo() {
nvi := NewNamespaceValidatorInterceptor(s.mockRegistry, dynamicconfig.GetBoolPropertyFn(false))
func (s *namespaceValidatorSuite) Test_StateValidationIntercept_GetClusterInfo() {
nvi := NewNamespaceValidatorInterceptor(
s.mockRegistry,
dynamicconfig.GetBoolPropertyFn(false),
dynamicconfig.GetIntPropertyFn(100))
serverInfo := &grpc.UnaryServerInfo{
FullMethod: "/temporal/random",
}

// Example of API which doesn't have namespace field.
req := &workflowservice.GetClusterInfoRequest{}
handlerCalled := false
_, err := nvi.Intercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
_, err := nvi.StateValidationIntercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
handlerCalled = true
return &workflowservice.GetClusterInfoResponse{}, nil
})
Expand All @@ -435,14 +453,17 @@ func (s *namespaceValidatorSuite) Test_Intercept_GetClusterInfo() {
}

func (s *namespaceValidatorSuite) Test_Intercept_RegisterNamespace() {
nvi := NewNamespaceValidatorInterceptor(s.mockRegistry, dynamicconfig.GetBoolPropertyFn(false))
nvi := NewNamespaceValidatorInterceptor(
s.mockRegistry,
dynamicconfig.GetBoolPropertyFn(false),
dynamicconfig.GetIntPropertyFn(100))
serverInfo := &grpc.UnaryServerInfo{
FullMethod: "/temporal/random",
}

req := &workflowservice.RegisterNamespaceRequest{Namespace: "new-namespace"}
handlerCalled := false
_, err := nvi.Intercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
_, err := nvi.StateValidationIntercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
handlerCalled = true
return &workflowservice.RegisterNamespaceResponse{}, nil
})
Expand All @@ -452,7 +473,7 @@ func (s *namespaceValidatorSuite) Test_Intercept_RegisterNamespace() {

req = &workflowservice.RegisterNamespaceRequest{}
handlerCalled = false
_, err = nvi.Intercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
_, err = nvi.StateValidationIntercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
handlerCalled = true
return &workflowservice.RegisterNamespaceResponse{}, nil
})
Expand All @@ -461,7 +482,7 @@ func (s *namespaceValidatorSuite) Test_Intercept_RegisterNamespace() {
s.False(handlerCalled)
}

func (s *namespaceValidatorSuite) Test_Interceptor_TokenNamespaceEnforcement() {
func (s *namespaceValidatorSuite) Test_StateValidationIntercept_TokenNamespaceEnforcement() {
testCases := []struct {
tokenNamespaceID namespace.ID
tokenNamespaceName namespace.Name
Expand Down Expand Up @@ -546,17 +567,20 @@ func (s *namespaceValidatorSuite) Test_Interceptor_TokenNamespaceEnforcement() {
s.mockRegistry.EXPECT().GetNamespace(testCase.requestNamespaceName).Return(requestNamespace, nil).Times(2)
s.mockRegistry.EXPECT().GetNamespaceByID(testCase.tokenNamespaceID).Return(tokenNamespace, nil).Times(2)

nvi := NewNamespaceValidatorInterceptor(s.mockRegistry, dynamicconfig.GetBoolPropertyFn(testCase.enableTokenNamespaceEnforcement))
nvi := NewNamespaceValidatorInterceptor(
s.mockRegistry,
dynamicconfig.GetBoolPropertyFn(testCase.enableTokenNamespaceEnforcement),
dynamicconfig.GetIntPropertyFn(100))
serverInfo := &grpc.UnaryServerInfo{
FullMethod: "/temporal/RandomMethod",
}

handlerCalled := false
_, err := nvi.Intercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
_, err := nvi.StateValidationIntercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
handlerCalled = true
return &workflowservice.RespondWorkflowTaskCompletedResponse{}, nil
})
_, queryErr := nvi.Intercept(context.Background(), queryReq, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
_, queryErr := nvi.StateValidationIntercept(context.Background(), queryReq, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
handlerCalled = true
return &workflowservice.RespondQueryTaskCompletedResponse{}, nil
})
Expand All @@ -572,3 +596,31 @@ func (s *namespaceValidatorSuite) Test_Interceptor_TokenNamespaceEnforcement() {
}
}
}

func (s *namespaceValidatorSuite) Test_LengthValidationIntercept() {
nvi := NewNamespaceValidatorInterceptor(
s.mockRegistry,
dynamicconfig.GetBoolPropertyFn(false),
dynamicconfig.GetIntPropertyFn(10))
serverInfo := &grpc.UnaryServerInfo{
FullMethod: "/temporal/random",
}

req := &workflowservice.StartWorkflowExecutionRequest{Namespace: "namespace"}
handlerCalled := false
_, err := nvi.LengthValidationIntercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
handlerCalled = true
return &workflowservice.StartWorkflowExecutionResponse{}, nil
})
s.True(handlerCalled)
s.NoError(err)

req = &workflowservice.StartWorkflowExecutionRequest{Namespace: "namespaceTooLong"}
handlerCalled = false
_, err = nvi.LengthValidationIntercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
handlerCalled = true
return &workflowservice.StartWorkflowExecutionResponse{}, nil
})
s.False(handlerCalled)
s.Error(err)
}
3 changes: 2 additions & 1 deletion common/rpc/interceptor/retry.go
Expand Up @@ -27,8 +27,9 @@ package interceptor
import (
"context"

"go.temporal.io/server/common/backoff"
"google.golang.org/grpc"

"go.temporal.io/server/common/backoff"
)

type (
Expand Down
17 changes: 11 additions & 6 deletions service/frontend/fx.go
Expand Up @@ -163,29 +163,33 @@ func GrpcServerOptionsProvider(
logger.Fatal("creating gRPC server options failed", tag.Error(err))
}
interceptors := []grpc.UnaryServerInterceptor{
namespaceLogInterceptor.Intercept,
// Service Error Interceptor should be the most outer interceptor on error handling
rpc.ServiceErrorInterceptor,
namespaceValidatorInterceptor.LengthValidationIntercept,
namespaceLogInterceptor.Intercept, // TODO: Deprecate this with a outer custom interceptor
grpc.UnaryServerInterceptor(traceInterceptor),
metrics.NewServerMetricsContextInjectorInterceptor(),
retryableInterceptor.Intercept,
telemetryInterceptor.Intercept,
namespaceValidatorInterceptor.Intercept,
namespaceCountLimiterInterceptor.Intercept,
namespaceRateLimiterInterceptor.Intercept,
rateLimitInterceptor.Intercept,
authorization.NewAuthorizationInterceptor(
claimMapper,
authorizer,
metricsClient,
logger,
audienceGetter,
),
namespaceValidatorInterceptor.StateValidationIntercept,
namespaceCountLimiterInterceptor.Intercept,
namespaceRateLimiterInterceptor.Intercept,
rateLimitInterceptor.Intercept,
sdkVersionInterceptor.Intercept,
callerInfoInterceptor.Intercept,
}
if len(customInterceptors) > 0 {
// TODO: Deprecate WithChainedFrontendGrpcInterceptors and provide a inner custom interceptor
interceptors = append(interceptors, customInterceptors...)
}
// retry interceptor should be the most inner interceptor
interceptors = append(interceptors, retryableInterceptor.Intercept)

return append(
grpcServerOptions,
Expand Down Expand Up @@ -309,6 +313,7 @@ func NamespaceValidatorInterceptorProvider(
return interceptor.NewNamespaceValidatorInterceptor(
namespaceRegistry,
serviceConfig.EnableTokenNamespaceEnforcement,
serviceConfig.MaxIDLengthLimit,
)
}

Expand Down

0 comments on commit ca5d0fd

Please sign in to comment.