From b4c2ccaa43f0550a3cf6d6191e6485b01f6d19dc Mon Sep 17 00:00:00 2001 From: pdoerner <122412190+pdoerner@users.noreply.github.com> Date: Fri, 24 May 2024 09:03:28 -0700 Subject: [PATCH] Add support for forwarding Nexus HTTP requests (#5793) ## What changed? * Added a new component `cluster.HttpClientCache` which serves a similar purpose to our gRPC `ClientBean` but provides HTTP clients for remote clusters. * Added logic to forward Nexus requests from standby clusters to active. * Exposed frontend `namespace.Registry` from our test `temporalImpl` so that tests can use `Eventually` functions to wait for namespace data to be updated in-memory without having to use `time.Sleep` ## Why? So that Nexus requests can be forwarded across clusters. ## How did you test it? New unit and functional tests. --- common/cluster/frontend_http_client.go | 114 ++++++ common/collection/oncemap.go | 10 + common/nexus/failure.go | 44 +++ common/resource/fx.go | 8 + common/rpc/interceptor/redirection.go | 24 +- common/rpc/interceptor/redirection_test.go | 14 +- service/frontend/fx.go | 6 + service/frontend/nexus_handler.go | 176 +++++++++- service/frontend/nexus_handler_test.go | 106 +++++- service/frontend/nexus_http_handler.go | 18 +- tests/nexus_api_test.go | 12 +- tests/onebox.go | 4 + tests/xdc/base.go | 4 +- tests/xdc/nexus_request_forwarding_test.go | 384 +++++++++++++++++++++ 14 files changed, 862 insertions(+), 62 deletions(-) create mode 100644 common/cluster/frontend_http_client.go create mode 100644 tests/xdc/nexus_request_forwarding_test.go diff --git a/common/cluster/frontend_http_client.go b/common/cluster/frontend_http_client.go new file mode 100644 index 00000000000..d83335f036f --- /dev/null +++ b/common/cluster/frontend_http_client.go @@ -0,0 +1,114 @@ +// The MIT License +// +// Copyright (c) 2024 Temporal Technologies Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package cluster + +import ( + "crypto/tls" + "fmt" + "net/http" + "net/url" + + "go.temporal.io/api/serviceerror" + + "go.temporal.io/server/common/collection" +) + +type tlsConfigProvider interface { + GetRemoteClusterClientConfig(hostname string) (*tls.Config, error) +} + +type FrontendHTTPClient struct { + http.Client + Address string +} + +type FrontendHTTPClientCache struct { + metadata Metadata + tlsProvider tlsConfigProvider + clients *collection.FallibleOnceMap[string, *FrontendHTTPClient] +} + +func NewFrontendHTTPClientCache( + metadata Metadata, + tlsProvider tlsConfigProvider, +) *FrontendHTTPClientCache { + cache := &FrontendHTTPClientCache{ + metadata: metadata, + tlsProvider: tlsProvider, + } + cache.clients = collection.NewFallibleOnceMap(cache.newClientForCluster) + metadata.RegisterMetadataChangeCallback(cache, cache.evictionCallback) + return cache +} + +// Get returns a cached HttpClient if available, or constructs a new one for the given cluster name. +func (c *FrontendHTTPClientCache) Get(targetClusterName string) (*FrontendHTTPClient, error) { + return c.clients.Get(targetClusterName) +} + +func (c *FrontendHTTPClientCache) newClientForCluster(targetClusterName string) (*FrontendHTTPClient, error) { + targetInfo, ok := c.metadata.GetAllClusterInfo()[targetClusterName] + if !ok { + return nil, serviceerror.NewNotFound(fmt.Sprintf("could not find cluster metadata for cluster %s", targetClusterName)) + } + + address, err := url.Parse(targetInfo.HTTPAddress) + if err != nil { + return nil, err + } + + client := http.Client{} + + if c.tlsProvider != nil { + tlsClientConfig, err := c.tlsProvider.GetRemoteClusterClientConfig(address.Hostname()) + if err != nil { + return nil, err + } + client.Transport = &http.Transport{TLSClientConfig: tlsClientConfig} + } + + return &FrontendHTTPClient{ + Address: targetInfo.HTTPAddress, + Client: client, + }, nil +} + +// evictionCallback is invoked by cluster.Metadata when cluster information changes. +// It invalidates clients which are either no longer present or have had their HTTP address changed. +// It is assumed that TLS information has not changed for clusters that are unmodified. +func (c *FrontendHTTPClientCache) evictionCallback(oldClusterMetadata map[string]*ClusterInformation, newClusterMetadata map[string]*ClusterInformation) { + for oldClusterName, oldClusterInfo := range oldClusterMetadata { + if oldClusterName == c.metadata.GetCurrentClusterName() || oldClusterInfo == nil { + continue + } + + newClusterInfo, exists := newClusterMetadata[oldClusterName] + if !exists || oldClusterInfo.HTTPAddress != newClusterInfo.HTTPAddress { + // Cluster was removed or had its HTTP address changed, so invalidate the cached client for that cluster. + client, ok := c.clients.Pop(oldClusterName) + if ok { + client.CloseIdleConnections() + } + } + } +} diff --git a/common/collection/oncemap.go b/common/collection/oncemap.go index 328b37fbb66..f58d76f3101 100644 --- a/common/collection/oncemap.go +++ b/common/collection/oncemap.go @@ -95,3 +95,13 @@ func (p *FallibleOnceMap[K, T]) Get(key K) (T, error) { return value, nil } + +func (p *FallibleOnceMap[K, T]) Pop(key K) (T, bool) { + p.mu.Lock() + defer p.mu.Unlock() + val, ok := p.inner[key] + if ok { + delete(p.inner, key) + } + return val, ok +} diff --git a/common/nexus/failure.go b/common/nexus/failure.go index b4f6489076d..55217ffdae7 100644 --- a/common/nexus/failure.go +++ b/common/nexus/failure.go @@ -24,6 +24,7 @@ package nexus import ( "errors" + "net/http" "github.com/nexus-rpc/sdk-go/nexus" commonpb "go.temporal.io/api/common/v1" @@ -192,3 +193,46 @@ func AdaptAuthorizeError(err error) error { } return nexus.HandlerErrorf(nexus.HandlerErrorTypeUnauthorized, "permission denied") } + +func HandlerErrorFromClientError(err error) error { + var unexpectedRespErr *nexus.UnexpectedResponseError + if errors.As(err, &unexpectedRespErr) { + failure := unexpectedRespErr.Failure + if unexpectedRespErr.Failure == nil { + failure = &nexus.Failure{ + Message: unexpectedRespErr.Error(), + } + } + handlerErr := &nexus.HandlerError{ + Failure: failure, + } + + switch unexpectedRespErr.Response.StatusCode { + case http.StatusBadRequest: + handlerErr.Type = nexus.HandlerErrorTypeBadRequest + case http.StatusUnauthorized: + handlerErr.Type = nexus.HandlerErrorTypeUnauthenticated + case http.StatusForbidden: + handlerErr.Type = nexus.HandlerErrorTypeUnauthorized + case http.StatusNotFound: + handlerErr.Type = nexus.HandlerErrorTypeNotFound + case http.StatusTooManyRequests: + handlerErr.Type = nexus.HandlerErrorTypeResourceExhausted + case http.StatusInternalServerError: + handlerErr.Type = nexus.HandlerErrorTypeInternal + case http.StatusNotImplemented: + handlerErr.Type = nexus.HandlerErrorTypeNotImplemented + case http.StatusServiceUnavailable: + handlerErr.Type = nexus.HandlerErrorTypeUnavailable + case nexus.StatusDownstreamError: + handlerErr.Type = nexus.HandlerErrorTypeDownstreamError + case nexus.StatusDownstreamTimeout: + handlerErr.Type = nexus.HandlerErrorTypeDownstreamTimeout + } + + return handlerErr + } + + // Let the nexus SDK handle this for us (log and convert to an internal error). + return err +} diff --git a/common/resource/fx.go b/common/resource/fx.go index 20164fe89e9..fa0153bd415 100644 --- a/common/resource/fx.go +++ b/common/resource/fx.go @@ -122,6 +122,7 @@ var Module = fx.Options( fx.Provide(MatchingRawClientProvider), fx.Provide(MatchingClientProvider), membership.GRPCResolverModule, + fx.Provide(FrontendHTTPClientCacheProvider), fx.Invoke(RegisterBootstrapContainer), fx.Provide(PersistenceConfigProvider), fx.Provide(health.NewServer), @@ -408,6 +409,13 @@ func RPCFactoryProvider( ), nil } +func FrontendHTTPClientCacheProvider( + metadata cluster.Metadata, + tlsConfigProvider encryption.TLSConfigProvider, +) *cluster.FrontendHTTPClientCache { + return cluster.NewFrontendHTTPClientCache(metadata, tlsConfigProvider) +} + func getFrontendConnectionDetails( cfg *config.Config, tlsConfigProvider encryption.TLSConfigProvider, diff --git a/common/rpc/interceptor/redirection.go b/common/rpc/interceptor/redirection.go index f77fd824f8e..1c8fc913c0c 100644 --- a/common/rpc/interceptor/redirection.go +++ b/common/rpc/interceptor/redirection.go @@ -47,8 +47,8 @@ import ( ) const ( - dcRedirectionContextHeaderName = "xdc-redirection" - dcRedirectionApiHeaderName = "xdc-redirection-api" + DCRedirectionContextHeaderName = "xdc-redirection" + DCRedirectionApiHeaderName = "xdc-redirection-api" dcRedirectionMetricsPrefix = "DCRedirection" ) @@ -185,7 +185,7 @@ func (i *Redirection) Intercept( if !strings.HasPrefix(info.FullMethod, api.WorkflowServicePrefix) { return handler(ctx, req) } - if !i.redirectionAllowed(ctx) { + if !i.RedirectionAllowed(ctx) { return handler(ctx, req) } @@ -213,9 +213,9 @@ func (i *Redirection) handleLocalAPIInvocation( handler grpc.UnaryHandler, methodName string, ) (_ any, retError error) { - scope, startTime := i.beforeCall(dcRedirectionMetricsPrefix + methodName) + scope, startTime := i.BeforeCall(dcRedirectionMetricsPrefix + methodName) defer func() { - i.afterCall(scope, startTime, i.currentClusterName, retError) + i.AfterCall(scope, startTime, i.currentClusterName, retError) }() return handler(ctx, req) } @@ -233,9 +233,9 @@ func (i *Redirection) handleRedirectAPIInvocation( var clusterName string var err error - scope, startTime := i.beforeCall(dcRedirectionMetricsPrefix + methodName) + scope, startTime := i.BeforeCall(dcRedirectionMetricsPrefix + methodName) defer func() { - i.afterCall(scope, startTime, clusterName, retError) + i.AfterCall(scope, startTime, clusterName, retError) }() err = i.redirectionPolicy.WithNamespaceRedirect(ctx, namespaceName, methodName, func(targetDC string) error { @@ -248,7 +248,7 @@ func (i *Redirection) handleRedirectAPIInvocation( return err } resp = respCtorFn() - ctx = metadata.AppendToOutgoingContext(ctx, dcRedirectionApiHeaderName, "true") + ctx = metadata.AppendToOutgoingContext(ctx, DCRedirectionApiHeaderName, "true") err = remoteClient.Invoke(ctx, info.FullMethod, req, resp) if err != nil { return err @@ -259,13 +259,13 @@ func (i *Redirection) handleRedirectAPIInvocation( return resp, err } -func (i *Redirection) beforeCall( +func (i *Redirection) BeforeCall( operation string, ) (metrics.Handler, time.Time) { return i.metricsHandler.WithTags(metrics.OperationTag(operation), metrics.ServiceRoleTag(metrics.DCRedirectionRoleTagValue)), i.timeSource.Now() } -func (i *Redirection) afterCall( +func (i *Redirection) AfterCall( metricsHandler metrics.Handler, startTime time.Time, clusterName string, @@ -279,7 +279,7 @@ func (i *Redirection) afterCall( } } -func (i *Redirection) redirectionAllowed( +func (i *Redirection) RedirectionAllowed( ctx context.Context, ) bool { // default to allow dc redirection @@ -287,7 +287,7 @@ func (i *Redirection) redirectionAllowed( if !ok { return true } - values := md.Get(dcRedirectionContextHeaderName) + values := md.Get(DCRedirectionContextHeaderName) if len(values) == 0 { return true } diff --git a/common/rpc/interceptor/redirection_test.go b/common/rpc/interceptor/redirection_test.go index e072beab4f4..0690a0a74ea 100644 --- a/common/rpc/interceptor/redirection_test.go +++ b/common/rpc/interceptor/redirection_test.go @@ -333,31 +333,31 @@ func (s *redirectionInterceptorSuite) TestHandleGlobalAPIInvocation_NamespaceNot func (s *redirectionInterceptorSuite) TestRedirectionAllowed_Empty() { ctx := metadata.NewIncomingContext(context.Background(), metadata.New(map[string]string{})) - allowed := s.redirector.redirectionAllowed(ctx) + allowed := s.redirector.RedirectionAllowed(ctx) s.True(allowed) } func (s *redirectionInterceptorSuite) TestRedirectionAllowed_Error() { ctx := metadata.NewIncomingContext(context.Background(), metadata.New(map[string]string{ - dcRedirectionContextHeaderName: "?", + DCRedirectionContextHeaderName: "?", })) - allowed := s.redirector.redirectionAllowed(ctx) + allowed := s.redirector.RedirectionAllowed(ctx) s.True(allowed) } func (s *redirectionInterceptorSuite) TestRedirectionAllowed_True() { ctx := metadata.NewIncomingContext(context.Background(), metadata.New(map[string]string{ - dcRedirectionContextHeaderName: "t", + DCRedirectionContextHeaderName: "t", })) - allowed := s.redirector.redirectionAllowed(ctx) + allowed := s.redirector.RedirectionAllowed(ctx) s.True(allowed) } func (s *redirectionInterceptorSuite) TestRedirectionAllowed_False() { ctx := metadata.NewIncomingContext(context.Background(), metadata.New(map[string]string{ - dcRedirectionContextHeaderName: "f", + DCRedirectionContextHeaderName: "f", })) - allowed := s.redirector.redirectionAllowed(ctx) + allowed := s.redirector.RedirectionAllowed(ctx) s.False(allowed) } diff --git a/service/frontend/fx.go b/service/frontend/fx.go index 083108a86ee..d5225dcd048 100644 --- a/service/frontend/fx.go +++ b/service/frontend/fx.go @@ -726,9 +726,12 @@ func RegisterNexusHTTPHandler( serviceName primitives.ServiceName, matchingClient resource.MatchingClient, metricsHandler metrics.Handler, + clusterMetadata cluster.Metadata, + clientCache *cluster.FrontendHTTPClientCache, namespaceRegistry namespace.Registry, endpointRegistry *nexus.EndpointRegistry, authInterceptor *authorization.Interceptor, + redirectionInterceptor *interceptor.Redirection, namespaceRateLimiterInterceptor *interceptor.NamespaceRateLimitInterceptor, namespaceCountLimiterInterceptor *interceptor.ConcurrentRequestLimitInterceptor, namespaceValidatorInterceptor *interceptor.NamespaceValidatorInterceptor, @@ -740,9 +743,12 @@ func RegisterNexusHTTPHandler( serviceConfig, matchingClient, metricsHandler, + clusterMetadata, + clientCache, namespaceRegistry, endpointRegistry, authInterceptor, + redirectionInterceptor, namespaceValidatorInterceptor, namespaceRateLimiterInterceptor, namespaceCountLimiterInterceptor, diff --git a/service/frontend/nexus_handler.go b/service/frontend/nexus_handler.go index a3ce6c07226..b915d641e45 100644 --- a/service/frontend/nexus_handler.go +++ b/service/frontend/nexus_handler.go @@ -26,7 +26,9 @@ import ( "context" "errors" "fmt" + "net/url" "runtime/debug" + "strconv" "time" "github.com/nexus-rpc/sdk-go/nexus" @@ -39,6 +41,8 @@ import ( "go.temporal.io/server/api/matchingservice/v1" "go.temporal.io/server/common" "go.temporal.io/server/common/authorization" + "go.temporal.io/server/common/cluster" + "go.temporal.io/server/common/dynamicconfig" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/metrics" @@ -65,14 +69,17 @@ type nexusContext struct { // Context for a specific Nexus operation, includes a resolved namespace, and a bound metrics handler and logger. type operationContext struct { nexusContext - namespace *namespace.Namespace + clusterMetadata cluster.Metadata + namespace *namespace.Namespace // "Special" metrics handler that should only be passed to interceptors, which require a different set of // pre-baked tags than the "normal" metricsHandler. metricsHandlerForInterceptors metrics.Handler metricsHandler metrics.Handler logger log.Logger auth *authorization.Interceptor - cleanupFunctions []func() + redirectionInterceptor *interceptor.Redirection + forwardingEnabledForNamespace dynamicconfig.BoolPropertyFnWithNamespaceFilter + cleanupFunctions []func(error) } // Panic handler and metrics recording function. @@ -95,7 +102,7 @@ func (c *operationContext) capturePanicAndRecordMetrics(errPtr *error) { c.metricsHandler.Histogram(metrics.NexusLatencyHistogram.Name(), metrics.Milliseconds).Record(time.Since(c.requestStartTime).Milliseconds()) for _, fn := range c.cleanupFunctions { - fn() + fn(*errPtr) } } @@ -110,7 +117,7 @@ func (c *operationContext) matchingRequest(req *nexuspb.Request) *matchingservic func (c *operationContext) interceptRequest(ctx context.Context, request *matchingservice.DispatchNexusTaskRequest, header nexus.Header) error { err := c.auth.Authorize(ctx, c.claims, &authorization.CallTarget{ APIName: c.apiName, - Namespace: c.namespace.Name().String(), + Namespace: c.namespaceName, Request: request, }) if err != nil { @@ -122,10 +129,25 @@ func (c *operationContext) interceptRequest(ctx context.Context, request *matchi c.metricsHandler = c.metricsHandler.WithTags(metrics.NexusOutcomeTag("invalid_namespace_state")) return commonnexus.ConvertGRPCError(err, false) } - // TODO: Redirect if current cluster is passive for this namespace. + + if !c.namespace.ActiveInCluster(c.clusterMetadata.GetCurrentClusterName()) { + notActiveErr := serviceerror.NewNamespaceNotActive(c.namespaceName, c.clusterMetadata.GetCurrentClusterName(), c.namespace.ActiveClusterName()) + if c.shouldForwardRequest(ctx, header) { + // Handler methods should have special logic to forward requests if this method returns a serviceerror.NamespaceNotActive error. + c.metricsHandler = c.metricsHandler.WithTags(metrics.NexusOutcomeTag("request_forwarded")) + var forwardStartTime time.Time + c.metricsHandlerForInterceptors, forwardStartTime = c.redirectionInterceptor.BeforeCall(c.apiName) + c.cleanupFunctions = append(c.cleanupFunctions, func(retErr error) { + c.redirectionInterceptor.AfterCall(c.metricsHandlerForInterceptors, forwardStartTime, c.namespace.ActiveClusterName(), retErr) + }) + return notActiveErr + } + c.metricsHandler = c.metricsHandler.WithTags(metrics.NexusOutcomeTag("namespace_inactive_forwarding_disabled")) + return nexus.HandlerErrorf(nexus.HandlerErrorTypeUnavailable, "cluster inactive") + } cleanup, err := c.namespaceConcurrencyLimitInterceptor.Allow(c.namespace.Name(), c.apiName, c.metricsHandlerForInterceptors, request) - c.cleanupFunctions = append(c.cleanupFunctions, cleanup) + c.cleanupFunctions = append(c.cleanupFunctions, func(error) { cleanup() }) if err != nil { c.metricsHandler = c.metricsHandler.WithTags(metrics.NexusOutcomeTag("namespace_concurrency_limited")) return commonnexus.ConvertGRPCError(err, false) @@ -145,6 +167,23 @@ func (c *operationContext) interceptRequest(ctx context.Context, request *matchi return nil } +// Combines logic from RedirectionInterceptor.redirectionAllowed and some from +// SelectedAPIsForwardingRedirectionPolicy.getTargetClusterAndIsNamespaceNotActiveAutoForwarding so all +// redirection conditions can be checked at once. If either of those methods are updated, this should +// be kept in sync. +func (c *operationContext) shouldForwardRequest(ctx context.Context, header nexus.Header) bool { + redirectHeader := header.Get(interceptor.DCRedirectionContextHeaderName) + redirectAllowed, err := strconv.ParseBool(redirectHeader) + if err != nil { + redirectAllowed = true + } + return redirectAllowed && + c.redirectionInterceptor.RedirectionAllowed(ctx) && + c.namespace.IsGlobalNamespace() && + len(c.namespace.ClusterNames()) > 1 && + c.forwardingEnabledForNamespace(c.namespaceName) +} + // Key to extract a nexusContext object from a context.Context. type nexusContextKey struct{} @@ -152,11 +191,15 @@ type nexusContextKey struct{} // Dispatches Nexus requests as Nexus tasks to workers via matching. type nexusHandler struct { nexus.UnimplementedHandler - logger log.Logger - metricsHandler metrics.Handler - namespaceRegistry namespace.Registry - matchingClient matchingservice.MatchingServiceClient - auth *authorization.Interceptor + logger log.Logger + metricsHandler metrics.Handler + clusterMetadata cluster.Metadata + namespaceRegistry namespace.Registry + matchingClient matchingservice.MatchingServiceClient + auth *authorization.Interceptor + redirectionInterceptor *interceptor.Redirection + forwardingEnabledForNamespace dynamicconfig.BoolPropertyFnWithNamespaceFilter + forwardingClients *cluster.FrontendHTTPClientCache } // Extracts a nexusContext from the given ctx and returns an operationContext with tagged metrics and logging. @@ -166,8 +209,14 @@ func (h *nexusHandler) getOperationContext(ctx context.Context, method string) ( if !ok { return nil, errors.New("no nexus context set on context") //nolint:goerr113 } - oc := operationContext{nexusContext: nc, auth: h.auth, cleanupFunctions: make([]func(), 0)} - + oc := operationContext{ + nexusContext: nc, + clusterMetadata: h.clusterMetadata, + auth: h.auth, + redirectionInterceptor: h.redirectionInterceptor, + forwardingEnabledForNamespace: h.forwardingEnabledForNamespace, + cleanupFunctions: make([]func(error), 0), + } oc.metricsHandlerForInterceptors = h.metricsHandler.WithTags( metrics.OperationTag(nc.apiName), metrics.NamespaceTag(nc.namespaceName), @@ -193,6 +242,7 @@ func (h *nexusHandler) getOperationContext(ctx context.Context, method string) ( } return nil, commonnexus.ConvertGRPCError(err, false) } + oc.forwardingEnabledForNamespace = h.forwardingEnabledForNamespace oc.logger = log.With(h.logger, tag.Operation(method), tag.WorkflowNamespace(nc.namespaceName)) return &oc, nil } @@ -219,7 +269,12 @@ func (h *nexusHandler) StartOperation(ctx context.Context, service, operation st StartOperation: &startOperationRequest, }, }) + if err := oc.interceptRequest(ctx, request, options.Header); err != nil { + var notActiveErr *serviceerror.NamespaceNotActive + if errors.As(err, ¬ActiveErr) { + return h.forwardStartOperation(ctx, service, operation, input, options, oc) + } return nil, err } @@ -273,6 +328,37 @@ func (h *nexusHandler) StartOperation(ctx context.Context, service, operation st return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeDownstreamError, "empty outcome") } +// forwardStartOperation forwards the StartOperation request to the active cluster using an HTTP request. +// Inputs and response values are passed as Reader objects to avoid reading bodies and bypass serialization. +func (h *nexusHandler) forwardStartOperation( + ctx context.Context, + service string, + operation string, + input *nexus.LazyValue, + options nexus.StartOperationOptions, + oc *operationContext, +) (nexus.HandlerStartOperationResult[any], error) { + options.Header[interceptor.DCRedirectionApiHeaderName] = "true" + + client, err := h.nexusClientForActiveCluster(oc, service) + if err != nil { + return nil, err + } + + resp, err := client.StartOperation(ctx, operation, input.Reader, options) + if err != nil { + oc.logger.Error("received error from remote cluster for forwarded Nexus start operation request.", tag.Error(err)) + oc.metricsHandler = oc.metricsHandler.WithTags(metrics.NexusOutcomeTag("forwarded_request_error")) + return nil, commonnexus.HandlerErrorFromClientError(err) + } + + if resp.Successful != nil { + return &nexus.HandlerStartOperationResultSync[any]{Value: resp.Successful.Reader}, nil + } + // If Nexus client did not return an error, one of Successful or Pending will always be set. + return &nexus.HandlerStartOperationResultAsync{OperationID: resp.Pending.ID}, nil +} + func (h *nexusHandler) CancelOperation(ctx context.Context, service, operation, id string, options nexus.CancelOperationOptions) (retErr error) { oc, err := h.getOperationContext(ctx, "CancelOperation") if err != nil { @@ -292,6 +378,10 @@ func (h *nexusHandler) CancelOperation(ctx context.Context, service, operation, }, }) if err := oc.interceptRequest(ctx, request, options.Header); err != nil { + var notActiveErr *serviceerror.NamespaceNotActive + if errors.As(err, ¬ActiveErr) { + return h.forwardCancelOperation(ctx, service, operation, id, options, oc) + } return err } @@ -322,7 +412,65 @@ func (h *nexusHandler) CancelOperation(ctx context.Context, service, operation, return nexus.HandlerErrorf(nexus.HandlerErrorTypeDownstreamError, "empty outcome") } -// convertNexusHandlerError converts any 5xx user handler error to a downsream error. +func (h *nexusHandler) forwardCancelOperation( + ctx context.Context, + service string, + operation string, + id string, + options nexus.CancelOperationOptions, + oc *operationContext, +) error { + options.Header[interceptor.DCRedirectionApiHeaderName] = "true" + + client, err := h.nexusClientForActiveCluster(oc, service) + if err != nil { + return err + } + + handle, err := client.NewHandle(operation, id) + if err != nil { + oc.logger.Warn("invalid Nexus cancel operation.", tag.Error(err)) + return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation") + } + + err = handle.Cancel(ctx, options) + if err != nil { + oc.logger.Error("received error from remote cluster for forwarded Nexus cancel operation request.", tag.Error(err)) + oc.metricsHandler = oc.metricsHandler.WithTags(metrics.NexusOutcomeTag("forwarded_request_error")) + return commonnexus.HandlerErrorFromClientError(err) + } + + return nil +} + +func (h *nexusHandler) nexusClientForActiveCluster(oc *operationContext, service string) (*nexus.Client, error) { + httpClient, err := h.forwardingClients.Get(oc.namespace.ActiveClusterName()) + if err != nil { + oc.logger.Error("failed to forward Nexus request. error creating HTTP client", tag.Error(err), tag.SourceCluster(oc.namespace.ActiveClusterName()), tag.TargetCluster(oc.namespace.ActiveClusterName())) + oc.metricsHandler = oc.metricsHandler.WithTags(metrics.NexusOutcomeTag("request_forwarding_failed")) + return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "request forwarding failed") + } + + baseURL, err := url.JoinPath( + httpClient.Address, + commonnexus.RouteDispatchNexusTaskByNamespaceAndTaskQueue.Path(commonnexus.NamespaceAndTaskQueue{ + Namespace: oc.namespaceName, + TaskQueue: oc.taskQueue, + })) + if err != nil { + oc.logger.Error(fmt.Sprintf("failed to forward Nexus request. error constructing ServiceBaseURL. address=%s namespace=%s task_queue=%s", httpClient.Address, oc.namespaceName, oc.taskQueue), tag.Error(err)) + oc.metricsHandler = oc.metricsHandler.WithTags(metrics.NexusOutcomeTag("request_forwarding_failed")) + return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "request forwarding failed") + } + + return nexus.NewClient(nexus.ClientOptions{ + HTTPCaller: httpClient.Do, + BaseURL: baseURL, + Service: service, + }) +} + +// convertNexusHandlerError converts any 5xx user handler error to a downstream error. func convertNexusHandlerError(t nexus.HandlerErrorType) nexus.HandlerErrorType { switch t { case nexus.HandlerErrorTypeDownstreamTimeout, diff --git a/service/frontend/nexus_handler_test.go b/service/frontend/nexus_handler_test.go index a2d63e57a8d..8f9ebe6a1a6 100644 --- a/service/frontend/nexus_handler_test.go +++ b/service/frontend/nexus_handler_test.go @@ -32,13 +32,20 @@ import ( "github.com/nexus-rpc/sdk-go/nexus" "github.com/stretchr/testify/require" enumspb "go.temporal.io/api/enums/v1" + "go.temporal.io/api/serviceerror" + "go.temporal.io/server/api/matchingservice/v1" persistencespb "go.temporal.io/server/api/persistence/v1" "go.temporal.io/server/common/authorization" + "go.temporal.io/server/common/clock" + "go.temporal.io/server/common/cluster" + "go.temporal.io/server/common/cluster/clustertest" + "go.temporal.io/server/common/config" + "go.temporal.io/server/common/dynamicconfig" "go.temporal.io/server/common/log" "go.temporal.io/server/common/metrics/metricstest" "go.temporal.io/server/common/namespace" - "go.temporal.io/server/common/persistence" + "go.temporal.io/server/common/primitives/timestamp" "go.temporal.io/server/common/quotas" "go.temporal.io/server/common/rpc/interceptor" ) @@ -84,9 +91,11 @@ func (n mockNamespaceChecker) Exists(name namespace.Name) error { type contextOptions struct { namespaceState enumspb.NamespaceState + namespacePassive bool quota int namespaceRateLimitAllow bool rateLimitAllow bool + redirectAllow bool } func newOperationContext(options contextOptions) *operationContext { @@ -96,18 +105,31 @@ func newOperationContext(options contextOptions) *operationContext { oc.metricsHandlerForInterceptors = mh oc.metricsHandler = mh oc.apiName = "/temporal.api.nexusservice.v1.NexusService/DispatchNexusTask" - oc.namespace = namespace.FromPersistentState(&persistence.GetNamespaceResponse{ - Namespace: &persistencespb.NamespaceDetail{ - Info: &persistencespb.NamespaceInfo{ - Id: uuid.NewString(), - Name: "test", - State: options.namespaceState, - }, - Config: &persistencespb.NamespaceConfig{ - CustomSearchAttributeAliases: make(map[string]string), + + oc.namespaceName = "test-namespace" + activeClusterName := cluster.TestCurrentClusterName + if options.namespacePassive { + activeClusterName = cluster.TestAlternativeClusterName + } + oc.namespace = namespace.NewGlobalNamespaceForTest( + &persistencespb.NamespaceInfo{ + Id: uuid.NewString(), + Name: oc.namespaceName, + State: options.namespaceState, + }, + &persistencespb.NamespaceConfig{ + Retention: timestamp.DurationFromDays(1), + CustomSearchAttributeAliases: make(map[string]string), + }, + &persistencespb.NamespaceReplicationConfig{ + ActiveClusterName: activeClusterName, + Clusters: []string{ + cluster.TestCurrentClusterName, + cluster.TestAlternativeClusterName, }, }, - }) + 1, + ) checker := mockNamespaceChecker(oc.namespace.Name()) oc.auth = authorization.NewInterceptor(nil, mockAuthorizer{}, oc.metricsHandler, oc.logger, checker, nil, "", "") @@ -123,10 +145,15 @@ func newOperationContext(options contextOptions) *operationContext { ) oc.namespaceRateLimitInterceptor = interceptor.NewNamespaceRateLimitInterceptor(nil, mockRateLimiter{options.namespaceRateLimitAllow}, make(map[string]int)) oc.rateLimitInterceptor = interceptor.NewRateLimitInterceptor(mockRateLimiter{options.rateLimitAllow}, make(map[string]int)) + + oc.clusterMetadata = clustertest.NewMetadataForTest(cluster.NewTestClusterMetadataConfig(true, !options.namespacePassive)) + oc.forwardingEnabledForNamespace = dynamicconfig.GetBoolPropertyFnFilteredByNamespace(options.redirectAllow) + oc.redirectionInterceptor = interceptor.NewRedirection(nil, nil, config.DCRedirectionPolicy{Policy: interceptor.DCRedirectionPolicyAllAPIsForwarding}, oc.logger, nil, oc.metricsHandlerForInterceptors, clock.NewRealTimeSource(), oc.clusterMetadata) + return oc } -func TestNexusInterceptRequeset_InvalidNamespaceState_ResultsInBadRequest(t *testing.T) { +func TestNexusInterceptRequest_InvalidNamespaceState_ResultsInBadRequest(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() var err error @@ -150,7 +177,7 @@ func TestNexusInterceptRequeset_InvalidNamespaceState_ResultsInBadRequest(t *tes require.Equal(t, map[string]string{"outcome": "invalid_namespace_state"}, snap["test"][0].Tags) } -func TestNexusInterceptRequeset_NamespaceConcurrencyLimited_ResultsInResourceExhausted(t *testing.T) { +func TestNexusInterceptRequest_NamespaceConcurrencyLimited_ResultsInResourceExhausted(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() var err error @@ -174,7 +201,7 @@ func TestNexusInterceptRequeset_NamespaceConcurrencyLimited_ResultsInResourceExh require.Equal(t, map[string]string{"outcome": "namespace_concurrency_limited"}, snap["test"][0].Tags) } -func TestNexusInterceptRequeset_NamespaceRateLimited_ResultsInResourceExhausted(t *testing.T) { +func TestNexusInterceptRequest_NamespaceRateLimited_ResultsInResourceExhausted(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() var err error @@ -198,7 +225,7 @@ func TestNexusInterceptRequeset_NamespaceRateLimited_ResultsInResourceExhausted( require.Equal(t, map[string]string{"outcome": "namespace_rate_limited"}, snap["test"][0].Tags) } -func TestNexusInterceptRequeset_GlobalRateLimited_ResultsInResourceExhausted(t *testing.T) { +func TestNexusInterceptRequest_GlobalRateLimited_ResultsInResourceExhausted(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() var err error @@ -221,3 +248,52 @@ func TestNexusInterceptRequeset_GlobalRateLimited_ResultsInResourceExhausted(t * require.Equal(t, 1, len(snap["test"])) require.Equal(t, map[string]string{"outcome": "global_rate_limited"}, snap["test"][0].Tags) } + +func TestNexusInterceptRequest_ForwardingDisabled_ResultsInUnavailable(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + var err error + oc := newOperationContext(contextOptions{ + namespaceState: enumspb.NAMESPACE_STATE_REGISTERED, + namespacePassive: true, + quota: 1, + namespaceRateLimitAllow: true, + rateLimitAllow: true, + redirectAllow: false, + }) + err = oc.interceptRequest(ctx, &matchingservice.DispatchNexusTaskRequest{}, nexus.Header{}) + var handlerError *nexus.HandlerError + require.ErrorAs(t, err, &handlerError) + require.Equal(t, nexus.HandlerErrorTypeUnavailable, handlerError.Type) + mh := oc.metricsHandler.(*metricstest.CaptureHandler) //nolint:revive + capture := mh.StartCapture() + oc.metricsHandler.Counter("test").Record(1) + mh.StopCapture(capture) + snap := capture.Snapshot() + require.Equal(t, 1, len(snap["test"])) + require.Equal(t, map[string]string{"outcome": "namespace_inactive_forwarding_disabled"}, snap["test"][0].Tags) +} + +func TestNexusInterceptRequest_ForwardingEnabled_ResultsInNotActiveError(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + var err error + oc := newOperationContext(contextOptions{ + namespaceState: enumspb.NAMESPACE_STATE_REGISTERED, + namespacePassive: true, + quota: 1, + namespaceRateLimitAllow: true, + rateLimitAllow: true, + redirectAllow: true, + }) + err = oc.interceptRequest(ctx, &matchingservice.DispatchNexusTaskRequest{}, nexus.Header{}) + var notActiveErr *serviceerror.NamespaceNotActive + require.ErrorAs(t, err, ¬ActiveErr) + mh := oc.metricsHandler.(*metricstest.CaptureHandler) //nolint:revive + capture := mh.StartCapture() + oc.metricsHandler.Counter("test").Record(1) + mh.StopCapture(capture) + snap := capture.Snapshot() + require.Equal(t, 1, len(snap["test"])) + require.Equal(t, map[string]string{"outcome": "request_forwarded"}, snap["test"][0].Tags) +} diff --git a/service/frontend/nexus_http_handler.go b/service/frontend/nexus_http_handler.go index 4d7c194b21a..c03fb91af6f 100644 --- a/service/frontend/nexus_http_handler.go +++ b/service/frontend/nexus_http_handler.go @@ -40,6 +40,7 @@ import ( "go.temporal.io/server/api/matchingservice/v1" "go.temporal.io/server/common/authorization" + "go.temporal.io/server/common/cluster" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/metrics" @@ -69,9 +70,12 @@ func NewNexusHTTPHandler( serviceConfig *Config, matchingClient matchingservice.MatchingServiceClient, metricsHandler metrics.Handler, + clusterMetadata cluster.Metadata, + clientCache *cluster.FrontendHTTPClientCache, namespaceRegistry namespace.Registry, endpointRegistry *commonnexus.EndpointRegistry, authInterceptor *authorization.Interceptor, + redirectionInterceptor *interceptor.Redirection, namespaceValidationInterceptor *interceptor.NamespaceValidatorInterceptor, namespaceRateLimitInterceptor *interceptor.NamespaceRateLimitInterceptor, namespaceConcurrencyLimitIntercptor *interceptor.ConcurrentRequestLimitInterceptor, @@ -90,11 +94,15 @@ func NewNexusHTTPHandler( preprocessErrorCounter: metricsHandler.Counter(metrics.NexusRequestPreProcessErrors.Name()).Record, nexusHandler: nexus.NewHTTPHandler(nexus.HandlerOptions{ Handler: &nexusHandler{ - logger: logger, - metricsHandler: metricsHandler, - namespaceRegistry: namespaceRegistry, - matchingClient: matchingClient, - auth: authInterceptor, + logger: logger, + metricsHandler: metricsHandler, + clusterMetadata: clusterMetadata, + namespaceRegistry: namespaceRegistry, + matchingClient: matchingClient, + auth: authInterceptor, + redirectionInterceptor: redirectionInterceptor, + forwardingEnabledForNamespace: serviceConfig.EnableNamespaceNotActiveAutoForwarding, + forwardingClients: clientCache, }, GetResultTimeout: serviceConfig.KeepAliveMaxConnectionIdle(), Logger: log.NewSlogLogger(logger), diff --git a/tests/nexus_api_test.go b/tests/nexus_api_test.go index 784d2f31fea..d836bff6819 100644 --- a/tests/nexus_api_test.go +++ b/tests/nexus_api_test.go @@ -152,8 +152,7 @@ func (s *ClientFunctionalSuite) TestNexusStartOperation_Outcomes() { assertion: func(t *testing.T, res *nexus.ClientStartOperationResult[string], err error) { var unexpectedError *nexus.UnexpectedResponseError require.ErrorAs(t, err, &unexpectedError) - // TODO: nexus should export this - require.Equal(t, 520, unexpectedError.Response.StatusCode) + require.Equal(t, nexus.StatusDownstreamError, unexpectedError.Response.StatusCode) require.Equal(t, "deliberate internal failure", unexpectedError.Failure.Message) }, }, @@ -172,8 +171,7 @@ func (s *ClientFunctionalSuite) TestNexusStartOperation_Outcomes() { assertion: func(t *testing.T, res *nexus.ClientStartOperationResult[string], err error) { var unexpectedError *nexus.UnexpectedResponseError require.ErrorAs(t, err, &unexpectedError) - // TODO: nexus should export this - require.Equal(t, 521, unexpectedError.Response.StatusCode) + require.Equal(t, nexus.StatusDownstreamTimeout, unexpectedError.Response.StatusCode) require.Equal(t, "downstream timeout", unexpectedError.Failure.Message) }, }, @@ -532,8 +530,7 @@ func (s *ClientFunctionalSuite) TestNexusCancelOperation_Outcomes() { assertion: func(t *testing.T, err error) { var unexpectedError *nexus.UnexpectedResponseError require.ErrorAs(t, err, &unexpectedError) - // TODO: nexus should export this - require.Equal(t, 520, unexpectedError.Response.StatusCode) + require.Equal(t, nexus.StatusDownstreamError, unexpectedError.Response.StatusCode) require.Equal(t, "deliberate internal failure", unexpectedError.Failure.Message) }, }, @@ -552,8 +549,7 @@ func (s *ClientFunctionalSuite) TestNexusCancelOperation_Outcomes() { assertion: func(t *testing.T, err error) { var unexpectedError *nexus.UnexpectedResponseError require.ErrorAs(t, err, &unexpectedError) - // TODO: nexus should export this - require.Equal(t, 521, unexpectedError.Response.StatusCode) + require.Equal(t, nexus.StatusDownstreamTimeout, unexpectedError.Response.StatusCode) require.Equal(t, "downstream timeout", unexpectedError.Failure.Message) }, }, diff --git a/tests/onebox.go b/tests/onebox.go index f7bbf8ea09c..76f9c5d9a2d 100644 --- a/tests/onebox.go +++ b/tests/onebox.go @@ -400,6 +400,10 @@ func (c *temporalImpl) GetMatchingClient() matchingservice.MatchingServiceClient return c.matchingClient } +func (c *temporalImpl) GetFrontendNamespaceRegistry() namespace.Registry { + return c.frontendNamespaceRegistry +} + func (c *temporalImpl) startFrontend( hostsByService map[primitives.ServiceName]static.Hosts, startWG *sync.WaitGroup, diff --git a/tests/xdc/base.go b/tests/xdc/base.go index 5c895ee87b6..540bace0f66 100644 --- a/tests/xdc/base.go +++ b/tests/xdc/base.go @@ -38,9 +38,10 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" replicationpb "go.temporal.io/api/replication/v1" - "go.temporal.io/server/common/testing/historyrequire" "gopkg.in/yaml.v3" + "go.temporal.io/server/common/testing/historyrequire" + "go.temporal.io/server/api/adminservice/v1" "go.temporal.io/server/api/historyservice/v1" "go.temporal.io/server/common/cluster" @@ -124,6 +125,7 @@ func (s *xdcBaseSuite) setupSuite(clusterNames []string, opts ...tests.Option) { HTTPAddress: fmt.Sprintf("http://127.0.0.1:%d144", 7+i), } clusterConfigs[i].ServiceFxOptions = params.ServiceOptions + clusterConfigs[i].EnableMetricsCapture = true } c, err := s.testClusterFactory.NewCluster(s.T(), clusterConfigs[0], log.With(s.logger, tag.ClusterName(s.clusterNames[0]))) diff --git a/tests/xdc/nexus_request_forwarding_test.go b/tests/xdc/nexus_request_forwarding_test.go new file mode 100644 index 00000000000..5744fa3a086 --- /dev/null +++ b/tests/xdc/nexus_request_forwarding_test.go @@ -0,0 +1,384 @@ +// The MIT License +// +// Copyright (c) 2024 Temporal Technologies Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package xdc + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "net/http" + "testing" + "time" + + "github.com/google/uuid" + "github.com/nexus-rpc/sdk-go/nexus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + enumspb "go.temporal.io/api/enums/v1" + nexuspb "go.temporal.io/api/nexus/v1" + taskqueuepb "go.temporal.io/api/taskqueue/v1" + "go.temporal.io/api/workflowservice/v1" + "google.golang.org/protobuf/types/known/durationpb" + + "go.temporal.io/server/common/dynamicconfig" + "go.temporal.io/server/common/metrics" + "go.temporal.io/server/common/metrics/metricstest" + "go.temporal.io/server/common/namespace" + cnexus "go.temporal.io/server/common/nexus" + "go.temporal.io/server/tests" +) + +var op = nexus.NewOperationReference[string, string]("my-operation") + +type NexusRequestForwardingSuite struct { + xdcBaseSuite +} + +func TestNexusRequestForwardingTestSuite(t *testing.T) { + flag.Parse() + suite.Run(t, new(NexusRequestForwardingSuite)) +} + +func (s *NexusRequestForwardingSuite) SetupSuite() { + s.dynamicConfigOverrides = map[dynamicconfig.Key]any{ + // Make sure we don't hit the rate limiter in tests + dynamicconfig.FrontendGlobalNamespaceNamespaceReplicationInducingAPIsRPS.Key(): 1000, + dynamicconfig.EnableNexus.Key(): true, + dynamicconfig.RefreshNexusEndpointsMinWait.Key(): 1 * time.Millisecond, + } + s.setupSuite([]string{"nexus_request_forwarding_active", "nexus_request_forwarding_standby"}) +} + +func (s *NexusRequestForwardingSuite) SetupTest() { + s.setupTest() +} + +func (s *NexusRequestForwardingSuite) TearDownSuite() { + s.tearDownSuite() +} + +// Only tests dispatch by namespace+task_queue. +// TODO: Add test cases for dispatch by endpoint ID once endpoints support replication. +func (s *NexusRequestForwardingSuite) TestStartOperationForwardedFromStandbyToActive() { + ns := s.createNexusRequestForwardingNamespace() + + testCases := []struct { + name string + taskQueue string + header nexus.Header + handler func(*workflowservice.PollNexusTaskQueueResponse) (*nexuspb.Response, *nexuspb.HandlerError) + assertion func(*testing.T, *nexus.ClientStartOperationResult[string], error, map[string][]*metricstest.CapturedRecording, map[string][]*metricstest.CapturedRecording) + }{ + { + name: "success", + taskQueue: fmt.Sprintf("%v-%v", "test-task-queue", uuid.New()), + handler: func(res *workflowservice.PollNexusTaskQueueResponse) (*nexuspb.Response, *nexuspb.HandlerError) { + s.Equal("true", res.Request.Header["xdc-redirection-api"]) + return &nexuspb.Response{ + Variant: &nexuspb.Response_StartOperation{ + StartOperation: &nexuspb.StartOperationResponse{ + Variant: &nexuspb.StartOperationResponse_SyncSuccess{ + SyncSuccess: &nexuspb.StartOperationResponse_Sync{ + Payload: res.Request.GetStartOperation().GetPayload()}}}}, + }, nil + }, + assertion: func(t *testing.T, result *nexus.ClientStartOperationResult[string], retErr error, activeSnap map[string][]*metricstest.CapturedRecording, passiveSnap map[string][]*metricstest.CapturedRecording) { + require.NoError(t, retErr) + require.Equal(t, "input", result.Successful) + requireExpectedMetricsCaptured(t, activeSnap, ns, "StartOperation", "sync_success") + requireExpectedMetricsCaptured(t, passiveSnap, ns, "StartOperation", "request_forwarded") + }, + }, + { + name: "operation error", + taskQueue: fmt.Sprintf("%v-%v", "test-task-queue", uuid.New()), + handler: func(res *workflowservice.PollNexusTaskQueueResponse) (*nexuspb.Response, *nexuspb.HandlerError) { + s.Equal("true", res.Request.Header["xdc-redirection-api"]) + return &nexuspb.Response{ + Variant: &nexuspb.Response_StartOperation{ + StartOperation: &nexuspb.StartOperationResponse{ + Variant: &nexuspb.StartOperationResponse_OperationError{ + OperationError: &nexuspb.UnsuccessfulOperationError{ + OperationState: string(nexus.OperationStateFailed), + Failure: &nexuspb.Failure{ + Message: "deliberate test failure", + Metadata: map[string]string{"k": "v"}, + Details: []byte(`"details"`), + }}}}}, + }, nil + }, + assertion: func(t *testing.T, result *nexus.ClientStartOperationResult[string], retErr error, activeSnap map[string][]*metricstest.CapturedRecording, passiveSnap map[string][]*metricstest.CapturedRecording) { + var operationError *nexus.UnsuccessfulOperationError + require.ErrorAs(t, retErr, &operationError) + require.Equal(t, nexus.OperationStateFailed, operationError.State) + require.Equal(t, "deliberate test failure", operationError.Failure.Message) + require.Equal(t, map[string]string{"k": "v"}, operationError.Failure.Metadata) + var details string + err := json.Unmarshal(operationError.Failure.Details, &details) + require.NoError(t, err) + require.Equal(t, "details", details) + requireExpectedMetricsCaptured(t, activeSnap, ns, "StartOperation", "operation_error") + requireExpectedMetricsCaptured(t, passiveSnap, ns, "StartOperation", "forwarded_request_error") + }, + }, + { + name: "handler error", + taskQueue: fmt.Sprintf("%v-%v", "test-task-queue", uuid.New()), + handler: func(res *workflowservice.PollNexusTaskQueueResponse) (*nexuspb.Response, *nexuspb.HandlerError) { + s.Equal("true", res.Request.Header["xdc-redirection-api"]) + return nil, &nexuspb.HandlerError{ + ErrorType: string(nexus.HandlerErrorTypeInternal), + Failure: &nexuspb.Failure{Message: "deliberate internal failure"}, + } + }, + assertion: func(t *testing.T, result *nexus.ClientStartOperationResult[string], retErr error, activeSnap map[string][]*metricstest.CapturedRecording, passiveSnap map[string][]*metricstest.CapturedRecording) { + var unexpectedError *nexus.UnexpectedResponseError + require.ErrorAs(t, retErr, &unexpectedError) + require.Equal(t, nexus.StatusDownstreamError, unexpectedError.Response.StatusCode) + require.Equal(t, "deliberate internal failure", unexpectedError.Failure.Message) + requireExpectedMetricsCaptured(t, activeSnap, ns, "StartOperation", "handler_error") + requireExpectedMetricsCaptured(t, passiveSnap, ns, "StartOperation", "forwarded_request_error") + }, + }, + { + name: "redirect disabled by header", + taskQueue: fmt.Sprintf("%v-%v", "test-task-queue", uuid.New()), + header: nexus.Header{"xdc-redirection": "false"}, + handler: func(res *workflowservice.PollNexusTaskQueueResponse) (*nexuspb.Response, *nexuspb.HandlerError) { + s.FailNow("nexus task handler invoked when redirection should be disabled") + return nil, &nexuspb.HandlerError{ + ErrorType: string(nexus.HandlerErrorTypeInternal), + Failure: &nexuspb.Failure{Message: "redirection not allowed"}, + } + }, + assertion: func(t *testing.T, result *nexus.ClientStartOperationResult[string], retErr error, activeSnap map[string][]*metricstest.CapturedRecording, passiveSnap map[string][]*metricstest.CapturedRecording) { + var unexpectedError *nexus.UnexpectedResponseError + require.ErrorAs(t, retErr, &unexpectedError) + require.Equal(t, http.StatusServiceUnavailable, unexpectedError.Response.StatusCode) + require.Equal(t, "cluster inactive", unexpectedError.Failure.Message) + requireExpectedMetricsCaptured(t, passiveSnap, ns, "StartOperation", "namespace_inactive_forwarding_disabled") + }, + }, + } + + for _, tc := range testCases { + tc := tc + s.T().Run(tc.name, func(t *testing.T) { + dispatchURL := fmt.Sprintf("http://%s/%s", s.cluster2.GetHost().FrontendHTTPAddress(), cnexus.RouteDispatchNexusTaskByNamespaceAndTaskQueue.Path(cnexus.NamespaceAndTaskQueue{Namespace: ns, TaskQueue: tc.taskQueue})) + nexusClient, err := nexus.NewClient(nexus.ClientOptions{BaseURL: dispatchURL, Service: "test-service"}) + s.NoError(err) + + activeMetricsHandler, ok := s.cluster1.GetHost().GetMetricsHandler().(*metricstest.CaptureHandler) + s.True(ok) + activeCapture := activeMetricsHandler.StartCapture() + defer activeMetricsHandler.StopCapture(activeCapture) + + passiveMetricsHandler, ok := s.cluster2.GetHost().GetMetricsHandler().(*metricstest.CaptureHandler) + s.True(ok) + passiveCapture := passiveMetricsHandler.StartCapture() + defer passiveMetricsHandler.StopCapture(passiveCapture) + + ctx, cancel := context.WithCancel(tests.NewContext()) + defer cancel() + + go s.nexusTaskPoller(ctx, s.cluster1.GetFrontendClient(), ns, tc.taskQueue, tc.handler) + + startResult, err := nexus.StartOperation(ctx, nexusClient, op, "input", nexus.StartOperationOptions{ + CallbackURL: "http://localhost/callback", + RequestID: "request-id", + Header: tc.header, + }) + tc.assertion(t, startResult, err, activeCapture.Snapshot(), passiveCapture.Snapshot()) + }) + } +} + +// Only tests dispatch by namespace+task_queue. +// TODO: Add test cases for dispatch by endpoint ID once endpoints support replication. +func (s *NexusRequestForwardingSuite) TestCancelOperationForwardedFromStandbyToActive() { + ns := s.createNexusRequestForwardingNamespace() + + testCases := []struct { + name string + taskQueue string + header nexus.Header + handler func(*workflowservice.PollNexusTaskQueueResponse) (*nexuspb.Response, *nexuspb.HandlerError) + assertion func(*testing.T, error, map[string][]*metricstest.CapturedRecording, map[string][]*metricstest.CapturedRecording) + }{ + { + name: "success", + taskQueue: fmt.Sprintf("%v-%v", "test-task-queue", uuid.New()), + handler: func(res *workflowservice.PollNexusTaskQueueResponse) (*nexuspb.Response, *nexuspb.HandlerError) { + s.Equal("true", res.Request.Header["xdc-redirection-api"]) + return &nexuspb.Response{ + Variant: &nexuspb.Response_CancelOperation{ + CancelOperation: &nexuspb.CancelOperationResponse{}, + }, + }, nil + }, + assertion: func(t *testing.T, retErr error, activeSnap map[string][]*metricstest.CapturedRecording, passiveSnap map[string][]*metricstest.CapturedRecording) { + require.NoError(t, retErr) + requireExpectedMetricsCaptured(t, activeSnap, ns, "CancelOperation", "success") + requireExpectedMetricsCaptured(t, passiveSnap, ns, "CancelOperation", "request_forwarded") + }, + }, + { + name: "handler error", + taskQueue: fmt.Sprintf("%v-%v", "test-task-queue", uuid.New()), + handler: func(res *workflowservice.PollNexusTaskQueueResponse) (*nexuspb.Response, *nexuspb.HandlerError) { + s.Equal("true", res.Request.Header["xdc-redirection-api"]) + return nil, &nexuspb.HandlerError{ + ErrorType: string(nexus.HandlerErrorTypeInternal), + Failure: &nexuspb.Failure{Message: "deliberate internal failure"}, + } + }, + assertion: func(t *testing.T, retErr error, activeSnap map[string][]*metricstest.CapturedRecording, passiveSnap map[string][]*metricstest.CapturedRecording) { + var unexpectedError *nexus.UnexpectedResponseError + require.ErrorAs(t, retErr, &unexpectedError) + require.Equal(t, nexus.StatusDownstreamError, unexpectedError.Response.StatusCode) + require.Equal(t, "deliberate internal failure", unexpectedError.Failure.Message) + requireExpectedMetricsCaptured(t, activeSnap, ns, "CancelOperation", "handler_error") + requireExpectedMetricsCaptured(t, passiveSnap, ns, "CancelOperation", "forwarded_request_error") + }, + }, + { + name: "redirect disabled by header", + taskQueue: fmt.Sprintf("%v-%v", "test-task-queue", uuid.New()), + header: nexus.Header{"xdc-redirection": "false"}, + handler: func(res *workflowservice.PollNexusTaskQueueResponse) (*nexuspb.Response, *nexuspb.HandlerError) { + s.FailNow("nexus task handler invoked when redirection should be disabled") + return nil, &nexuspb.HandlerError{ + ErrorType: string(nexus.HandlerErrorTypeInternal), + Failure: &nexuspb.Failure{Message: "redirection should be disabled"}, + } + }, + assertion: func(t *testing.T, retErr error, activeSnap map[string][]*metricstest.CapturedRecording, passiveSnap map[string][]*metricstest.CapturedRecording) { + var unexpectedError *nexus.UnexpectedResponseError + require.ErrorAs(t, retErr, &unexpectedError) + require.Equal(t, http.StatusServiceUnavailable, unexpectedError.Response.StatusCode) + require.Equal(t, "cluster inactive", unexpectedError.Failure.Message) + requireExpectedMetricsCaptured(t, passiveSnap, ns, "CancelOperation", "namespace_inactive_forwarding_disabled") + }, + }, + } + + for _, tc := range testCases { + tc := tc + s.T().Run(tc.name, func(t *testing.T) { + dispatchURL := fmt.Sprintf("http://%s/%s", s.cluster2.GetHost().FrontendHTTPAddress(), cnexus.RouteDispatchNexusTaskByNamespaceAndTaskQueue.Path(cnexus.NamespaceAndTaskQueue{Namespace: ns, TaskQueue: tc.taskQueue})) + nexusClient, err := nexus.NewClient(nexus.ClientOptions{BaseURL: dispatchURL, Service: "test-service"}) + s.NoError(err) + + activeMetricsHandler, ok := s.cluster1.GetHost().GetMetricsHandler().(*metricstest.CaptureHandler) + s.True(ok) + activeCapture := activeMetricsHandler.StartCapture() + defer activeMetricsHandler.StopCapture(activeCapture) + + passiveMetricsHandler, ok := s.cluster2.GetHost().GetMetricsHandler().(*metricstest.CaptureHandler) + s.True(ok) + passiveCapture := passiveMetricsHandler.StartCapture() + defer passiveMetricsHandler.StopCapture(passiveCapture) + + ctx, cancel := context.WithCancel(tests.NewContext()) + defer cancel() + + go s.nexusTaskPoller(ctx, s.cluster1.GetFrontendClient(), ns, tc.taskQueue, tc.handler) + + handle, err := nexusClient.NewHandle("operation", "id") + require.NoError(t, err) + err = handle.Cancel(ctx, nexus.CancelOperationOptions{Header: tc.header}) + tc.assertion(t, err, activeCapture.Snapshot(), passiveCapture.Snapshot()) + }) + } +} + +func (s *NexusRequestForwardingSuite) nexusTaskPoller(ctx context.Context, frontendClient tests.FrontendClient, ns string, taskQueue string, handler func(*workflowservice.PollNexusTaskQueueResponse) (*nexuspb.Response, *nexuspb.HandlerError)) { + res, err := frontendClient.PollNexusTaskQueue(ctx, &workflowservice.PollNexusTaskQueueRequest{ + Namespace: ns, + Identity: uuid.NewString(), + TaskQueue: &taskqueuepb.TaskQueue{ + Name: taskQueue, + Kind: enumspb.TASK_QUEUE_KIND_NORMAL, + }, + }) + if ctx.Err() != nil { + // Test doesn't expect poll to get unblocked. + return + } + s.NoError(err) + + response, handlerErr := handler(res) + + if handlerErr != nil { + _, err = frontendClient.RespondNexusTaskFailed(ctx, &workflowservice.RespondNexusTaskFailedRequest{ + Namespace: ns, + Identity: uuid.NewString(), + TaskToken: res.TaskToken, + Error: handlerErr, + }) + } else if response != nil { + _, err = frontendClient.RespondNexusTaskCompleted(ctx, &workflowservice.RespondNexusTaskCompletedRequest{ + Namespace: ns, + Identity: uuid.NewString(), + TaskToken: res.TaskToken, + Response: response, + }) + } + + s.NoError(err) +} + +func (s *NexusRequestForwardingSuite) createNexusRequestForwardingNamespace() string { + ctx := tests.NewContext() + ns := fmt.Sprintf("%v-%v", "test-namespace", uuid.New()) + + regReq := &workflowservice.RegisterNamespaceRequest{ + Namespace: ns, + IsGlobalNamespace: true, + Clusters: s.clusterReplicationConfig(), + ActiveClusterName: s.clusterNames[0], + WorkflowExecutionRetentionPeriod: durationpb.New(7 * time.Hour * 24), + } + _, err := s.cluster1.GetFrontendClient().RegisterNamespace(ctx, regReq) + s.NoError(err) + + s.EventuallyWithT(func(t *assert.CollectT) { + // Wait for namespace record to be replicated and loaded into memory. + _, err := s.cluster2.GetHost().GetFrontendNamespaceRegistry().GetNamespace(namespace.Name(ns)) + assert.NoError(t, err) + }, 15*time.Second, 500*time.Millisecond) + + return ns +} + +func requireExpectedMetricsCaptured(t *testing.T, snap map[string][]*metricstest.CapturedRecording, ns string, method string, expectedOutcome string) { + require.Equal(t, 1, len(snap["nexus_requests"])) + require.Subset(t, snap["nexus_requests"][0].Tags, map[string]string{"namespace": ns, "method": method, "outcome": expectedOutcome}) + require.Equal(t, int64(1), snap["nexus_requests"][0].Value) + require.Equal(t, metrics.MetricUnit(""), snap["nexus_requests"][0].Unit) + require.Equal(t, 1, len(snap["nexus_latency"])) + require.Subset(t, snap["nexus_latency"][0].Tags, map[string]string{"namespace": ns, "method": method, "outcome": expectedOutcome}) + require.Equal(t, metrics.MetricUnit(metrics.Milliseconds), snap["nexus_latency"][0].Unit) +}