Skip to content

Commit

Permalink
Fix several structural and security issues with the OIDC implementation
Browse files Browse the repository at this point in the history
This includes:
* Properly validating ID and access tokens
* A single point of entry for initiating OAuth2 flows
* Add an OAuth2 initiation endpoint for query clients which allows use of a nonce cookie in token exchange.
* Breaking a hard dependency between OAuth2TokenExchange and OAuth2Service
* Use a hashed UUID/UUID pair in token exchange to prevent a comprimised browser from obtaining the access token
* Support trusting additional audiences in access tokens
  • Loading branch information
Nik Hodgkinson authored and dain committed Aug 19, 2021
1 parent 32354fe commit 4e9d3ab
Show file tree
Hide file tree
Showing 20 changed files with 475 additions and 325 deletions.
Expand Up @@ -26,6 +26,7 @@ public class OAuth2AuthenticationSupportModule
protected void setup(Binder binder)
{
binder.bind(OAuth2TokenExchange.class).in(Scopes.SINGLETON);
binder.bind(OAuth2TokenHandler.class).to(OAuth2TokenExchange.class).in(Scopes.SINGLETON);
jaxrsBinder(binder).bind(OAuth2TokenExchangeResource.class);
install(new OAuth2ServiceModule());
}
Expand Down
Expand Up @@ -26,7 +26,7 @@
import java.util.UUID;

import static io.trino.server.security.UserMapping.createUserMapping;
import static io.trino.server.security.oauth2.OAuth2CallbackResource.CALLBACK_ENDPOINT;
import static io.trino.server.security.oauth2.OAuth2TokenExchangeResource.getInitiateUri;
import static io.trino.server.security.oauth2.OAuth2TokenExchangeResource.getTokenUri;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
Expand All @@ -48,18 +48,23 @@ public OAuth2Authenticator(OAuth2Service service, OAuth2Config config)
@Override
protected Optional<Principal> extractPrincipalFromToken(String token)
{
return service.convertTokenToClaims(token)
.map(claims -> claims.get(principalField))
.map(String.class::cast)
.map(BasicPrincipal::new);
try {
return service.convertTokenToClaims(token)
.map(claims -> claims.get(principalField))
.map(String.class::cast)
.map(BasicPrincipal::new);
}
catch (ChallengeFailedException e) {
return Optional.empty();
}
}

@Override
protected AuthenticationException needAuthentication(ContainerRequestContext request, String message)
{
UUID authId = UUID.randomUUID();
URI redirectUri = service.startRestChallenge(request.getUriInfo().getBaseUri().resolve(CALLBACK_ENDPOINT), authId);
URI initiateUri = request.getUriInfo().getBaseUri().resolve(getInitiateUri(authId));
URI tokenUri = request.getUriInfo().getBaseUri().resolve(getTokenUri(authId));
return new AuthenticationException(message, format("Bearer x_redirect_server=\"%s\", x_token_server=\"%s\"", redirectUri, tokenUri));
return new AuthenticationException(message, format("Bearer x_redirect_server=\"%s\", x_token_server=\"%s\"", initiateUri, tokenUri));
}
}
Expand Up @@ -15,9 +15,6 @@

import io.airlift.log.Logger;
import io.trino.server.security.ResourceSecurity;
import io.trino.server.security.oauth2.OAuth2Service.OAuthResult;
import io.trino.server.ui.OAuth2WebUiInstalled;
import io.trino.server.ui.OAuthWebUiCookie;

import javax.inject.Inject;
import javax.ws.rs.CookieParam;
Expand All @@ -28,20 +25,14 @@
import javax.ws.rs.core.Context;
import javax.ws.rs.core.Cookie;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.Response.ResponseBuilder;
import javax.ws.rs.core.UriInfo;

import java.net.URI;
import java.util.Optional;
import java.util.UUID;

import static io.trino.server.security.ResourceSecurity.AccessType.PUBLIC;
import static io.trino.server.security.oauth2.NonceCookie.NONCE_COOKIE;
import static io.trino.server.security.oauth2.OAuth2CallbackResource.CALLBACK_ENDPOINT;
import static io.trino.server.ui.FormWebUiAuthenticationFilter.UI_LOCATION;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static javax.ws.rs.core.MediaType.TEXT_HTML;
import static javax.ws.rs.core.Response.Status.BAD_REQUEST;

@Path(CALLBACK_ENDPOINT)
public class OAuth2CallbackResource
Expand All @@ -51,15 +42,11 @@ public class OAuth2CallbackResource
public static final String CALLBACK_ENDPOINT = "/oauth2/callback";

private final OAuth2Service service;
private final Optional<OAuth2TokenExchange> tokenExchange;
private final boolean webUiOAuthEnabled;

@Inject
public OAuth2CallbackResource(OAuth2Service service, Optional<OAuth2TokenExchange> tokenExchange, Optional<OAuth2WebUiInstalled> webUiOAuthEnabled)
public OAuth2CallbackResource(OAuth2Service service)
{
this.service = requireNonNull(service, "service is null");
this.tokenExchange = requireNonNull(tokenExchange, "tokenExchange is null");
this.webUiOAuthEnabled = requireNonNull(webUiOAuthEnabled, "webUiOAuthEnabled is null").isPresent();
}

@ResourceSecurity(PUBLIC)
Expand All @@ -74,70 +61,21 @@ public Response callback(
@CookieParam(NONCE_COOKIE) Cookie nonce,
@Context UriInfo uriInfo)
{
Optional<UUID> authId;
try {
authId = service.getAuthId(state);
}
catch (ChallengeFailedException e) {
LOG.debug(e, "Authentication response could not be verified: state=%s", state);
return Response.ok()
.entity(service.getInternalFailureHtml("Authentication response could not be verified"))
.build();
}

// Note: the Web UI may be disabled, so REST requests can not redirect to a success or error page inside of the Web UI

if (error != null) {
LOG.debug("OAuth server returned an error: error=%s, error_description=%s, error_uri=%s, state=%s", error, errorDescription, errorUri, state);

if (tokenExchange.isPresent() && authId.isPresent()) {
tokenExchange.get().setTokenExchangeError(
authId.get(),
format("OAuth server returned an error: error=%s, error_description=%s, error_uri=%s, state=%s", error, errorDescription, errorUri, state));
}
return Response.ok()
.entity(service.getCallbackErrorHtml(error))
.build();
return service.handleOAuth2Error(state, error, errorDescription, errorUri);
}

OAuthResult result;
try {
result = service.finishChallenge(
authId,
code,
uriInfo.getBaseUri().resolve(CALLBACK_ENDPOINT),
NonceCookie.read(nonce));
requireNonNull(state, "state is null");
requireNonNull(code, "code is null");
return service.finishOAuth2Challenge(state, code, uriInfo.getBaseUri().resolve(CALLBACK_ENDPOINT), NonceCookie.read(nonce));
}
catch (ChallengeFailedException | RuntimeException e) {
catch (RuntimeException e) {
LOG.debug(e, "Authentication response could not be verified: state=%s", state);
if (tokenExchange.isPresent() && authId.isPresent()) {
tokenExchange.get().setTokenExchangeError(authId.get(), format("Authentication response could not be verified: state=%s", state));
}
return Response.ok()
return Response.status(BAD_REQUEST)
.cookie(NonceCookie.delete())
.entity(service.getInternalFailureHtml("Authentication response could not be verified"))
.build();
}

if (authId.isEmpty()) {
return Response
.seeOther(URI.create(UI_LOCATION))
.cookie(OAuthWebUiCookie.create(result.getAccessToken(), result.getTokenExpiration()), NonceCookie.delete())
.build();
}

if (tokenExchange.isEmpty()) {
LOG.debug("Token exchange is not active: state=%s", state);
return Response.ok()
.entity(service.getInternalFailureHtml("Client token exchange is not enabled"))
.build();
}

tokenExchange.get().setAccessToken(authId.get(), result.getAccessToken());

ResponseBuilder builder = Response.ok(service.getSuccessHtml());
if (webUiOAuthEnabled) {
builder.cookie(OAuthWebUiCookie.create(result.getAccessToken(), result.getTokenExpiration()));
}
return builder.build();
}
}
Expand Up @@ -14,17 +14,21 @@
package io.trino.server.security.oauth2;

import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.airlift.configuration.Config;
import io.airlift.configuration.ConfigDescription;
import io.airlift.configuration.ConfigSecuritySensitive;
import io.airlift.configuration.LegacyConfig;
import io.airlift.configuration.validation.FileExists;
import io.airlift.units.Duration;
import io.airlift.units.MinDuration;

import javax.validation.constraints.NotNull;

import java.io.File;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit;
Expand All @@ -35,14 +39,16 @@
public class OAuth2Config
{
private Optional<String> stateKey = Optional.empty();
private String issuer;
private Optional<String> accessTokenIssuer = Optional.empty();
private String authUrl;
private String tokenUrl;
private String jwksUrl;
private String clientId;
private String clientSecret;
private Optional<String> audience = Optional.empty();
private Set<String> scopes = ImmutableSet.of(OPENID_SCOPE);
private String principalField = "sub";
private List<String> additionalAudiences = Collections.emptyList();
private Duration challengeTimeout = new Duration(15, TimeUnit.MINUTES);
private Optional<String> userMappingPattern = Optional.empty();
private Optional<File> userMappingFile = Optional.empty();
Expand All @@ -60,6 +66,34 @@ public OAuth2Config setStateKey(String stateKey)
return this;
}

@NotNull
public String getIssuer()
{
return issuer;
}

@Config("http-server.authentication.oauth2.issuer")
@ConfigDescription("The required issuer of a token")
public OAuth2Config setIssuer(String issuer)
{
this.issuer = issuer;
return this;
}

@NotNull
public Optional<String> getAccessTokenIssuer()
{
return accessTokenIssuer;
}

@Config("http-server.authentication.oauth2.access-token-issuer")
@ConfigDescription("The required issuer for access tokens")
public OAuth2Config setAccessTokenIssuer(String accessTokenIssuer)
{
this.accessTokenIssuer = Optional.ofNullable(accessTokenIssuer);
return this;
}

@NotNull
public String getAuthUrl()
{
Expand Down Expand Up @@ -131,16 +165,18 @@ public OAuth2Config setClientSecret(String clientSecret)
return this;
}

public Optional<String> getAudience()
@NotNull
public List<String> getAdditionalAudiences()
{
return audience;
return additionalAudiences;
}

@Config("http-server.authentication.oauth2.audience")
@ConfigDescription("The required audience of a token")
public OAuth2Config setAudience(String audience)
@LegacyConfig("http-server.authentication.oauth2.audience")
@Config("http-server.authentication.oauth2.additional-audiences")
@ConfigDescription("Additional audiences to trust in addition to the Client ID")
public OAuth2Config setAdditionalAudiences(List<String> additionalAudiences)
{
this.audience = Optional.ofNullable(audience);
this.additionalAudiences = ImmutableList.copyOf(additionalAudiences);
return this;
}

Expand All @@ -165,6 +201,7 @@ public String getPrincipalField()
}

@Config("http-server.authentication.oauth2.principal-field")
@ConfigDescription("The claim to use as the principal")
public OAuth2Config setPrincipalField(String principalField)
{
this.principalField = principalField;
Expand Down

0 comments on commit 4e9d3ab

Please sign in to comment.