From 36c3f4442476eadf597562ae51c0706314caf21c Mon Sep 17 00:00:00 2001 From: Kees Koffeman Date: Sat, 6 Jun 2015 11:33:42 +0200 Subject: [PATCH] Save token after successful authentication when an ClientTokenServices is provided. Fixes gh-498 --- ...2ClientAuthenticationProcessingFilter.java | 34 ++++++++++++----- ...ntAuthenticationProcessingFilterTests.java | 38 +++++++++++++------ 2 files changed, 51 insertions(+), 21 deletions(-) diff --git a/spring-security-oauth2/src/main/java/org/springframework/security/oauth2/client/filter/OAuth2ClientAuthenticationProcessingFilter.java b/spring-security-oauth2/src/main/java/org/springframework/security/oauth2/client/filter/OAuth2ClientAuthenticationProcessingFilter.java index 6fd1a6e27..586ce54f5 100644 --- a/spring-security-oauth2/src/main/java/org/springframework/security/oauth2/client/filter/OAuth2ClientAuthenticationProcessingFilter.java +++ b/spring-security-oauth2/src/main/java/org/springframework/security/oauth2/client/filter/OAuth2ClientAuthenticationProcessingFilter.java @@ -16,20 +16,15 @@ package org.springframework.security.oauth2.client.filter; -import java.io.IOException; - -import javax.servlet.FilterChain; -import javax.servlet.ServletException; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; +import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.client.OAuth2RestOperations; import org.springframework.security.oauth2.client.http.AccessTokenRequiredException; +import org.springframework.security.oauth2.client.token.ClientTokenServices; import org.springframework.security.oauth2.common.OAuth2AccessToken; import org.springframework.security.oauth2.common.exceptions.InvalidTokenException; import org.springframework.security.oauth2.common.exceptions.OAuth2Exception; @@ -40,6 +35,12 @@ import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter; import org.springframework.util.Assert; +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.io.IOException; + /** * An OAuth2 client filter that can be used to acquire an OAuth2 access token from an authorization server, and load an * authentication object into the SecurityContext @@ -53,9 +54,20 @@ public class OAuth2ClientAuthenticationProcessingFilter extends AbstractAuthenti private ResourceServerTokenServices tokenServices; + private ClientTokenServices clientTokenServices; + private AuthenticationDetailsSource authenticationDetailsSource = new OAuth2AuthenticationDetailsSource(); - /** + /** + * Reference to a ClientTokenServices that can save an OAuth2AccessToken after successful authentication. + * + * @param clientTokenServices + */ + public void setClientTokenServices(ClientTokenServices clientTokenServices) { + this.clientTokenServices = clientTokenServices; + } + + /** * Reference to a CheckTokenServices that can validate an OAuth2AccessToken * * @param tokenServices @@ -114,8 +126,10 @@ public Authentication attemptAuthentication(HttpServletRequest request, HttpServ protected void successfulAuthentication(HttpServletRequest request, HttpServletResponse response, FilterChain chain, Authentication authResult) throws IOException, ServletException { super.successfulAuthentication(request, response, chain, authResult); - // Nearly a no-op, but if there is a ClientTokenServices then the token will now be stored - restTemplate.getAccessToken(); + + if (clientTokenServices != null) { + clientTokenServices.saveAccessToken(restTemplate.getResource(), SecurityContextHolder.getContext().getAuthentication(), restTemplate.getAccessToken()); + } } @Override diff --git a/spring-security-oauth2/src/test/java/org/springframework/security/oauth2/client/filter/OAuth2ClientAuthenticationProcessingFilterTests.java b/spring-security-oauth2/src/test/java/org/springframework/security/oauth2/client/filter/OAuth2ClientAuthenticationProcessingFilterTests.java index 1b84f709d..315df32a9 100644 --- a/spring-security-oauth2/src/test/java/org/springframework/security/oauth2/client/filter/OAuth2ClientAuthenticationProcessingFilterTests.java +++ b/spring-security-oauth2/src/test/java/org/springframework/security/oauth2/client/filter/OAuth2ClientAuthenticationProcessingFilterTests.java @@ -15,16 +15,6 @@ */ package org.springframework.security.oauth2.client.filter; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; - -import java.io.IOException; -import java.util.Arrays; -import java.util.HashSet; -import java.util.Set; - -import javax.servlet.ServletException; - import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -33,8 +23,10 @@ import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.client.OAuth2RestOperations; import org.springframework.security.oauth2.client.http.AccessTokenRequiredException; +import org.springframework.security.oauth2.client.token.ClientTokenServices; import org.springframework.security.oauth2.common.DefaultOAuth2AccessToken; import org.springframework.security.oauth2.common.exceptions.OAuth2Exception; import org.springframework.security.oauth2.provider.OAuth2Authentication; @@ -43,6 +35,15 @@ import org.springframework.security.oauth2.provider.authentication.OAuth2AuthenticationDetails; import org.springframework.security.oauth2.provider.token.ResourceServerTokenServices; +import javax.servlet.ServletException; +import java.io.IOException; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + public class OAuth2ClientAuthenticationProcessingFilterTests { private OAuth2ClientAuthenticationProcessingFilter filter = new OAuth2ClientAuthenticationProcessingFilter( @@ -53,6 +54,8 @@ public class OAuth2ClientAuthenticationProcessingFilterTests { private OAuth2RestOperations restTemplate = Mockito.mock(OAuth2RestOperations.class); private OAuth2Authentication authentication; + + private ClientTokenServices clientTokenServices = Mockito.mock(ClientTokenServices.class); @Rule public ExpectedException expected = ExpectedException.none(); @@ -96,9 +99,22 @@ public void testSuccessfulAuthentication() throws Exception { OAuth2Request storedOAuth2Request = RequestTokenFactory.createOAuth2Request("client", false, scopes); this.authentication = new OAuth2Authentication(storedOAuth2Request, null); filter.successfulAuthentication(new MockHttpServletRequest(), new MockHttpServletResponse(), null, authentication); - Mockito.verify(restTemplate, Mockito.times(1)).getAccessToken(); + Mockito.verify(clientTokenServices, Mockito.times(0)).saveAccessToken(restTemplate.getResource(), SecurityContextHolder.getContext().getAuthentication(),restTemplate.getAccessToken()); } + @Test + public void testSuccessfulAuthenticationWithClientTokenServices() throws Exception { + filter.setRestTemplate(restTemplate); + filter.setClientTokenServices(clientTokenServices); + Set scopes = new HashSet(); + scopes.addAll(Arrays.asList("read", "write")); + OAuth2Request storedOAuth2Request = RequestTokenFactory.createOAuth2Request("client", false, scopes); + this.authentication = new OAuth2Authentication(storedOAuth2Request, null); + filter.successfulAuthentication(new MockHttpServletRequest(), new MockHttpServletResponse(), null, authentication); + Mockito.verify(restTemplate, Mockito.times(1)).getAccessToken(); + Mockito.verify(clientTokenServices, Mockito.times(1)).saveAccessToken(restTemplate.getResource(), SecurityContextHolder.getContext().getAuthentication(),restTemplate.getAccessToken()); + } + @Test public void testDeniedToken() throws Exception { filter.setRestTemplate(restTemplate);