Skip to content

Commit

Permalink
Use exceptions for OperatorOnlyRegistry.checkRest
Browse files Browse the repository at this point in the history
This commit changs `OperatorOnlyRegistry.checkRest` to handle failures
via an exception rather than a return value and the use of a channel

This fits better into the way that the `SecurityRestFilter` works
(since elastic#104291) with a dedicated `RestInterceptor` interface
  • Loading branch information
tvernum committed Mar 14, 2024
1 parent 99719f2 commit 09fd0ce
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.license.PutLicenseAction;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestHandler;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.transport.TransportRequest;
Expand Down Expand Up @@ -79,8 +78,8 @@ public OperatorPrivilegesViolation check(String action, TransportRequest request
}

@Override
public OperatorPrivilegesViolation checkRest(RestHandler restHandler, RestRequest restRequest, RestChannel restChannel) {
return null; // no restrictions
public void checkRest(RestHandler restHandler, RestRequest restRequest) {
// no restrictions
}

private OperatorPrivilegesViolation checkClusterUpdateSettings(ClusterUpdateSettingsRequest request) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

package org.elasticsearch.xpack.security.operator;

import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.rest.RestHandler;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.transport.TransportRequest;
Expand All @@ -22,18 +23,22 @@ public interface OperatorOnlyRegistry {
OperatorPrivilegesViolation check(String action, TransportRequest request);

/**
* Checks to see if a given {@link RestHandler} is subject to operator-only restrictions for the REST API. Any REST API may be
* fully or partially restricted. A fully restricted REST API mandates that the implementation call restChannel.sendResponse(...) and
* return a {@link OperatorPrivilegesViolation}. A partially restricted REST API mandates that the {@link RestRequest} is marked as
* restricted so that the downstream handler can behave appropriately. For example, to restrict the REST response the implementation
* Checks to see if a given {@link RestHandler} is subject to operator-only restrictions for the REST API.
*
* Any REST API may be fully or partially restricted.
* A fully restricted REST API mandates that the implementation of this method throw an
* {@link org.elasticsearch.ElasticsearchStatusException} with an appropriate status code and error message.
*
* A partially restricted REST API mandates that the {@link RestRequest} is marked as restricted so that the downstream handler can
* behave appropriately.
* For example, to restrict the REST response the implementation
* should call {@link RestRequest#markPathRestricted(String)} so that the downstream handler can properly restrict the response
* before returning to the client. Note - a partial restriction should return null.
* before returning to the client. Note - a partial restriction should not throw an exception.
*
* @param restHandler The {@link RestHandler} to check for any restrictions
* @param restRequest The {@link RestRequest} to check for any restrictions and mark any partially restricted REST API's
* @param restChannel The {@link RestChannel} to enforce fully restricted REST API's
* @return {@link OperatorPrivilegesViolation} iff the request was fully restricted and the response has been sent back to the client.
* else returns null.
* @throws ElasticsearchStatusException if the request should be denied in its entirety (fully restricted)
*/
OperatorPrivilegesViolation checkRest(RestHandler restHandler, RestRequest restRequest, RestChannel restChannel);
void checkRest(RestHandler restHandler, RestRequest restRequest) throws ElasticsearchException;

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.action.admin.cluster.snapshots.restore.RestoreSnapshotRequest;
import org.elasticsearch.common.settings.Setting;
Expand Down Expand Up @@ -156,28 +157,40 @@ public boolean checkRest(RestHandler restHandler, RestRequest restRequest, RestC
if (false == isOperator(threadContext)) {
// Only check whether request is operator-only when user is NOT an operator
if (logger.isTraceEnabled()) {
Authentication authentication = threadContext.getTransient(AuthenticationField.AUTHENTICATION_KEY);
final User user = authentication.getEffectiveSubject().getUser();
final User user = getUser(threadContext);
logger.trace("Checking for any operator-only REST violations for user [{}] and uri [{}]", user, restRequest.uri());
}
OperatorPrivilegesViolation violation = operatorOnlyRegistry.checkRest(restHandler, restRequest, restChannel);
if (violation != null) {

try {
operatorOnlyRegistry.checkRest(restHandler, restRequest);
} catch (ElasticsearchException e) {
if (logger.isDebugEnabled()) {
Authentication authentication = threadContext.getTransient(AuthenticationField.AUTHENTICATION_KEY);
final User user = authentication.getEffectiveSubject().getUser();
logger.debug(
"Found the following operator-only violation [{}] for user [{}] and uri [{}]",
violation.message(),
user,
e.getMessage(),
getUser(threadContext),
restRequest.uri()
);
}
return false;
throw e;
} catch (Exception e) {
logger.info(
"Unexpected exception [{}] while processing operator privileges for user [{}] and uri [{}]",
e.getMessage(),
getUser(threadContext),
restRequest.uri()
);
throw e;
}
}
return true;
}

private static User getUser(ThreadContext threadContext) {
Authentication authentication = threadContext.getTransient(AuthenticationField.AUTHENTICATION_KEY);
return authentication.getEffectiveSubject().getUser();
}

public void maybeInterceptRequest(ThreadContext threadContext, TransportRequest request) {
if (request instanceof RestoreSnapshotRequest) {
logger.debug("Intercepting [{}] for operator privileges", request);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
import org.elasticsearch.plugins.ExtensiblePlugin;
import org.elasticsearch.plugins.MapperPlugin;
import org.elasticsearch.plugins.internal.RestExtension;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestHandler;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.script.ScriptService;
Expand Down Expand Up @@ -177,7 +176,7 @@ public OperatorPrivilegesViolation check(String action, TransportRequest request
}

@Override
public OperatorPrivilegesViolation checkRest(RestHandler restHandler, RestRequest restRequest, RestChannel restChannel) {
public void checkRest(RestHandler restHandler, RestRequest restRequest) {
throw new RuntimeException("boom");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.apache.logging.log4j.Level;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.action.admin.cluster.snapshots.restore.RestoreSnapshotRequest;
import org.elasticsearch.common.logging.Loggers;
Expand All @@ -32,13 +33,16 @@
import org.junit.Before;
import org.mockito.Mockito;

import static org.elasticsearch.test.TestMatchers.throwableWithMessage;
import static org.elasticsearch.xpack.security.operator.OperatorPrivileges.NOOP_OPERATOR_PRIVILEGES_SERVICE;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.notNullValue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
Expand Down Expand Up @@ -278,8 +282,13 @@ public void testCheckRest() {
ThreadContext threadContext = new ThreadContext(settings);

// not an operator
when(operatorOnlyRegistry.checkRest(restHandler, restRequest, restChannel)).thenReturn(() -> "violation!");
assertFalse(operatorPrivilegesService.checkRest(restHandler, restRequest, restChannel, threadContext));
doThrow(new ElasticsearchSecurityException("violation!")).when(operatorOnlyRegistry).checkRest(restHandler, restRequest);
final ElasticsearchException ex = expectThrows(
ElasticsearchException.class,
() -> operatorPrivilegesService.checkRest(restHandler, restRequest, restChannel, threadContext)
);
assertThat(ex, instanceOf(ElasticsearchSecurityException.class));
assertThat(ex, throwableWithMessage("violation!"));
Mockito.clearInvocations(operatorOnlyRegistry);

// is an operator
Expand Down

0 comments on commit 09fd0ce

Please sign in to comment.