diff --git a/qa/custom-rest-controller/src/javaRestTest/java/co/elastic/elasticsearch/test/CustomRestPlugin.java b/qa/custom-rest-controller/src/javaRestTest/java/co/elastic/elasticsearch/test/CustomRestPlugin.java index 37c79fe2abb0b..4fbdfa65d40ba 100644 --- a/qa/custom-rest-controller/src/javaRestTest/java/co/elastic/elasticsearch/test/CustomRestPlugin.java +++ b/qa/custom-rest-controller/src/javaRestTest/java/co/elastic/elasticsearch/test/CustomRestPlugin.java @@ -8,6 +8,7 @@ package co.elastic.elasticsearch.test; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.indices.breaker.CircuitBreakerService; @@ -18,12 +19,11 @@ import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; +import org.elasticsearch.rest.RestInterceptor; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.telemetry.tracing.Tracer; import org.elasticsearch.usage.UsageService; -import java.util.function.UnaryOperator; - public class CustomRestPlugin extends Plugin implements RestServerActionPlugin { private static final Logger logger = LogManager.getLogger(CustomRestPlugin.class); @@ -35,34 +35,33 @@ private static void echoHeader(String name, RestRequest request, ThreadContext t } } - public static class CustomInterceptor implements RestHandler { + public static class CustomInterceptor implements RestInterceptor { private final ThreadContext threadContext; - private final RestHandler delegate; - public CustomInterceptor(ThreadContext threadContext, RestHandler delegate) { + public CustomInterceptor(ThreadContext threadContext) { this.threadContext = threadContext; - this.delegate = delegate; } @Override - public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) throws Exception { + public void intercept(RestRequest request, RestChannel channel, RestHandler targetHandler, ActionListener listener) + throws Exception { logger.info("intercept request {} {}", request.method(), request.uri()); echoHeader("x-test-interceptor", request, threadContext); - delegate.handleRequest(request, channel, client); + listener.onResponse(Boolean.TRUE); } } public static class CustomController extends RestController { public CustomController( - UnaryOperator handlerWrapper, + RestInterceptor interceptor, NodeClient client, CircuitBreakerService circuitBreakerService, UsageService usageService, Tracer tracer ) { - super(handlerWrapper, client, circuitBreakerService, usageService, tracer); + super(interceptor, client, circuitBreakerService, usageService, tracer); } @Override @@ -74,19 +73,19 @@ public void dispatchRequest(RestRequest request, RestChannel channel, ThreadCont } @Override - public UnaryOperator getRestHandlerInterceptor(ThreadContext threadContext) { - return handler -> new CustomInterceptor(threadContext, handler); + public RestInterceptor getRestHandlerInterceptor(ThreadContext threadContext) { + return new CustomInterceptor(threadContext); } @Override public RestController getRestController( - UnaryOperator handlerWrapper, + RestInterceptor interceptor, NodeClient client, CircuitBreakerService circuitBreakerService, UsageService usageService, Tracer tracer ) { - return new CustomController(handlerWrapper, client, circuitBreakerService, usageService, tracer); + return new CustomController(interceptor, client, circuitBreakerService, usageService, tracer); } } diff --git a/server/src/main/java/org/elasticsearch/action/ActionModule.java b/server/src/main/java/org/elasticsearch/action/ActionModule.java index dd70dc65b853b..3e6ec1bde7448 100644 --- a/server/src/main/java/org/elasticsearch/action/ActionModule.java +++ b/server/src/main/java/org/elasticsearch/action/ActionModule.java @@ -264,6 +264,7 @@ import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; import org.elasticsearch.rest.RestHeaderDefinition; +import org.elasticsearch.rest.RestInterceptor; import org.elasticsearch.rest.RestUtils; import org.elasticsearch.rest.action.RestFieldCapabilitiesAction; import org.elasticsearch.rest.action.admin.cluster.RestAddVotingConfigExclusionAction; @@ -425,7 +426,6 @@ import java.util.function.Function; import java.util.function.Predicate; import java.util.function.Supplier; -import java.util.function.UnaryOperator; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -501,7 +501,7 @@ public ActionModule( new RestHeaderDefinition(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER, false) ) ).collect(Collectors.toSet()); - UnaryOperator restInterceptor = getRestServerComponent( + final RestInterceptor restInterceptor = getRestServerComponent( "REST interceptor", actionPlugins, restPlugin -> restPlugin.getRestHandlerInterceptor(threadPool.getThreadContext()) diff --git a/server/src/main/java/org/elasticsearch/plugins/interceptor/RestServerActionPlugin.java b/server/src/main/java/org/elasticsearch/plugins/interceptor/RestServerActionPlugin.java index 35badffe0b3aa..44653dcf8b5fe 100644 --- a/server/src/main/java/org/elasticsearch/plugins/interceptor/RestServerActionPlugin.java +++ b/server/src/main/java/org/elasticsearch/plugins/interceptor/RestServerActionPlugin.java @@ -14,7 +14,7 @@ import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.rest.RestController; -import org.elasticsearch.rest.RestHandler; +import org.elasticsearch.rest.RestInterceptor; import org.elasticsearch.telemetry.tracing.Tracer; import org.elasticsearch.usage.UsageService; @@ -46,7 +46,7 @@ public interface RestServerActionPlugin extends ActionPlugin { * * Note: Only one installed plugin may implement a rest interceptor. */ - UnaryOperator getRestHandlerInterceptor(ThreadContext threadContext); + RestInterceptor getRestHandlerInterceptor(ThreadContext threadContext); /** * Returns a replacement {@link RestController} to be used in the server. @@ -54,7 +54,7 @@ public interface RestServerActionPlugin extends ActionPlugin { */ @Nullable default RestController getRestController( - @Nullable UnaryOperator handlerWrapper, + @Nullable RestInterceptor interceptor, NodeClient client, CircuitBreakerService circuitBreakerService, UsageService usageService, diff --git a/server/src/main/java/org/elasticsearch/rest/RestController.java b/server/src/main/java/org/elasticsearch/rest/RestController.java index 6a5d6f99df64b..56a70e3bb2ab4 100644 --- a/server/src/main/java/org/elasticsearch/rest/RestController.java +++ b/server/src/main/java/org/elasticsearch/rest/RestController.java @@ -13,6 +13,8 @@ import org.apache.logging.log4j.Logger; import org.apache.lucene.util.BytesRef; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.common.Strings; import org.elasticsearch.common.breaker.CircuitBreaker; @@ -57,7 +59,6 @@ import java.util.TreeMap; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Supplier; -import java.util.function.UnaryOperator; import static org.elasticsearch.indices.SystemIndices.EXTERNAL_SYSTEM_INDEX_ACCESS_CONTROL_HEADER_KEY; import static org.elasticsearch.indices.SystemIndices.SYSTEM_INDEX_ACCESS_CONTROL_HEADER_KEY; @@ -95,7 +96,7 @@ public class RestController implements HttpServerTransport.Dispatcher { private final PathTrie handlers = new PathTrie<>(RestUtils.REST_DECODER); - private final UnaryOperator handlerWrapper; + private final RestInterceptor interceptor; private final NodeClient client; @@ -107,7 +108,7 @@ public class RestController implements HttpServerTransport.Dispatcher { private final ServerlessApiProtections apiProtections; public RestController( - UnaryOperator handlerWrapper, + RestInterceptor restInterceptor, NodeClient client, CircuitBreakerService circuitBreakerService, UsageService usageService, @@ -115,10 +116,10 @@ public RestController( ) { this.usageService = usageService; this.tracer = tracer; - if (handlerWrapper == null) { - handlerWrapper = h -> h; // passthrough if no wrapper set + if (restInterceptor == null) { + restInterceptor = (request, channel, targetHandler, listener) -> listener.onResponse(Boolean.TRUE); } - this.handlerWrapper = handlerWrapper; + this.interceptor = restInterceptor; this.client = client; this.circuitBreakerService = circuitBreakerService; registerHandlerNoWrap(RestRequest.Method.GET, "/favicon.ico", RestApiVersion.current(), new RestFavIconHandler()); @@ -264,7 +265,7 @@ protected void registerHandler(RestRequest.Method method, String path, RestApiVe if (handler instanceof BaseRestHandler) { usageService.addRestHandler((BaseRestHandler) handler); } - registerHandlerNoWrap(method, path, version, handlerWrapper.apply(handler)); + registerHandlerNoWrap(method, path, version, handler); } private void registerHandlerNoWrap(RestRequest.Method method, String path, RestApiVersion version, RestHandler handler) { @@ -325,7 +326,7 @@ public void dispatchRequest(RestRequest request, RestChannel channel, ThreadCont tryAllHandlers(request, channel, threadContext); } catch (Exception e) { try { - channel.sendResponse(new RestResponse(channel, e)); + sendFailure(channel, e); } catch (Exception inner) { inner.addSuppressed(e); logger.error(() -> "failed to send failure response for uri [" + request.uri() + "]", inner); @@ -348,7 +349,7 @@ public void dispatchBadRequest(final RestChannel channel, final ThreadContext th // unless it's a http headers validation error, we consider any exceptions encountered so far during request processing // to be a problem of invalid/malformed request (hence the RestStatus#BAD_REQEST (400) HTTP response code) if (e instanceof HttpHeadersValidationException) { - channel.sendResponse(new RestResponse(channel, (Exception) e.getCause())); + sendFailure(channel, (Exception) e.getCause()); } else { channel.sendResponse(new RestResponse(channel, BAD_REQUEST, e)); } @@ -438,12 +439,44 @@ private void dispatchRequest( } else { threadContext.putHeader(SYSTEM_INDEX_ACCESS_CONTROL_HEADER_KEY, Boolean.TRUE.toString()); } - handler.handleRequest(request, responseChannel, client); + final var finalChannel = responseChannel; + this.interceptor.intercept(request, responseChannel, handler.getConcreteRestHandler(), new ActionListener<>() { + @Override + public void onResponse(Boolean processRequest) { + if (processRequest) { + try { + validateRequest(request, handler, client); + handler.handleRequest(request, finalChannel, client); + } catch (Exception e) { + onFailure(e); + } + } + } + + @Override + public void onFailure(Exception e) { + try { + sendFailure(finalChannel, e); + } catch (IOException ex) { + logger.info("Failed to send error [{}] to HTTP client", ex.toString()); + } + } + }); } catch (Exception e) { - responseChannel.sendResponse(new RestResponse(responseChannel, e)); + sendFailure(responseChannel, e); } } + /** + * Validates that the request should be allowed. Throws an exception if the request should be rejected. + */ + @SuppressWarnings("unused") + protected void validateRequest(RestRequest request, RestHandler handler, NodeClient client) throws ElasticsearchStatusException {} + + private static void sendFailure(RestChannel responseChannel, Exception e) throws IOException { + responseChannel.sendResponse(new RestResponse(responseChannel, e)); + } + /** * in order to prevent CSRF we have to reject all media types that are from a browser safelist * see https://fetch.spec.whatwg.org/#cors-safelisted-request-header @@ -691,7 +724,7 @@ public static void handleBadRequest(String uri, RestRequest.Method method, RestC public static void handleServerlessRequestToProtectedResource(String uri, RestRequest.Method method, RestChannel channel) throws IOException { String msg = "uri [" + uri + "] with method [" + method + "] exists but is not available when running in serverless mode"; - channel.sendResponse(new RestResponse(channel, new ApiNotAvailableException(msg))); + sendFailure(channel, new ApiNotAvailableException(msg)); } /** diff --git a/server/src/main/java/org/elasticsearch/rest/RestInterceptor.java b/server/src/main/java/org/elasticsearch/rest/RestInterceptor.java new file mode 100644 index 0000000000000..dd0d444073040 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/rest/RestInterceptor.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.rest; + +import org.elasticsearch.action.ActionListener; + +/** + * Wraps the execution of a {@link RestHandler} + */ +@FunctionalInterface +public interface RestInterceptor { + + /** + * @param listener The interceptor responds with {@code True} if the handler should be called, + * or {@code False} if the request has been entirely handled by the interceptor. + * In the case of {@link ActionListener#onFailure(Exception)}, the target handler + * will not be called, the request will be treated as unhandled, and the regular + * rest exception handling will be performed + */ + void intercept(RestRequest request, RestChannel channel, RestHandler targetHandler, ActionListener listener) throws Exception; +} diff --git a/server/src/test/java/org/elasticsearch/action/ActionModuleTests.java b/server/src/test/java/org/elasticsearch/action/ActionModuleTests.java index a076537bb7351..291db437fcc25 100644 --- a/server/src/test/java/org/elasticsearch/action/ActionModuleTests.java +++ b/server/src/test/java/org/elasticsearch/action/ActionModuleTests.java @@ -30,6 +30,7 @@ import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; +import org.elasticsearch.rest.RestInterceptor; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.action.admin.cluster.RestNodesInfoAction; import org.elasticsearch.tasks.Task; @@ -45,7 +46,6 @@ import java.util.Arrays; import java.util.List; import java.util.function.Supplier; -import java.util.function.UnaryOperator; import static java.util.Collections.emptyList; import static java.util.Collections.singletonList; @@ -362,9 +362,9 @@ class SecPlugin implements ActionPlugin, RestServerActionPlugin { } @Override - public UnaryOperator getRestHandlerInterceptor(ThreadContext threadContext) { + public RestInterceptor getRestHandlerInterceptor(ThreadContext threadContext) { if (installInterceptor) { - return UnaryOperator.identity(); + return (request, channel, targetHandler, listener) -> listener.onResponse(true); } else { return null; } @@ -372,14 +372,14 @@ public UnaryOperator getRestHandlerInterceptor(ThreadContext thread @Override public RestController getRestController( - UnaryOperator handlerWrapper, + RestInterceptor interceptor, NodeClient client, CircuitBreakerService circuitBreakerService, UsageService usageService, Tracer tracer ) { if (installController) { - return new RestController(handlerWrapper, client, circuitBreakerService, usageService, tracer); + return new RestController(interceptor, client, circuitBreakerService, usageService, tracer); } else { return null; } diff --git a/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java b/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java index 00c65437579ec..37300f1c19b1c 100644 --- a/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java +++ b/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java @@ -286,22 +286,25 @@ public void testRegisterSecondMethodWithDifferentNamedWildcard() { assertThat(exception.getMessage(), equalTo("Trying to use conflicting wildcard names for same path: wildcard1 and wildcard2")); } - public void testRestHandlerWrapper() throws Exception { + public void testRestInterceptor() throws Exception { AtomicBoolean handlerCalled = new AtomicBoolean(false); AtomicBoolean wrapperCalled = new AtomicBoolean(false); + final boolean callHandler = randomBoolean(); final RestHandler handler = (RestRequest request, RestChannel channel, NodeClient client) -> handlerCalled.set(true); final HttpServerTransport httpServerTransport = new TestHttpServerTransport(); - final RestController restController = new RestController(h -> { - assertSame(handler, h); - return (RestRequest request, RestChannel channel, NodeClient client) -> wrapperCalled.set(true); - }, client, circuitBreakerService, usageService, tracer); + final RestInterceptor interceptor = (request, channel, targetHandler, listener) -> { + assertSame(handler, targetHandler); + wrapperCalled.set(true); + listener.onResponse(callHandler); + }; + final RestController restController = new RestController(interceptor, client, circuitBreakerService, usageService, tracer); restController.registerHandler(new Route(GET, "/wrapped"), handler); RestRequest request = testRestRequest("/wrapped", "{}", XContentType.JSON); AssertingChannel channel = new AssertingChannel(request, true, RestStatus.BAD_REQUEST); restController.dispatchRequest(request, channel, client.threadPool().getThreadContext()); httpServerTransport.start(); - assertTrue(wrapperCalled.get()); - assertFalse(handlerCalled.get()); + assertThat(wrapperCalled.get(), is(true)); + assertThat(handlerCalled.get(), is(callHandler)); } public void testDispatchRequestAddsAndFreesBytesOnSuccess() { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/LocalStateCompositeXPackPlugin.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/LocalStateCompositeXPackPlugin.java index a383004c12878..1e1b68ef3f9b7 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/LocalStateCompositeXPackPlugin.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/LocalStateCompositeXPackPlugin.java @@ -83,6 +83,7 @@ import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; import org.elasticsearch.rest.RestHeaderDefinition; +import org.elasticsearch.rest.RestInterceptor; import org.elasticsearch.script.ScriptContext; import org.elasticsearch.search.internal.ShardSearchRequest; import org.elasticsearch.snapshots.Snapshot; @@ -381,10 +382,9 @@ public List getBootstrapChecks() { } @Override - public UnaryOperator getRestHandlerInterceptor(ThreadContext threadContext) { - + public RestInterceptor getRestHandlerInterceptor(ThreadContext threadContext) { // There can be only one. - List> items = filterPlugins(ActionPlugin.class).stream() + List items = filterPlugins(ActionPlugin.class).stream() .filter(RestServerActionPlugin.class::isInstance) .map(RestServerActionPlugin.class::cast) .map(p -> p.getRestHandlerInterceptor(threadContext)) diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java index a9af4b4ba104a..bea7403f9ff40 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java @@ -85,6 +85,7 @@ import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; import org.elasticsearch.rest.RestHeaderDefinition; +import org.elasticsearch.rest.RestInterceptor; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.script.ScriptService; @@ -1847,13 +1848,12 @@ protected void populatePerRequestThreadContext(RestRequest restRequest, ThreadCo } @Override - public UnaryOperator getRestHandlerInterceptor(ThreadContext threadContext) { - return handler -> new SecurityRestFilter( + public RestInterceptor getRestHandlerInterceptor(ThreadContext threadContext) { + return new SecurityRestFilter( enabled, threadContext, secondayAuthc.get(), auditTrailService.get(), - handler, operatorPrivilegesService.get() ); } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/SecurityRestFilter.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/SecurityRestFilter.java index f7d5ada9b9538..6c3c25a951744 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/SecurityRestFilter.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/SecurityRestFilter.java @@ -8,18 +8,15 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.util.Supplier; import org.elasticsearch.ElasticsearchSecurityException; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.rest.FilterRestHandler; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestHandler; +import org.elasticsearch.rest.RestInterceptor; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestRequest.Method; import org.elasticsearch.rest.RestRequestFilter; -import org.elasticsearch.rest.RestResponse; import org.elasticsearch.xpack.security.audit.AuditTrailService; import org.elasticsearch.xpack.security.authc.support.SecondaryAuthenticator; import org.elasticsearch.xpack.security.authz.restriction.WorkflowService; @@ -27,7 +24,7 @@ import static org.elasticsearch.core.Strings.format; -public class SecurityRestFilter extends FilterRestHandler implements RestHandler { +public class SecurityRestFilter implements RestInterceptor { private static final Logger logger = LogManager.getLogger(SecurityRestFilter.class); @@ -42,10 +39,8 @@ public SecurityRestFilter( ThreadContext threadContext, SecondaryAuthenticator secondaryAuthenticator, AuditTrailService auditTrailService, - RestHandler restHandler, OperatorPrivileges.OperatorPrivilegesService operatorPrivilegesService ) { - super(restHandler); this.enabled = enabled; this.threadContext = threadContext; this.secondaryAuthenticator = secondaryAuthenticator; @@ -57,57 +52,52 @@ public SecurityRestFilter( } @Override - public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) throws Exception { + public void intercept(RestRequest request, RestChannel channel, RestHandler targetHandler, ActionListener listener) + throws Exception { // requests with the OPTIONS method should be handled elsewhere, and not by calling {@code RestHandler#handleRequest} // authn is bypassed for HTTP requests with the OPTIONS method, so this sanity check prevents dispatching unauthenticated requests if (request.method() == Method.OPTIONS) { handleException( request, - channel, - new ElasticsearchSecurityException("Cannot dispatch OPTIONS request, as they are not authenticated") + new ElasticsearchSecurityException("Cannot dispatch OPTIONS request, as they are not authenticated"), + listener ); return; } if (enabled == false) { - doHandleRequest(request, channel, client); + listener.onResponse(Boolean.TRUE); return; } - final RestRequest wrappedRequest = maybeWrapRestRequest(request); + final RestRequest wrappedRequest = maybeWrapRestRequest(request, targetHandler); auditTrailService.get().authenticationSuccess(wrappedRequest); secondaryAuthenticator.authenticateAndAttachToContext(wrappedRequest, ActionListener.wrap(secondaryAuthentication -> { if (secondaryAuthentication != null) { logger.trace("Found secondary authentication {} in REST request [{}]", secondaryAuthentication, request.uri()); } - WorkflowService.resolveWorkflowAndStoreInThreadContext(getConcreteRestHandler(), threadContext); + WorkflowService.resolveWorkflowAndStoreInThreadContext(targetHandler, threadContext); - doHandleRequest(request, channel, client); - }, e -> handleException(request, channel, e))); + doHandleRequest(request, channel, targetHandler, listener); + }, e -> handleException(request, e, listener))); } - private void doHandleRequest(RestRequest request, RestChannel channel, NodeClient client) throws Exception { + private void doHandleRequest(RestRequest request, RestChannel channel, RestHandler targetHandler, ActionListener listener) { threadContext.sanitizeHeaders(); // operator privileges can short circuit to return a non-successful response - if (operatorPrivilegesService.checkRest(getConcreteRestHandler(), request, channel, threadContext)) { - try { - getDelegate().handleRequest(request, channel, client); - } catch (Exception e) { - logger.debug(() -> format("Request handling failed for REST request [%s]", request.uri()), e); - throw e; - } + if (operatorPrivilegesService.checkRest(targetHandler, request, channel, threadContext)) { + listener.onResponse(Boolean.TRUE); + } else { + // The service sends its own response if it returns `false`. + // That's kind of ugly, and it would be better if we throw an exception and let the rest controller serialize it as normal + listener.onResponse(Boolean.FALSE); } } - protected void handleException(RestRequest request, RestChannel channel, Exception e) { + protected void handleException(RestRequest request, Exception e, ActionListener listener) { logger.debug(() -> format("failed for REST request [%s]", request.uri()), e); threadContext.sanitizeHeaders(); - try { - channel.sendResponse(new RestResponse(channel, e)); - } catch (Exception inner) { - inner.addSuppressed(e); - logger.error((Supplier) () -> "failed to send failure response for uri [" + request.uri() + "]", inner); - } + listener.onFailure(e); } // for testing @@ -115,8 +105,8 @@ OperatorPrivileges.OperatorPrivilegesService getOperatorPrivilegesService() { return operatorPrivilegesService; } - private RestRequest maybeWrapRestRequest(RestRequest restRequest) { - if (getConcreteRestHandler() instanceof RestRequestFilter rrf) { + private RestRequest maybeWrapRestRequest(RestRequest restRequest, RestHandler targetHandler) { + if (targetHandler instanceof RestRequestFilter rrf) { return rrf.getFilteredRequest(restRequest); } return restRequest; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterTests.java index d3b46f5847636..0b0f8f8ddaae5 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterTests.java @@ -11,6 +11,7 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.ElasticsearchSecurityException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.settings.Settings; @@ -23,10 +24,9 @@ import org.elasticsearch.rest.RestHandler; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestRequestFilter; -import org.elasticsearch.rest.RestResponse; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.SecuritySettingsSourceField; +import org.elasticsearch.test.TestMatchers; import org.elasticsearch.test.rest.FakeRestRequest; import org.elasticsearch.transport.TransportRequest; import org.elasticsearch.xcontent.DeprecationHandler; @@ -48,7 +48,6 @@ import org.elasticsearch.xpack.security.authz.restriction.WorkflowServiceTests.TestBaseRestHandler; import org.elasticsearch.xpack.security.operator.OperatorPrivileges; import org.junit.Before; -import org.mockito.ArgumentCaptor; import org.mockito.Mockito; import java.util.Base64; @@ -72,8 +71,6 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; @@ -97,14 +94,7 @@ public void init() throws Exception { } private SecurityRestFilter getFilter(OperatorPrivileges.OperatorPrivilegesService privilegesService) { - return new SecurityRestFilter( - true, - threadContext, - secondaryAuthenticator, - new AuditTrailService(null, null), - restHandler, - privilegesService - ); + return new SecurityRestFilter(true, threadContext, secondaryAuthenticator, new AuditTrailService(null, null), privilegesService); } public void testProcess() throws Exception { @@ -119,8 +109,9 @@ public void testProcess() throws Exception { callback.onResponse(authentication); return Void.TYPE; }).when(authcService).authenticate(eq(httpRequest), anyActionListener()); - filter.handleRequest(request, channel, null); - verify(restHandler).handleRequest(request, channel, null); + PlainActionFuture future = new PlainActionFuture<>(); + filter.intercept(request, channel, restHandler, future); + assertThat(future.get(), is(Boolean.TRUE)); verifyNoMoreInteractions(channel); } @@ -150,19 +141,21 @@ public void testProcessSecondaryAuthentication() throws Exception { }).when(authcService).authenticate(eq(httpRequest), eq(false), anyActionListener()); SecurityContext securityContext = new SecurityContext(Settings.EMPTY, threadContext); - AtomicReference secondaryAuthRef = new AtomicReference<>(); - doAnswer(i -> { - secondaryAuthRef.set(securityContext.getSecondaryAuthentication()); - return null; - }).when(restHandler).handleRequest(request, channel, null); final String credentials = randomAlphaOfLengthBetween(4, 8) + ":" + randomAlphaOfLengthBetween(4, 12); threadContext.putHeader( SecondaryAuthenticator.SECONDARY_AUTH_HEADER_NAME, "Basic " + Base64.getEncoder().encodeToString(credentials.getBytes(StandardCharset.UTF_8)) ); - filter.handleRequest(request, channel, null); - verify(restHandler).handleRequest(request, channel, null); + + AtomicReference secondaryAuthRef = new AtomicReference<>(); + ActionListener listener = ActionListener.wrap(proceed -> { + assertThat(proceed, is(Boolean.TRUE)); + secondaryAuthRef.set(securityContext.getSecondaryAuthentication()); + }, ex -> { throw new RuntimeException(ex); }); + + filter.intercept(request, channel, restHandler, listener); + verifyNoMoreInteractions(channel); assertThat(secondaryAuthRef.get(), notNullValue()); @@ -170,11 +163,13 @@ public void testProcessSecondaryAuthentication() throws Exception { } public void testProcessWithSecurityDisabled() throws Exception { - filter = new SecurityRestFilter(false, threadContext, secondaryAuthenticator, mock(AuditTrailService.class), restHandler, null); + filter = new SecurityRestFilter(false, threadContext, secondaryAuthenticator, mock(AuditTrailService.class), null); assertEquals(NOOP_OPERATOR_PRIVILEGES_SERVICE, filter.getOperatorPrivilegesService()); RestRequest request = mock(RestRequest.class); - filter.handleRequest(request, channel, null); - verify(restHandler).handleRequest(request, channel, null); + + PlainActionFuture future = new PlainActionFuture<>(); + filter.intercept(request, channel, restHandler, future); + assertThat(future.get(), is(Boolean.TRUE)); verifyNoMoreInteractions(channel, authcService); } @@ -182,14 +177,14 @@ public void testProcessOptionsMethod() throws Exception { FakeRestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withMethod(RestRequest.Method.OPTIONS).build(); when(channel.request()).thenReturn(request); when(channel.newErrorBuilder()).thenReturn(JsonXContent.contentBuilder()); - filter.handleRequest(request, channel, null); + + PlainActionFuture future = new PlainActionFuture<>(); + filter.intercept(request, channel, restHandler, future); + final ElasticsearchSecurityException ex = expectThrows(ElasticsearchSecurityException.class, future::actionGet); + assertThat(ex, TestMatchers.throwableWithMessage(containsString("Cannot dispatch OPTIONS request, as they are not authenticated"))); + verifyNoMoreInteractions(restHandler); verifyNoMoreInteractions(authcService); - ArgumentCaptor responseArgumentCaptor = ArgumentCaptor.forClass(RestResponse.class); - verify(channel).sendResponse(responseArgumentCaptor.capture()); - RestResponse restResponse = responseArgumentCaptor.getValue(); - assertThat(restResponse.status(), is(RestStatus.INTERNAL_SERVER_ERROR)); - assertThat(restResponse.content().utf8ToString(), containsString("Cannot dispatch OPTIONS request, as they are not authenticated")); } public void testProcessFiltersBodyCorrectly() throws Exception { @@ -198,12 +193,9 @@ public void testProcessFiltersBodyCorrectly() throws Exception { XContentType.JSON ).build(); when(channel.request()).thenReturn(restRequest); - SetOnce handlerRequest = new SetOnce<>(); restHandler = new FilteredRestHandler() { @Override - public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) throws Exception { - handlerRequest.set(request); - } + public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) {} @Override public Set getFilteredFields() { @@ -222,28 +214,12 @@ public Set getFilteredFields() { threadContext, secondaryAuthenticator, new AuditTrailService(auditTrail, licenseState), - restHandler, NOOP_OPERATOR_PRIVILEGES_SERVICE ); - filter.handleRequest(restRequest, channel, null); - - assertEquals(restRequest, handlerRequest.get()); - assertEquals(restRequest.content(), handlerRequest.get().content()); - Map original; - try ( - var parser = XContentType.JSON.xContent() - .createParser( - NamedXContentRegistry.EMPTY, - DeprecationHandler.THROW_UNSUPPORTED_OPERATION, - handlerRequest.get().content().streamInput() - ) - ) { - original = parser.map(); - } - assertEquals(2, original.size()); - assertEquals(SecuritySettingsSourceField.TEST_PASSWORD, original.get("password")); - assertEquals("bar", original.get("foo")); + PlainActionFuture future = new PlainActionFuture<>(); + filter.intercept(restRequest, channel, restHandler, future); + assertThat(future.get(), is(Boolean.TRUE)); assertNotEquals(restRequest, auditTrailRequest.get()); assertNotEquals(restRequest.content(), auditTrailRequest.get().content()); @@ -284,7 +260,9 @@ public void testSanitizeHeaders() throws Exception { Set foundKeys = threadContext.getHeaders().keySet(); assertThat(foundKeys, hasItem(UsernamePasswordToken.BASIC_AUTH_HEADER)); - filter.handleRequest(request, channel, null); + PlainActionFuture future = new PlainActionFuture<>(); + filter.intercept(request, channel, restHandler, future); + assertThat(future.get(), is(Boolean.TRUE)); foundKeys = threadContext.getHeaders().keySet(); assertThat(foundKeys, not(hasItem(UsernamePasswordToken.BASIC_AUTH_HEADER))); @@ -296,10 +274,12 @@ public void testProcessWithWorkflow() throws Exception { restHandler = new TestBaseRestHandler(randomFrom(workflow.allowedRestHandlers())); final WorkflowService workflowService = new WorkflowService(); - filter = new SecurityRestFilter(true, threadContext, secondaryAuthenticator, new AuditTrailService(null, null), restHandler, null); + filter = new SecurityRestFilter(true, threadContext, secondaryAuthenticator, new AuditTrailService(null, null), null); RestRequest request = mock(RestRequest.class); - filter.handleRequest(request, channel, null); + PlainActionFuture future = new PlainActionFuture<>(); + filter.intercept(request, channel, restHandler, future); + assertThat(future.get(), is(Boolean.TRUE)); assertThat(WorkflowService.readWorkflowFromThreadContext(threadContext), equalTo(workflow.name())); } @@ -315,10 +295,12 @@ public void testProcessWithoutWorkflow() throws Exception { } final WorkflowService workflowService = new WorkflowService(); - filter = new SecurityRestFilter(true, threadContext, secondaryAuthenticator, new AuditTrailService(null, null), restHandler, null); + filter = new SecurityRestFilter(true, threadContext, secondaryAuthenticator, new AuditTrailService(null, null), null); RestRequest request = mock(RestRequest.class); - filter.handleRequest(request, channel, null); + PlainActionFuture future = new PlainActionFuture<>(); + filter.intercept(request, channel, restHandler, future); + assertThat(future.get(), is(Boolean.TRUE)); assertThat(WorkflowService.readWorkflowFromThreadContext(threadContext), nullValue()); } @@ -354,11 +336,13 @@ public boolean checkRest( public void maybeInterceptRequest(ThreadContext threadContext, TransportRequest request) {} }); - filter.handleRequest(request, channel, null); + PlainActionFuture future = new PlainActionFuture<>(); + filter.intercept(request, channel, restHandler, future); + if (isOperator) { - verify(restHandler).handleRequest(request, channel, null); + assertThat(future.get(), is(Boolean.TRUE)); } else { - verify(restHandler, never()).handleRequest(request, channel, null); + assertThat(future.get(), is(Boolean.FALSE)); } } }