From 0fa7f9278ce986bdb3d361d667f501dac1e83e6f Mon Sep 17 00:00:00 2001 From: Tim Vernum Date: Fri, 12 Jan 2024 15:48:49 +1100 Subject: [PATCH] Convert RestWrapper to an explicit Interceptor Adds a new `RestInterceptor` interface and converts `RestServerActionPlugin.getRestHandlerInterceptor` to return this new type instead of a wrapping function. This has the following benefits: - Less object creation, there is 1 instance of the interceptor class (see `SecurityRestFilter`) rather than an instance per handler - More control over the sequence of steps in processing a request. The explicit interceptor separates it from the deprecation handler or any validation that might be needed, and the controller can be intentional about the order in which these operations are applied. --- .../elasticsearch/test/CustomRestPlugin.java | 27 +++-- .../elasticsearch/action/ActionModule.java | 4 +- .../interceptor/RestServerActionPlugin.java | 6 +- .../elasticsearch/rest/RestController.java | 57 +++++++-- .../elasticsearch/rest/RestInterceptor.java | 27 +++++ .../action/ActionModuleTests.java | 10 +- .../rest/RestControllerTests.java | 17 +-- .../core/LocalStateCompositeXPackPlugin.java | 6 +- .../xpack/security/Security.java | 6 +- .../security/rest/SecurityRestFilter.java | 54 ++++----- .../rest/SecurityRestFilterTests.java | 108 ++++++++---------- 11 files changed, 179 insertions(+), 143 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/rest/RestInterceptor.java diff --git a/qa/custom-rest-controller/src/javaRestTest/java/co/elastic/elasticsearch/test/CustomRestPlugin.java b/qa/custom-rest-controller/src/javaRestTest/java/co/elastic/elasticsearch/test/CustomRestPlugin.java index 37c79fe2abb0b..4fbdfa65d40ba 100644 --- a/qa/custom-rest-controller/src/javaRestTest/java/co/elastic/elasticsearch/test/CustomRestPlugin.java +++ b/qa/custom-rest-controller/src/javaRestTest/java/co/elastic/elasticsearch/test/CustomRestPlugin.java @@ -8,6 +8,7 @@ package co.elastic.elasticsearch.test; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.indices.breaker.CircuitBreakerService; @@ -18,12 +19,11 @@ import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; +import org.elasticsearch.rest.RestInterceptor; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.telemetry.tracing.Tracer; import org.elasticsearch.usage.UsageService; -import java.util.function.UnaryOperator; - public class CustomRestPlugin extends Plugin implements RestServerActionPlugin { private static final Logger logger = LogManager.getLogger(CustomRestPlugin.class); @@ -35,34 +35,33 @@ private static void echoHeader(String name, RestRequest request, ThreadContext t } } - public static class CustomInterceptor implements RestHandler { + public static class CustomInterceptor implements RestInterceptor { private final ThreadContext threadContext; - private final RestHandler delegate; - public CustomInterceptor(ThreadContext threadContext, RestHandler delegate) { + public CustomInterceptor(ThreadContext threadContext) { this.threadContext = threadContext; - this.delegate = delegate; } @Override - public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) throws Exception { + public void intercept(RestRequest request, RestChannel channel, RestHandler targetHandler, ActionListener listener) + throws Exception { logger.info("intercept request {} {}", request.method(), request.uri()); echoHeader("x-test-interceptor", request, threadContext); - delegate.handleRequest(request, channel, client); + listener.onResponse(Boolean.TRUE); } } public static class CustomController extends RestController { public CustomController( - UnaryOperator handlerWrapper, + RestInterceptor interceptor, NodeClient client, CircuitBreakerService circuitBreakerService, UsageService usageService, Tracer tracer ) { - super(handlerWrapper, client, circuitBreakerService, usageService, tracer); + super(interceptor, client, circuitBreakerService, usageService, tracer); } @Override @@ -74,19 +73,19 @@ public void dispatchRequest(RestRequest request, RestChannel channel, ThreadCont } @Override - public UnaryOperator getRestHandlerInterceptor(ThreadContext threadContext) { - return handler -> new CustomInterceptor(threadContext, handler); + public RestInterceptor getRestHandlerInterceptor(ThreadContext threadContext) { + return new CustomInterceptor(threadContext); } @Override public RestController getRestController( - UnaryOperator handlerWrapper, + RestInterceptor interceptor, NodeClient client, CircuitBreakerService circuitBreakerService, UsageService usageService, Tracer tracer ) { - return new CustomController(handlerWrapper, client, circuitBreakerService, usageService, tracer); + return new CustomController(interceptor, client, circuitBreakerService, usageService, tracer); } } diff --git a/server/src/main/java/org/elasticsearch/action/ActionModule.java b/server/src/main/java/org/elasticsearch/action/ActionModule.java index dd70dc65b853b..3e6ec1bde7448 100644 --- a/server/src/main/java/org/elasticsearch/action/ActionModule.java +++ b/server/src/main/java/org/elasticsearch/action/ActionModule.java @@ -264,6 +264,7 @@ import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; import org.elasticsearch.rest.RestHeaderDefinition; +import org.elasticsearch.rest.RestInterceptor; import org.elasticsearch.rest.RestUtils; import org.elasticsearch.rest.action.RestFieldCapabilitiesAction; import org.elasticsearch.rest.action.admin.cluster.RestAddVotingConfigExclusionAction; @@ -425,7 +426,6 @@ import java.util.function.Function; import java.util.function.Predicate; import java.util.function.Supplier; -import java.util.function.UnaryOperator; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -501,7 +501,7 @@ public ActionModule( new RestHeaderDefinition(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER, false) ) ).collect(Collectors.toSet()); - UnaryOperator restInterceptor = getRestServerComponent( + final RestInterceptor restInterceptor = getRestServerComponent( "REST interceptor", actionPlugins, restPlugin -> restPlugin.getRestHandlerInterceptor(threadPool.getThreadContext()) diff --git a/server/src/main/java/org/elasticsearch/plugins/interceptor/RestServerActionPlugin.java b/server/src/main/java/org/elasticsearch/plugins/interceptor/RestServerActionPlugin.java index 35badffe0b3aa..44653dcf8b5fe 100644 --- a/server/src/main/java/org/elasticsearch/plugins/interceptor/RestServerActionPlugin.java +++ b/server/src/main/java/org/elasticsearch/plugins/interceptor/RestServerActionPlugin.java @@ -14,7 +14,7 @@ import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.rest.RestController; -import org.elasticsearch.rest.RestHandler; +import org.elasticsearch.rest.RestInterceptor; import org.elasticsearch.telemetry.tracing.Tracer; import org.elasticsearch.usage.UsageService; @@ -46,7 +46,7 @@ public interface RestServerActionPlugin extends ActionPlugin { * * Note: Only one installed plugin may implement a rest interceptor. */ - UnaryOperator getRestHandlerInterceptor(ThreadContext threadContext); + RestInterceptor getRestHandlerInterceptor(ThreadContext threadContext); /** * Returns a replacement {@link RestController} to be used in the server. @@ -54,7 +54,7 @@ public interface RestServerActionPlugin extends ActionPlugin { */ @Nullable default RestController getRestController( - @Nullable UnaryOperator handlerWrapper, + @Nullable RestInterceptor interceptor, NodeClient client, CircuitBreakerService circuitBreakerService, UsageService usageService, diff --git a/server/src/main/java/org/elasticsearch/rest/RestController.java b/server/src/main/java/org/elasticsearch/rest/RestController.java index 6a5d6f99df64b..56a70e3bb2ab4 100644 --- a/server/src/main/java/org/elasticsearch/rest/RestController.java +++ b/server/src/main/java/org/elasticsearch/rest/RestController.java @@ -13,6 +13,8 @@ import org.apache.logging.log4j.Logger; import org.apache.lucene.util.BytesRef; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.common.Strings; import org.elasticsearch.common.breaker.CircuitBreaker; @@ -57,7 +59,6 @@ import java.util.TreeMap; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Supplier; -import java.util.function.UnaryOperator; import static org.elasticsearch.indices.SystemIndices.EXTERNAL_SYSTEM_INDEX_ACCESS_CONTROL_HEADER_KEY; import static org.elasticsearch.indices.SystemIndices.SYSTEM_INDEX_ACCESS_CONTROL_HEADER_KEY; @@ -95,7 +96,7 @@ public class RestController implements HttpServerTransport.Dispatcher { private final PathTrie handlers = new PathTrie<>(RestUtils.REST_DECODER); - private final UnaryOperator handlerWrapper; + private final RestInterceptor interceptor; private final NodeClient client; @@ -107,7 +108,7 @@ public class RestController implements HttpServerTransport.Dispatcher { private final ServerlessApiProtections apiProtections; public RestController( - UnaryOperator handlerWrapper, + RestInterceptor restInterceptor, NodeClient client, CircuitBreakerService circuitBreakerService, UsageService usageService, @@ -115,10 +116,10 @@ public RestController( ) { this.usageService = usageService; this.tracer = tracer; - if (handlerWrapper == null) { - handlerWrapper = h -> h; // passthrough if no wrapper set + if (restInterceptor == null) { + restInterceptor = (request, channel, targetHandler, listener) -> listener.onResponse(Boolean.TRUE); } - this.handlerWrapper = handlerWrapper; + this.interceptor = restInterceptor; this.client = client; this.circuitBreakerService = circuitBreakerService; registerHandlerNoWrap(RestRequest.Method.GET, "/favicon.ico", RestApiVersion.current(), new RestFavIconHandler()); @@ -264,7 +265,7 @@ protected void registerHandler(RestRequest.Method method, String path, RestApiVe if (handler instanceof BaseRestHandler) { usageService.addRestHandler((BaseRestHandler) handler); } - registerHandlerNoWrap(method, path, version, handlerWrapper.apply(handler)); + registerHandlerNoWrap(method, path, version, handler); } private void registerHandlerNoWrap(RestRequest.Method method, String path, RestApiVersion version, RestHandler handler) { @@ -325,7 +326,7 @@ public void dispatchRequest(RestRequest request, RestChannel channel, ThreadCont tryAllHandlers(request, channel, threadContext); } catch (Exception e) { try { - channel.sendResponse(new RestResponse(channel, e)); + sendFailure(channel, e); } catch (Exception inner) { inner.addSuppressed(e); logger.error(() -> "failed to send failure response for uri [" + request.uri() + "]", inner); @@ -348,7 +349,7 @@ public void dispatchBadRequest(final RestChannel channel, final ThreadContext th // unless it's a http headers validation error, we consider any exceptions encountered so far during request processing // to be a problem of invalid/malformed request (hence the RestStatus#BAD_REQEST (400) HTTP response code) if (e instanceof HttpHeadersValidationException) { - channel.sendResponse(new RestResponse(channel, (Exception) e.getCause())); + sendFailure(channel, (Exception) e.getCause()); } else { channel.sendResponse(new RestResponse(channel, BAD_REQUEST, e)); } @@ -438,12 +439,44 @@ private void dispatchRequest( } else { threadContext.putHeader(SYSTEM_INDEX_ACCESS_CONTROL_HEADER_KEY, Boolean.TRUE.toString()); } - handler.handleRequest(request, responseChannel, client); + final var finalChannel = responseChannel; + this.interceptor.intercept(request, responseChannel, handler.getConcreteRestHandler(), new ActionListener<>() { + @Override + public void onResponse(Boolean processRequest) { + if (processRequest) { + try { + validateRequest(request, handler, client); + handler.handleRequest(request, finalChannel, client); + } catch (Exception e) { + onFailure(e); + } + } + } + + @Override + public void onFailure(Exception e) { + try { + sendFailure(finalChannel, e); + } catch (IOException ex) { + logger.info("Failed to send error [{}] to HTTP client", ex.toString()); + } + } + }); } catch (Exception e) { - responseChannel.sendResponse(new RestResponse(responseChannel, e)); + sendFailure(responseChannel, e); } } + /** + * Validates that the request should be allowed. Throws an exception if the request should be rejected. + */ + @SuppressWarnings("unused") + protected void validateRequest(RestRequest request, RestHandler handler, NodeClient client) throws ElasticsearchStatusException {} + + private static void sendFailure(RestChannel responseChannel, Exception e) throws IOException { + responseChannel.sendResponse(new RestResponse(responseChannel, e)); + } + /** * in order to prevent CSRF we have to reject all media types that are from a browser safelist * see https://fetch.spec.whatwg.org/#cors-safelisted-request-header @@ -691,7 +724,7 @@ public static void handleBadRequest(String uri, RestRequest.Method method, RestC public static void handleServerlessRequestToProtectedResource(String uri, RestRequest.Method method, RestChannel channel) throws IOException { String msg = "uri [" + uri + "] with method [" + method + "] exists but is not available when running in serverless mode"; - channel.sendResponse(new RestResponse(channel, new ApiNotAvailableException(msg))); + sendFailure(channel, new ApiNotAvailableException(msg)); } /** diff --git a/server/src/main/java/org/elasticsearch/rest/RestInterceptor.java b/server/src/main/java/org/elasticsearch/rest/RestInterceptor.java new file mode 100644 index 0000000000000..dd0d444073040 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/rest/RestInterceptor.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.rest; + +import org.elasticsearch.action.ActionListener; + +/** + * Wraps the execution of a {@link RestHandler} + */ +@FunctionalInterface +public interface RestInterceptor { + + /** + * @param listener The interceptor responds with {@code True} if the handler should be called, + * or {@code False} if the request has been entirely handled by the interceptor. + * In the case of {@link ActionListener#onFailure(Exception)}, the target handler + * will not be called, the request will be treated as unhandled, and the regular + * rest exception handling will be performed + */ + void intercept(RestRequest request, RestChannel channel, RestHandler targetHandler, ActionListener listener) throws Exception; +} diff --git a/server/src/test/java/org/elasticsearch/action/ActionModuleTests.java b/server/src/test/java/org/elasticsearch/action/ActionModuleTests.java index a076537bb7351..291db437fcc25 100644 --- a/server/src/test/java/org/elasticsearch/action/ActionModuleTests.java +++ b/server/src/test/java/org/elasticsearch/action/ActionModuleTests.java @@ -30,6 +30,7 @@ import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; +import org.elasticsearch.rest.RestInterceptor; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.action.admin.cluster.RestNodesInfoAction; import org.elasticsearch.tasks.Task; @@ -45,7 +46,6 @@ import java.util.Arrays; import java.util.List; import java.util.function.Supplier; -import java.util.function.UnaryOperator; import static java.util.Collections.emptyList; import static java.util.Collections.singletonList; @@ -362,9 +362,9 @@ class SecPlugin implements ActionPlugin, RestServerActionPlugin { } @Override - public UnaryOperator getRestHandlerInterceptor(ThreadContext threadContext) { + public RestInterceptor getRestHandlerInterceptor(ThreadContext threadContext) { if (installInterceptor) { - return UnaryOperator.identity(); + return (request, channel, targetHandler, listener) -> listener.onResponse(true); } else { return null; } @@ -372,14 +372,14 @@ public UnaryOperator getRestHandlerInterceptor(ThreadContext thread @Override public RestController getRestController( - UnaryOperator handlerWrapper, + RestInterceptor interceptor, NodeClient client, CircuitBreakerService circuitBreakerService, UsageService usageService, Tracer tracer ) { if (installController) { - return new RestController(handlerWrapper, client, circuitBreakerService, usageService, tracer); + return new RestController(interceptor, client, circuitBreakerService, usageService, tracer); } else { return null; } diff --git a/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java b/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java index 00c65437579ec..37300f1c19b1c 100644 --- a/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java +++ b/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java @@ -286,22 +286,25 @@ public void testRegisterSecondMethodWithDifferentNamedWildcard() { assertThat(exception.getMessage(), equalTo("Trying to use conflicting wildcard names for same path: wildcard1 and wildcard2")); } - public void testRestHandlerWrapper() throws Exception { + public void testRestInterceptor() throws Exception { AtomicBoolean handlerCalled = new AtomicBoolean(false); AtomicBoolean wrapperCalled = new AtomicBoolean(false); + final boolean callHandler = randomBoolean(); final RestHandler handler = (RestRequest request, RestChannel channel, NodeClient client) -> handlerCalled.set(true); final HttpServerTransport httpServerTransport = new TestHttpServerTransport(); - final RestController restController = new RestController(h -> { - assertSame(handler, h); - return (RestRequest request, RestChannel channel, NodeClient client) -> wrapperCalled.set(true); - }, client, circuitBreakerService, usageService, tracer); + final RestInterceptor interceptor = (request, channel, targetHandler, listener) -> { + assertSame(handler, targetHandler); + wrapperCalled.set(true); + listener.onResponse(callHandler); + }; + final RestController restController = new RestController(interceptor, client, circuitBreakerService, usageService, tracer); restController.registerHandler(new Route(GET, "/wrapped"), handler); RestRequest request = testRestRequest("/wrapped", "{}", XContentType.JSON); AssertingChannel channel = new AssertingChannel(request, true, RestStatus.BAD_REQUEST); restController.dispatchRequest(request, channel, client.threadPool().getThreadContext()); httpServerTransport.start(); - assertTrue(wrapperCalled.get()); - assertFalse(handlerCalled.get()); + assertThat(wrapperCalled.get(), is(true)); + assertThat(handlerCalled.get(), is(callHandler)); } public void testDispatchRequestAddsAndFreesBytesOnSuccess() { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/LocalStateCompositeXPackPlugin.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/LocalStateCompositeXPackPlugin.java index a383004c12878..1e1b68ef3f9b7 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/LocalStateCompositeXPackPlugin.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/LocalStateCompositeXPackPlugin.java @@ -83,6 +83,7 @@ import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; import org.elasticsearch.rest.RestHeaderDefinition; +import org.elasticsearch.rest.RestInterceptor; import org.elasticsearch.script.ScriptContext; import org.elasticsearch.search.internal.ShardSearchRequest; import org.elasticsearch.snapshots.Snapshot; @@ -381,10 +382,9 @@ public List getBootstrapChecks() { } @Override - public UnaryOperator getRestHandlerInterceptor(ThreadContext threadContext) { - + public RestInterceptor getRestHandlerInterceptor(ThreadContext threadContext) { // There can be only one. - List> items = filterPlugins(ActionPlugin.class).stream() + List items = filterPlugins(ActionPlugin.class).stream() .filter(RestServerActionPlugin.class::isInstance) .map(RestServerActionPlugin.class::cast) .map(p -> p.getRestHandlerInterceptor(threadContext)) diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java index a9af4b4ba104a..bea7403f9ff40 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java @@ -85,6 +85,7 @@ import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; import org.elasticsearch.rest.RestHeaderDefinition; +import org.elasticsearch.rest.RestInterceptor; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.script.ScriptService; @@ -1847,13 +1848,12 @@ protected void populatePerRequestThreadContext(RestRequest restRequest, ThreadCo } @Override - public UnaryOperator getRestHandlerInterceptor(ThreadContext threadContext) { - return handler -> new SecurityRestFilter( + public RestInterceptor getRestHandlerInterceptor(ThreadContext threadContext) { + return new SecurityRestFilter( enabled, threadContext, secondayAuthc.get(), auditTrailService.get(), - handler, operatorPrivilegesService.get() ); } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/SecurityRestFilter.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/SecurityRestFilter.java index f7d5ada9b9538..6c3c25a951744 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/SecurityRestFilter.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/SecurityRestFilter.java @@ -8,18 +8,15 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.util.Supplier; import org.elasticsearch.ElasticsearchSecurityException; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.rest.FilterRestHandler; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestHandler; +import org.elasticsearch.rest.RestInterceptor; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestRequest.Method; import org.elasticsearch.rest.RestRequestFilter; -import org.elasticsearch.rest.RestResponse; import org.elasticsearch.xpack.security.audit.AuditTrailService; import org.elasticsearch.xpack.security.authc.support.SecondaryAuthenticator; import org.elasticsearch.xpack.security.authz.restriction.WorkflowService; @@ -27,7 +24,7 @@ import static org.elasticsearch.core.Strings.format; -public class SecurityRestFilter extends FilterRestHandler implements RestHandler { +public class SecurityRestFilter implements RestInterceptor { private static final Logger logger = LogManager.getLogger(SecurityRestFilter.class); @@ -42,10 +39,8 @@ public SecurityRestFilter( ThreadContext threadContext, SecondaryAuthenticator secondaryAuthenticator, AuditTrailService auditTrailService, - RestHandler restHandler, OperatorPrivileges.OperatorPrivilegesService operatorPrivilegesService ) { - super(restHandler); this.enabled = enabled; this.threadContext = threadContext; this.secondaryAuthenticator = secondaryAuthenticator; @@ -57,57 +52,52 @@ public SecurityRestFilter( } @Override - public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) throws Exception { + public void intercept(RestRequest request, RestChannel channel, RestHandler targetHandler, ActionListener listener) + throws Exception { // requests with the OPTIONS method should be handled elsewhere, and not by calling {@code RestHandler#handleRequest} // authn is bypassed for HTTP requests with the OPTIONS method, so this sanity check prevents dispatching unauthenticated requests if (request.method() == Method.OPTIONS) { handleException( request, - channel, - new ElasticsearchSecurityException("Cannot dispatch OPTIONS request, as they are not authenticated") + new ElasticsearchSecurityException("Cannot dispatch OPTIONS request, as they are not authenticated"), + listener ); return; } if (enabled == false) { - doHandleRequest(request, channel, client); + listener.onResponse(Boolean.TRUE); return; } - final RestRequest wrappedRequest = maybeWrapRestRequest(request); + final RestRequest wrappedRequest = maybeWrapRestRequest(request, targetHandler); auditTrailService.get().authenticationSuccess(wrappedRequest); secondaryAuthenticator.authenticateAndAttachToContext(wrappedRequest, ActionListener.wrap(secondaryAuthentication -> { if (secondaryAuthentication != null) { logger.trace("Found secondary authentication {} in REST request [{}]", secondaryAuthentication, request.uri()); } - WorkflowService.resolveWorkflowAndStoreInThreadContext(getConcreteRestHandler(), threadContext); + WorkflowService.resolveWorkflowAndStoreInThreadContext(targetHandler, threadContext); - doHandleRequest(request, channel, client); - }, e -> handleException(request, channel, e))); + doHandleRequest(request, channel, targetHandler, listener); + }, e -> handleException(request, e, listener))); } - private void doHandleRequest(RestRequest request, RestChannel channel, NodeClient client) throws Exception { + private void doHandleRequest(RestRequest request, RestChannel channel, RestHandler targetHandler, ActionListener listener) { threadContext.sanitizeHeaders(); // operator privileges can short circuit to return a non-successful response - if (operatorPrivilegesService.checkRest(getConcreteRestHandler(), request, channel, threadContext)) { - try { - getDelegate().handleRequest(request, channel, client); - } catch (Exception e) { - logger.debug(() -> format("Request handling failed for REST request [%s]", request.uri()), e); - throw e; - } + if (operatorPrivilegesService.checkRest(targetHandler, request, channel, threadContext)) { + listener.onResponse(Boolean.TRUE); + } else { + // The service sends its own response if it returns `false`. + // That's kind of ugly, and it would be better if we throw an exception and let the rest controller serialize it as normal + listener.onResponse(Boolean.FALSE); } } - protected void handleException(RestRequest request, RestChannel channel, Exception e) { + protected void handleException(RestRequest request, Exception e, ActionListener listener) { logger.debug(() -> format("failed for REST request [%s]", request.uri()), e); threadContext.sanitizeHeaders(); - try { - channel.sendResponse(new RestResponse(channel, e)); - } catch (Exception inner) { - inner.addSuppressed(e); - logger.error((Supplier) () -> "failed to send failure response for uri [" + request.uri() + "]", inner); - } + listener.onFailure(e); } // for testing @@ -115,8 +105,8 @@ OperatorPrivileges.OperatorPrivilegesService getOperatorPrivilegesService() { return operatorPrivilegesService; } - private RestRequest maybeWrapRestRequest(RestRequest restRequest) { - if (getConcreteRestHandler() instanceof RestRequestFilter rrf) { + private RestRequest maybeWrapRestRequest(RestRequest restRequest, RestHandler targetHandler) { + if (targetHandler instanceof RestRequestFilter rrf) { return rrf.getFilteredRequest(restRequest); } return restRequest; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterTests.java index d3b46f5847636..0b0f8f8ddaae5 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterTests.java @@ -11,6 +11,7 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.ElasticsearchSecurityException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.settings.Settings; @@ -23,10 +24,9 @@ import org.elasticsearch.rest.RestHandler; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestRequestFilter; -import org.elasticsearch.rest.RestResponse; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.SecuritySettingsSourceField; +import org.elasticsearch.test.TestMatchers; import org.elasticsearch.test.rest.FakeRestRequest; import org.elasticsearch.transport.TransportRequest; import org.elasticsearch.xcontent.DeprecationHandler; @@ -48,7 +48,6 @@ import org.elasticsearch.xpack.security.authz.restriction.WorkflowServiceTests.TestBaseRestHandler; import org.elasticsearch.xpack.security.operator.OperatorPrivileges; import org.junit.Before; -import org.mockito.ArgumentCaptor; import org.mockito.Mockito; import java.util.Base64; @@ -72,8 +71,6 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; @@ -97,14 +94,7 @@ public void init() throws Exception { } private SecurityRestFilter getFilter(OperatorPrivileges.OperatorPrivilegesService privilegesService) { - return new SecurityRestFilter( - true, - threadContext, - secondaryAuthenticator, - new AuditTrailService(null, null), - restHandler, - privilegesService - ); + return new SecurityRestFilter(true, threadContext, secondaryAuthenticator, new AuditTrailService(null, null), privilegesService); } public void testProcess() throws Exception { @@ -119,8 +109,9 @@ public void testProcess() throws Exception { callback.onResponse(authentication); return Void.TYPE; }).when(authcService).authenticate(eq(httpRequest), anyActionListener()); - filter.handleRequest(request, channel, null); - verify(restHandler).handleRequest(request, channel, null); + PlainActionFuture future = new PlainActionFuture<>(); + filter.intercept(request, channel, restHandler, future); + assertThat(future.get(), is(Boolean.TRUE)); verifyNoMoreInteractions(channel); } @@ -150,19 +141,21 @@ public void testProcessSecondaryAuthentication() throws Exception { }).when(authcService).authenticate(eq(httpRequest), eq(false), anyActionListener()); SecurityContext securityContext = new SecurityContext(Settings.EMPTY, threadContext); - AtomicReference secondaryAuthRef = new AtomicReference<>(); - doAnswer(i -> { - secondaryAuthRef.set(securityContext.getSecondaryAuthentication()); - return null; - }).when(restHandler).handleRequest(request, channel, null); final String credentials = randomAlphaOfLengthBetween(4, 8) + ":" + randomAlphaOfLengthBetween(4, 12); threadContext.putHeader( SecondaryAuthenticator.SECONDARY_AUTH_HEADER_NAME, "Basic " + Base64.getEncoder().encodeToString(credentials.getBytes(StandardCharset.UTF_8)) ); - filter.handleRequest(request, channel, null); - verify(restHandler).handleRequest(request, channel, null); + + AtomicReference secondaryAuthRef = new AtomicReference<>(); + ActionListener listener = ActionListener.wrap(proceed -> { + assertThat(proceed, is(Boolean.TRUE)); + secondaryAuthRef.set(securityContext.getSecondaryAuthentication()); + }, ex -> { throw new RuntimeException(ex); }); + + filter.intercept(request, channel, restHandler, listener); + verifyNoMoreInteractions(channel); assertThat(secondaryAuthRef.get(), notNullValue()); @@ -170,11 +163,13 @@ public void testProcessSecondaryAuthentication() throws Exception { } public void testProcessWithSecurityDisabled() throws Exception { - filter = new SecurityRestFilter(false, threadContext, secondaryAuthenticator, mock(AuditTrailService.class), restHandler, null); + filter = new SecurityRestFilter(false, threadContext, secondaryAuthenticator, mock(AuditTrailService.class), null); assertEquals(NOOP_OPERATOR_PRIVILEGES_SERVICE, filter.getOperatorPrivilegesService()); RestRequest request = mock(RestRequest.class); - filter.handleRequest(request, channel, null); - verify(restHandler).handleRequest(request, channel, null); + + PlainActionFuture future = new PlainActionFuture<>(); + filter.intercept(request, channel, restHandler, future); + assertThat(future.get(), is(Boolean.TRUE)); verifyNoMoreInteractions(channel, authcService); } @@ -182,14 +177,14 @@ public void testProcessOptionsMethod() throws Exception { FakeRestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withMethod(RestRequest.Method.OPTIONS).build(); when(channel.request()).thenReturn(request); when(channel.newErrorBuilder()).thenReturn(JsonXContent.contentBuilder()); - filter.handleRequest(request, channel, null); + + PlainActionFuture future = new PlainActionFuture<>(); + filter.intercept(request, channel, restHandler, future); + final ElasticsearchSecurityException ex = expectThrows(ElasticsearchSecurityException.class, future::actionGet); + assertThat(ex, TestMatchers.throwableWithMessage(containsString("Cannot dispatch OPTIONS request, as they are not authenticated"))); + verifyNoMoreInteractions(restHandler); verifyNoMoreInteractions(authcService); - ArgumentCaptor responseArgumentCaptor = ArgumentCaptor.forClass(RestResponse.class); - verify(channel).sendResponse(responseArgumentCaptor.capture()); - RestResponse restResponse = responseArgumentCaptor.getValue(); - assertThat(restResponse.status(), is(RestStatus.INTERNAL_SERVER_ERROR)); - assertThat(restResponse.content().utf8ToString(), containsString("Cannot dispatch OPTIONS request, as they are not authenticated")); } public void testProcessFiltersBodyCorrectly() throws Exception { @@ -198,12 +193,9 @@ public void testProcessFiltersBodyCorrectly() throws Exception { XContentType.JSON ).build(); when(channel.request()).thenReturn(restRequest); - SetOnce handlerRequest = new SetOnce<>(); restHandler = new FilteredRestHandler() { @Override - public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) throws Exception { - handlerRequest.set(request); - } + public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) {} @Override public Set getFilteredFields() { @@ -222,28 +214,12 @@ public Set getFilteredFields() { threadContext, secondaryAuthenticator, new AuditTrailService(auditTrail, licenseState), - restHandler, NOOP_OPERATOR_PRIVILEGES_SERVICE ); - filter.handleRequest(restRequest, channel, null); - - assertEquals(restRequest, handlerRequest.get()); - assertEquals(restRequest.content(), handlerRequest.get().content()); - Map original; - try ( - var parser = XContentType.JSON.xContent() - .createParser( - NamedXContentRegistry.EMPTY, - DeprecationHandler.THROW_UNSUPPORTED_OPERATION, - handlerRequest.get().content().streamInput() - ) - ) { - original = parser.map(); - } - assertEquals(2, original.size()); - assertEquals(SecuritySettingsSourceField.TEST_PASSWORD, original.get("password")); - assertEquals("bar", original.get("foo")); + PlainActionFuture future = new PlainActionFuture<>(); + filter.intercept(restRequest, channel, restHandler, future); + assertThat(future.get(), is(Boolean.TRUE)); assertNotEquals(restRequest, auditTrailRequest.get()); assertNotEquals(restRequest.content(), auditTrailRequest.get().content()); @@ -284,7 +260,9 @@ public void testSanitizeHeaders() throws Exception { Set foundKeys = threadContext.getHeaders().keySet(); assertThat(foundKeys, hasItem(UsernamePasswordToken.BASIC_AUTH_HEADER)); - filter.handleRequest(request, channel, null); + PlainActionFuture future = new PlainActionFuture<>(); + filter.intercept(request, channel, restHandler, future); + assertThat(future.get(), is(Boolean.TRUE)); foundKeys = threadContext.getHeaders().keySet(); assertThat(foundKeys, not(hasItem(UsernamePasswordToken.BASIC_AUTH_HEADER))); @@ -296,10 +274,12 @@ public void testProcessWithWorkflow() throws Exception { restHandler = new TestBaseRestHandler(randomFrom(workflow.allowedRestHandlers())); final WorkflowService workflowService = new WorkflowService(); - filter = new SecurityRestFilter(true, threadContext, secondaryAuthenticator, new AuditTrailService(null, null), restHandler, null); + filter = new SecurityRestFilter(true, threadContext, secondaryAuthenticator, new AuditTrailService(null, null), null); RestRequest request = mock(RestRequest.class); - filter.handleRequest(request, channel, null); + PlainActionFuture future = new PlainActionFuture<>(); + filter.intercept(request, channel, restHandler, future); + assertThat(future.get(), is(Boolean.TRUE)); assertThat(WorkflowService.readWorkflowFromThreadContext(threadContext), equalTo(workflow.name())); } @@ -315,10 +295,12 @@ public void testProcessWithoutWorkflow() throws Exception { } final WorkflowService workflowService = new WorkflowService(); - filter = new SecurityRestFilter(true, threadContext, secondaryAuthenticator, new AuditTrailService(null, null), restHandler, null); + filter = new SecurityRestFilter(true, threadContext, secondaryAuthenticator, new AuditTrailService(null, null), null); RestRequest request = mock(RestRequest.class); - filter.handleRequest(request, channel, null); + PlainActionFuture future = new PlainActionFuture<>(); + filter.intercept(request, channel, restHandler, future); + assertThat(future.get(), is(Boolean.TRUE)); assertThat(WorkflowService.readWorkflowFromThreadContext(threadContext), nullValue()); } @@ -354,11 +336,13 @@ public boolean checkRest( public void maybeInterceptRequest(ThreadContext threadContext, TransportRequest request) {} }); - filter.handleRequest(request, channel, null); + PlainActionFuture future = new PlainActionFuture<>(); + filter.intercept(request, channel, restHandler, future); + if (isOperator) { - verify(restHandler).handleRequest(request, channel, null); + assertThat(future.get(), is(Boolean.TRUE)); } else { - verify(restHandler, never()).handleRequest(request, channel, null); + assertThat(future.get(), is(Boolean.FALSE)); } } }