Skip to content

Commit

Permalink
Refactor Async Search security into a common class
Browse files Browse the repository at this point in the history
This commit moves code from AsyncTaskIndexService and
DeleteAsyncResultsService into a new AsyncSearchSecurity so that all
security code is centralised.,
  • Loading branch information
tvernum committed Mar 22, 2024
1 parent a960bef commit 0d724d8
Show file tree
Hide file tree
Showing 5 changed files with 289 additions and 209 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.async;

import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.get.GetRequest;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.common.Strings;
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
import org.elasticsearch.xpack.core.security.SecurityContext;
import org.elasticsearch.xpack.core.security.action.user.HasPrivilegesAction;
import org.elasticsearch.xpack.core.security.action.user.HasPrivilegesRequest;
import org.elasticsearch.xpack.core.security.authc.Authentication;
import org.elasticsearch.xpack.core.security.authz.RoleDescriptor;
import org.elasticsearch.xpack.core.security.authz.privilege.ClusterPrivilegeResolver;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import java.util.function.Consumer;

public class AsyncSearchSecurity {

private static final FetchSourceContext FETCH_HEADERS_FIELD_CONTEXT = FetchSourceContext.of(
true,
new String[] { AsyncTaskIndexService.HEADERS_FIELD },
Strings.EMPTY_ARRAY
);

private final String indexName;
private final SecurityContext securityContext;
private final Client client;
private final OriginSettingClient clientWithOrigin;

public AsyncSearchSecurity(String indexName, SecurityContext securityContext, Client client, String origin) {
this.securityContext = securityContext;
this.client = client;
this.clientWithOrigin = new OriginSettingClient(client, origin);
this.indexName = indexName;
}

public void currentUserHasCancelTaskPrivilege(Consumer<Boolean> consumer) {
final Authentication current = securityContext.getAuthentication();
if (current != null) {
HasPrivilegesRequest req = new HasPrivilegesRequest();
req.username(current.getEffectiveSubject().getUser().principal());
req.clusterPrivileges(ClusterPrivilegeResolver.CANCEL_TASK.name());
req.indexPrivileges(new RoleDescriptor.IndicesPrivileges[] {});
req.applicationPrivileges(new RoleDescriptor.ApplicationResourcePrivileges[] {});
try {
client.execute(
HasPrivilegesAction.INSTANCE,
req,
ActionListener.wrap(resp -> consumer.accept(resp.isCompleteMatch()), exc -> consumer.accept(false))
);
} catch (Exception exc) {
consumer.accept(false);
}
} else {
consumer.accept(false);
}
}

public boolean currentUserHasAccessToTask(AsyncTask asyncTask) throws IOException {
Objects.requireNonNull(asyncTask, "Task cannot be null");
return currentUserHasAccessToTaskWithHeaders(asyncTask.getOriginHeaders());
}

public boolean currentUserHasAccessToTaskWithHeaders(Map<String, String> headers) throws IOException {
return securityContext.canIAccessResourcesCreatedWithHeaders(headers);
}

/**
* Checks if the current user can access the async search result of the original user.
*/
void ensureAuthenticatedUserCanDeleteFromIndex(AsyncExecutionId executionId, ActionListener<Void> listener) {
getTaskHeadersFromIndex(executionId, listener.map(headers -> {
if (currentUserHasAccessToTaskWithHeaders(headers)) {
return null;
} else {
throw new ResourceNotFoundException(executionId.getEncoded());
}
}));
}

private void getTaskHeadersFromIndex(AsyncExecutionId executionId, ActionListener<Map<String, String>> listener) {
GetRequest internalGet = new GetRequest(indexName).preference(executionId.getEncoded())
.id(executionId.getDocId())
.fetchSourceContext(FETCH_HEADERS_FIELD_CONTEXT);

clientWithOrigin.get(internalGet, ActionListener.wrap(get -> {
if (get.isExists() == false) {
listener.onFailure(new ResourceNotFoundException(executionId.getEncoded()));
return;
}
// Check authentication for the user
@SuppressWarnings("unchecked")
Map<String, String> headers = (Map<String, String>) get.getSource().get(AsyncTaskIndexService.HEADERS_FIELD);
listener.onResponse(headers);
}, exc -> listener.onFailure(new ResourceNotFoundException(executionId.getEncoded()))));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.TriFunction;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.bytes.BytesReference;
Expand All @@ -49,13 +48,11 @@
import org.elasticsearch.index.engine.DocumentMissingException;
import org.elasticsearch.index.engine.VersionConflictEngineException;
import org.elasticsearch.indices.SystemIndexDescriptor;
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.xcontent.DeprecationHandler;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.XPackPlugin;
Expand Down Expand Up @@ -154,9 +151,10 @@ public static SystemIndexDescriptor getSystemIndexDescriptor() {
}

private final String index;
private final ThreadContext threadContext;
private final Client client;
final AsyncSearchSecurity security;
private final Client clientWithOrigin;
private final SecurityContext securityContext;
private final NamedWriteableRegistry registry;
private final Writeable.Reader<R> reader;
private final BigArrays bigArrays;
Expand All @@ -175,8 +173,14 @@ public AsyncTaskIndexService(
BigArrays bigArrays
) {
this.index = index;
this.securityContext = new SecurityContext(clusterService.getSettings(), threadContext);
this.threadContext = threadContext;
this.client = client;
this.security = new AsyncSearchSecurity(
index,
new SecurityContext(clusterService.getSettings(), client.threadPool().getThreadContext()),
client,
origin
);
this.clientWithOrigin = new OriginSettingClient(client, origin);
this.registry = registry;
this.reader = reader;
Expand All @@ -202,11 +206,8 @@ public Client getClient() {
return client;
}

/**
* Returns the authentication information, or null if the current context has no authentication info.
**/
public SecurityContext getSecurityContext() {
return securityContext;
public AsyncSearchSecurity getSecurity() {
return security;
}

/**
Expand Down Expand Up @@ -257,8 +258,7 @@ private void indexResponse(
try {
var buffer = allocateBuffer(limitToMaxResponseSize);
listener = ActionListener.runBefore(listener, buffer::close);
final XContentBuilder source = XContentFactory.jsonBuilder(buffer)
.startObject()
final XContentBuilder source = jsonBuilder(buffer).startObject()
.field(HEADERS_FIELD, headers)
.field(EXPIRATION_TIME_FIELD, response.getExpirationTime());
if (responseHeaders != null) {
Expand All @@ -285,7 +285,7 @@ private void updateResponse(
ReleasableBytesStreamOutput buffer = null;
try {
buffer = allocateBuffer(isFailure == false);
final XContentBuilder source = XContentFactory.jsonBuilder(buffer).startObject().field(RESPONSE_HEADERS_FIELD, responseHeaders);
final XContentBuilder source = jsonBuilder(buffer).startObject().field(RESPONSE_HEADERS_FIELD, responseHeaders);
addResultFieldAndFinish(response, source);
clientWithOrigin.update(
new UpdateRequest().index(index).id(docId).doc(buffer.bytes(), source.contentType()).retryOnConflict(5),
Expand Down Expand Up @@ -399,7 +399,7 @@ public <T extends AsyncTask> T getTaskAndCheckAuthentication(
return null;
}
// Check authentication for the user
if (false == securityContext.canIAccessResourcesCreatedWithHeaders(asyncTask.getOriginHeaders())) {
if (false == security.currentUserHasAccessToTask(asyncTask)) {
throw new ResourceNotFoundException(asyncExecutionId.getEncoded() + " not found");
}
return asyncTask;
Expand Down Expand Up @@ -472,7 +472,7 @@ private R parseResponseFromIndex(
@SuppressWarnings("unchecked")
final Map<String, String> headers = (Map<String, String>) XContentParserUtils.parseFieldsValue(parser);
// check the authentication of the current user against the user that initiated the async task
if (checkAuthentication && false == securityContext.canIAccessResourcesCreatedWithHeaders(headers)) {
if (checkAuthentication && false == security.currentUserHasAccessToTaskWithHeaders(headers)) {
throw new ResourceNotFoundException(asyncExecutionId.getEncoded());
}
}
Expand All @@ -482,7 +482,7 @@ private R parseResponseFromIndex(
parser
);
if (restoreResponseHeaders) {
restoreResponseHeadersContext(securityContext.getThreadContext(), responseHeaders);
restoreResponseHeadersContext(threadContext, responseHeaders);
}
}
default -> XContentParserUtils.parseFieldsValue(parser); // consume and discard unknown fields
Expand Down Expand Up @@ -540,36 +540,6 @@ public <T extends AsyncTask, SR extends SearchStatusResponse> void retrieveStatu
}
}

private static final FetchSourceContext FETCH_HEADERS_FIELD_CONTEXT = FetchSourceContext.of(
true,
new String[] { HEADERS_FIELD },
Strings.EMPTY_ARRAY
);

/**
* Checks if the current user can access the async search result of the original user.
**/
void ensureAuthenticatedUserCanDeleteFromIndex(AsyncExecutionId executionId, ActionListener<Void> listener) {
GetRequest internalGet = new GetRequest(index).preference(executionId.getEncoded())
.id(executionId.getDocId())
.fetchSourceContext(FETCH_HEADERS_FIELD_CONTEXT);

clientWithOrigin.get(internalGet, ActionListener.wrap(get -> {
if (get.isExists() == false) {
listener.onFailure(new ResourceNotFoundException(executionId.getEncoded()));
return;
}
// Check authentication for the user
@SuppressWarnings("unchecked")
Map<String, String> headers = (Map<String, String>) get.getSource().get(HEADERS_FIELD);
if (securityContext.canIAccessResourcesCreatedWithHeaders(headers)) {
listener.onResponse(null);
} else {
listener.onFailure(new ResourceNotFoundException(executionId.getEncoded()));
}
}, exc -> listener.onFailure(new ResourceNotFoundException(executionId.getEncoded()))));
}

/**
* Decode the provided base-64 bytes into a {@link AsyncSearchResponse}.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,6 @@
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.xpack.core.security.action.user.HasPrivilegesAction;
import org.elasticsearch.xpack.core.security.action.user.HasPrivilegesRequest;
import org.elasticsearch.xpack.core.security.authc.Authentication;
import org.elasticsearch.xpack.core.security.authz.RoleDescriptor;
import org.elasticsearch.xpack.core.security.authz.privilege.ClusterPrivilegeResolver;

import java.util.function.Consumer;

Expand All @@ -29,8 +24,10 @@
*/
public class DeleteAsyncResultsService {
private static final Logger logger = LogManager.getLogger(DeleteAsyncResultsService.class);
private final TaskManager taskManager;

private final AsyncTaskIndexService<? extends AsyncResponse<?>> store;
private final AsyncSearchSecurity security;
private final TaskManager taskManager;

/**
* Creates async results service
Expand All @@ -39,8 +36,9 @@ public class DeleteAsyncResultsService {
* @param taskManager task manager
*/
public DeleteAsyncResultsService(AsyncTaskIndexService<? extends AsyncResponse<?>> store, TaskManager taskManager) {
this.taskManager = taskManager;
this.store = store;
this.security = store.getSecurity();
this.taskManager = taskManager;
}

public void deleteResponse(DeleteAsyncResultRequest request, ActionListener<AcknowledgedResponse> listener) {
Expand All @@ -52,26 +50,7 @@ public void deleteResponse(DeleteAsyncResultRequest request, ActionListener<Ackn
* delete async search submitted by another user.
*/
private void hasCancelTaskPrivilegeAsync(Consumer<Boolean> consumer) {
final Authentication current = store.getSecurityContext().getAuthentication();
if (current != null) {
HasPrivilegesRequest req = new HasPrivilegesRequest();
req.username(current.getEffectiveSubject().getUser().principal());
req.clusterPrivileges(ClusterPrivilegeResolver.CANCEL_TASK.name());
req.indexPrivileges(new RoleDescriptor.IndicesPrivileges[] {});
req.applicationPrivileges(new RoleDescriptor.ApplicationResourcePrivileges[] {});
try {
store.getClient()
.execute(
HasPrivilegesAction.INSTANCE,
req,
ActionListener.wrap(resp -> consumer.accept(resp.isCompleteMatch()), exc -> consumer.accept(false))
);
} catch (Exception exc) {
consumer.accept(false);
}
} else {
consumer.accept(false);
}
security.currentUserHasCancelTaskPrivilege(consumer);
}

private void deleteResponseAsync(
Expand All @@ -91,7 +70,7 @@ private void deleteResponseAsync(
if (hasCancelTaskPrivilege) {
deleteResponseFromIndex(searchId, false, listener);
} else {
store.ensureAuthenticatedUserCanDeleteFromIndex(
store.security.ensureAuthenticatedUserCanDeleteFromIndex(
searchId,
listener.delegateFailureAndWrap((l, res) -> deleteResponseFromIndex(searchId, false, l))
);
Expand Down
Loading

0 comments on commit 0d724d8

Please sign in to comment.