Skip to content

Commit

Permalink
Improve error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
vpavic committed Dec 19, 2019
1 parent c25101c commit 6ee2da3
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 13 deletions.
Expand Up @@ -44,7 +44,7 @@ public CompletionStage<Void> handle(HttpExchange httpExchange) {
BearerToken bearerToken = this.bearerTokenExtractor.apply(httpExchange);
if (bearerToken == null) {
CompletableFuture<Void> result = new CompletableFuture<>();
result.completeExceptionally(new BearerTokenException(BearerTokenError.INVALID_REQUEST));
result.completeExceptionally(new BearerTokenException());
return result;
}
return this.authorizationContextResolver.apply(bearerToken).handle((authorizationContext, throwable) -> {
Expand Down
Expand Up @@ -4,28 +4,32 @@

public class BearerTokenError {

public static final BearerTokenError INVALID_REQUEST = new BearerTokenError(400, "invalid_request");
public static final BearerTokenError INVALID_REQUEST = BearerTokenError.of("invalid_request", 400);

public static final BearerTokenError INVALID_TOKEN = new BearerTokenError(401, "invalid_token");
public static final BearerTokenError INVALID_TOKEN = BearerTokenError.of("invalid_token", 401);

public static final BearerTokenError INSUFFICIENT_SCOPE = new BearerTokenError(403, "insufficient_scope");

private final int status;
public static final BearerTokenError INSUFFICIENT_SCOPE = BearerTokenError.of("insufficient_scope", 403);

private final String code;

public BearerTokenError(int status, String code) {
private final int status;

private BearerTokenError(String code, int status) {
Objects.requireNonNull(code, "code must not be null");
this.status = status;
this.code = code;
this.status = status;
}

public int getStatus() {
return this.status;
public static BearerTokenError of(String code, int status) {
return new BearerTokenError(code, status);
}

public String getCode() {
return this.code;
}

public int getStatus() {
return this.status;
}

}
Expand Up @@ -6,6 +6,15 @@ public class BearerTokenException extends RuntimeException {

private final BearerTokenError error;

public BearerTokenException() {
this.error = null;
}

public BearerTokenException(String message) {
super(message);
this.error = null;
}

public BearerTokenException(BearerTokenError error, String message) {
super(message);
Objects.requireNonNull(error, "error must not be null");
Expand All @@ -16,6 +25,10 @@ public BearerTokenException(BearerTokenError error) {
this(error, error.getCode());
}

public int getStatus() {
return (this.error != null) ? this.error.getStatus() : 401;
}

public BearerTokenError getError() {
return this.error;
}
Expand Down
@@ -0,0 +1,48 @@
package io.github.vpavic.bearerauth;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

public final class WwwAuthenticateBuilder {

private final BearerTokenException bearerTokenException;

private String realm;

private WwwAuthenticateBuilder(BearerTokenException bearerTokenException) {
Objects.requireNonNull(bearerTokenException, "bearerTokenException must not be null");
this.bearerTokenException = bearerTokenException;
}

public static WwwAuthenticateBuilder from(BearerTokenException bearerTokenException) {
return new WwwAuthenticateBuilder(bearerTokenException);
}

public WwwAuthenticateBuilder withRealm(String realm) {
Objects.requireNonNull(realm, "realm must not be null");
this.realm = realm;
return this;
}

public String build() {
String wwwAuthenticate = "Bearer";
List<String> attributes = new ArrayList<>();
if (this.realm != null) {
attributes.add(buildAttribute("realm", this.realm));
}
BearerTokenError error = this.bearerTokenException.getError();
if (error != null) {
attributes.add(buildAttribute("error", error.getCode()));
}
if (!attributes.isEmpty()) {
wwwAuthenticate += " " + String.join(", ", attributes);
}
return wwwAuthenticate;
}

private static String buildAttribute(String name, String value) {
return name + "=\"" + value + "\"";
}

}
@@ -0,0 +1,5 @@
package io.github.vpavic.bearerauth;

class WwwAuthenticateBuilderTests {

}
Expand Up @@ -3,8 +3,10 @@
import io.github.vpavic.bearerauth.AuthorizationContext;
import io.github.vpavic.bearerauth.BearerAuthenticationHandler;
import io.github.vpavic.bearerauth.BearerToken;
import io.github.vpavic.bearerauth.BearerTokenException;
import io.github.vpavic.bearerauth.HttpExchange;
import io.github.vpavic.bearerauth.MapAuthorizationContextResolver;
import io.github.vpavic.bearerauth.WwwAuthenticateBuilder;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
Expand Down Expand Up @@ -39,11 +41,23 @@ protected void doFilter(HttpServletRequest request, HttpServletResponse response
throws IOException, ServletException {
try {
this.bearerAuthenticationHandler.handle(new ServletHttpExchange(request)).toCompletableFuture().get();
chain.doFilter(request, response);
}
catch (ExecutionException | InterruptedException ex) {
catch (ExecutionException ex) {
Throwable cause = ex.getCause();
if (cause instanceof BearerTokenException) {
BearerTokenException bearerTokenException = (BearerTokenException) cause;
String wwwAuthenticate = WwwAuthenticateBuilder.from(bearerTokenException).build();
response.addHeader("WWW-Authenticate", wwwAuthenticate);
response.sendError(bearerTokenException.getStatus());
}
else {
throw new ServletException(ex);
}
}
catch (InterruptedException ex) {
throw new ServletException(ex);
}
chain.doFilter(request, response);
}

private static class ServletHttpExchange implements HttpExchange {
Expand Down
Expand Up @@ -3,8 +3,13 @@
import io.github.vpavic.bearerauth.AuthorizationContext;
import io.github.vpavic.bearerauth.BearerAuthenticationHandler;
import io.github.vpavic.bearerauth.BearerToken;
import io.github.vpavic.bearerauth.BearerTokenException;
import io.github.vpavic.bearerauth.HttpExchange;
import io.github.vpavic.bearerauth.MapAuthorizationContextResolver;
import io.github.vpavic.bearerauth.WwwAuthenticateBuilder;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
Expand Down Expand Up @@ -34,7 +39,14 @@ public WebFluxBearerAuthenticationFilter() {
@Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
return Mono.fromCompletionStage(this.bearerAuthenticationHandler.handle(new WebFluxHttpExchange(exchange)))
.then(chain.filter(exchange));
.then(chain.filter(exchange))
.onErrorResume(BearerTokenException.class, ex -> {
String wwwAuthenticate = WwwAuthenticateBuilder.from(ex).build();
ServerHttpResponse response = exchange.getResponse();
response.getHeaders().set(HttpHeaders.WWW_AUTHENTICATE, wwwAuthenticate);
response.setStatusCode(HttpStatus.resolve(ex.getStatus()));
return Mono.empty();
});
}

private static class WebFluxHttpExchange implements HttpExchange {
Expand Down

0 comments on commit 6ee2da3

Please sign in to comment.