Skip to content

Commit

Permalink
ServerOAuth2AuthorizedClientExchangeFilterFunction clientRegistrationId
Browse files Browse the repository at this point in the history
Issue: gh-4921
  • Loading branch information
rwinch committed Sep 7, 2018
1 parent 28537fa commit 158b8aa
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 18 deletions.
Expand Up @@ -27,12 +27,15 @@
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2ClientException;
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.WebClientReactiveClientCredentialsTokenResponseClient;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.util.Assert;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.ClientRequest;
Expand Down Expand Up @@ -75,18 +78,25 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
* The request attribute name used to locate the {@link org.springframework.web.server.ServerWebExchange}.
*/
private static final String SERVER_WEB_EXCHANGE_ATTR_NAME = ServerWebExchange.class.getName();
public static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser",

private static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser",
AuthorityUtils.createAuthorityList("ROLE_USER"));

private Clock clock = Clock.systemUTC();

private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);

private ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient =
new WebClientReactiveClientCredentialsTokenResponseClient();

private ReactiveClientRegistrationRepository clientRegistrationRepository;

private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;

public ServerOAuth2AuthorizedClientExchangeFilterFunction() {}

public ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
this.clientRegistrationRepository = clientRegistrationRepository;
this.authorizedClientRepository = authorizedClientRepository;
}

Expand Down Expand Up @@ -164,6 +174,17 @@ public static Consumer<Map<String, Object>> clientRegistrationId(String clientRe
return attributes -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId);
}

/**
* Sets the {@link ReactiveOAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for
* client_credentials grant.
* @param clientCredentialsTokenResponseClient the client to use
*/
public void setClientCredentialsTokenResponseClient(
ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient) {
Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null");
this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient;
}

/**
* An access token will be considered expired by comparing its expiration to now +
* this skewed Duration. The default is 1 minute.
Expand Down Expand Up @@ -208,7 +229,39 @@ private Mono<OAuth2AuthorizedClient> findAuthorizedClientByRegistrationId(Client
private Mono<OAuth2AuthorizedClient> loadAuthorizedClient(String clientRegistrationId,
ServerWebExchange exchange, Authentication principal) {
return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, exchange)
.switchIfEmpty(Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId)));
.switchIfEmpty(authorizedClientNotFound(clientRegistrationId, exchange));
}

private Mono<OAuth2AuthorizedClient> authorizedClientNotFound(String clientRegistrationId, ServerWebExchange exchange) {
return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Client Registration with id " + clientRegistrationId + " was not found")))
.flatMap(clientRegistration -> {
if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
return clientCredentials(clientRegistration, exchange);
}
return Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId));
});
}

private Mono<? extends OAuth2AuthorizedClient> clientCredentials(
ClientRegistration clientRegistration, ServerWebExchange exchange) {
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
return this.clientCredentialsTokenResponseClient.getTokenResponse(grantRequest)
.flatMap(tokenResponse -> clientCredentialsResponse(clientRegistration, tokenResponse, exchange));
}

private Mono<OAuth2AuthorizedClient> clientCredentialsResponse(ClientRegistration clientRegistration, OAuth2AccessTokenResponse tokenResponse, ServerWebExchange exchange) {
return currentAuthentication()
.flatMap(principal -> {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
clientRegistration, (principal != null ?
principal.getName() :
"anonymousUser"),
tokenResponse.getAccessToken());

return this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, null)
.thenReturn(authorizedClient);
});
}

private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) {
Expand Down
Expand Up @@ -37,6 +37,7 @@
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
Expand Down Expand Up @@ -71,7 +72,10 @@
@RunWith(MockitoJUnitRunner.class)
public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Mock
private ServerOAuth2AuthorizedClientRepository auth2AuthorizedClientRepository;
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;

@Mock
private ReactiveClientRegistrationRepository clientRegistrationRepository;

private ServerOAuth2AuthorizedClientExchangeFilterFunction function = new ServerOAuth2AuthorizedClientExchangeFilterFunction();

Expand Down Expand Up @@ -125,7 +129,7 @@ public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() {

@Test
public void filterWhenRefreshRequiredThenRefresh() {
when(this.auth2AuthorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
.expiresIn(3600)
Expand All @@ -140,7 +144,7 @@ public void filterWhenRefreshRequiredThenRefresh() {
this.accessToken.getTokenValue(),
issuedAt,
accessTokenExpiresAt);
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);

OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
Expand All @@ -154,7 +158,7 @@ public void filterWhenRefreshRequiredThenRefresh() {
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
.block();

verify(this.auth2AuthorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any());
verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any());

List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(2);
Expand All @@ -174,7 +178,7 @@ public void filterWhenRefreshRequiredThenRefresh() {

@Test
public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() {
when(this.auth2AuthorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
.expiresIn(3600)
Expand All @@ -189,7 +193,7 @@ public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved()
this.accessToken.getTokenValue(),
issuedAt,
accessTokenExpiresAt);
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);

OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
Expand All @@ -201,7 +205,7 @@ public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved()
this.function.filter(request, this.exchange)
.block();

verify(this.auth2AuthorizedClientRepository).saveAuthorizedClient(any(), any(), any());
verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(), any());

List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(2);
Expand All @@ -221,7 +225,7 @@ public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved()

@Test
public void filterWhenRefreshTokenNullThenShouldRefreshFalse() {
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);

OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
"principalName", this.accessToken);
Expand All @@ -243,7 +247,7 @@ public void filterWhenRefreshTokenNullThenShouldRefreshFalse() {

@Test
public void filterWhenNotExpiredThenShouldRefreshFalse() {
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);

OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
Expand All @@ -266,12 +270,13 @@ public void filterWhenNotExpiredThenShouldRefreshFalse() {

@Test
public void filterWhenClientRegistrationIdThenAuthorizedClientResolved() {
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);

OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
"principalName", this.accessToken, refreshToken);
when(this.auth2AuthorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient));
when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient));
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.registration));
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.attributes(clientRegistrationId(this.registration.getRegistrationId()))
.build();
Expand Down
Expand Up @@ -18,7 +18,9 @@

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.web.reactive.function.client.WebClient;

/**
Expand All @@ -29,9 +31,10 @@
public class WebClientConfig {

@Bean
WebClient webClient() {
WebClient webClient(ReactiveClientRegistrationRepository clientRegistrationRepository,
ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
return WebClient.builder()
.filter(new ServerOAuth2AuthorizedClientExchangeFilterFunction())
.filter(new ServerOAuth2AuthorizedClientExchangeFilterFunction(clientRegistrationRepository, authorizedClientRepository))
.build();
}
}

0 comments on commit 158b8aa

Please sign in to comment.