Skip to content
This repository was archived by the owner on May 31, 2022. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,15 @@

package org.springframework.security.oauth2.client.filter;

import java.io.IOException;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.springframework.security.authentication.AuthenticationDetailsSource;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.OAuth2RestOperations;
import org.springframework.security.oauth2.client.http.AccessTokenRequiredException;
import org.springframework.security.oauth2.client.token.ClientTokenServices;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.exceptions.InvalidTokenException;
import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
Expand All @@ -40,6 +35,12 @@
import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
import org.springframework.util.Assert;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;

/**
* An OAuth2 client filter that can be used to acquire an OAuth2 access token from an authorization server, and load an
* authentication object into the SecurityContext
Expand All @@ -53,9 +54,20 @@ public class OAuth2ClientAuthenticationProcessingFilter extends AbstractAuthenti

private ResourceServerTokenServices tokenServices;

private ClientTokenServices clientTokenServices;

private AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource = new OAuth2AuthenticationDetailsSource();

/**
/**
* Reference to a ClientTokenServices that can save an OAuth2AccessToken after successful authentication.
*
* @param clientTokenServices
*/
public void setClientTokenServices(ClientTokenServices clientTokenServices) {
this.clientTokenServices = clientTokenServices;
}

/**
* Reference to a CheckTokenServices that can validate an OAuth2AccessToken
*
* @param tokenServices
Expand Down Expand Up @@ -114,8 +126,10 @@ public Authentication attemptAuthentication(HttpServletRequest request, HttpServ
protected void successfulAuthentication(HttpServletRequest request, HttpServletResponse response,
FilterChain chain, Authentication authResult) throws IOException, ServletException {
super.successfulAuthentication(request, response, chain, authResult);
// Nearly a no-op, but if there is a ClientTokenServices then the token will now be stored
restTemplate.getAccessToken();

if (clientTokenServices != null) {
clientTokenServices.saveAccessToken(restTemplate.getResource(), SecurityContextHolder.getContext().getAuthentication(), restTemplate.getAccessToken());
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,6 @@
*/
package org.springframework.security.oauth2.client.filter;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;

import javax.servlet.ServletException;

import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
Expand All @@ -33,8 +23,10 @@
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.OAuth2RestOperations;
import org.springframework.security.oauth2.client.http.AccessTokenRequiredException;
import org.springframework.security.oauth2.client.token.ClientTokenServices;
import org.springframework.security.oauth2.common.DefaultOAuth2AccessToken;
import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
Expand All @@ -43,6 +35,15 @@
import org.springframework.security.oauth2.provider.authentication.OAuth2AuthenticationDetails;
import org.springframework.security.oauth2.provider.token.ResourceServerTokenServices;

import javax.servlet.ServletException;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;

public class OAuth2ClientAuthenticationProcessingFilterTests {

private OAuth2ClientAuthenticationProcessingFilter filter = new OAuth2ClientAuthenticationProcessingFilter(
Expand All @@ -53,6 +54,8 @@ public class OAuth2ClientAuthenticationProcessingFilterTests {
private OAuth2RestOperations restTemplate = Mockito.mock(OAuth2RestOperations.class);

private OAuth2Authentication authentication;

private ClientTokenServices clientTokenServices = Mockito.mock(ClientTokenServices.class);

@Rule
public ExpectedException expected = ExpectedException.none();
Expand Down Expand Up @@ -96,9 +99,22 @@ public void testSuccessfulAuthentication() throws Exception {
OAuth2Request storedOAuth2Request = RequestTokenFactory.createOAuth2Request("client", false, scopes);
this.authentication = new OAuth2Authentication(storedOAuth2Request, null);
filter.successfulAuthentication(new MockHttpServletRequest(), new MockHttpServletResponse(), null, authentication);
Mockito.verify(restTemplate, Mockito.times(1)).getAccessToken();
Mockito.verify(clientTokenServices, Mockito.times(0)).saveAccessToken(restTemplate.getResource(), SecurityContextHolder.getContext().getAuthentication(),restTemplate.getAccessToken());
}

@Test
public void testSuccessfulAuthenticationWithClientTokenServices() throws Exception {
filter.setRestTemplate(restTemplate);
filter.setClientTokenServices(clientTokenServices);
Set<String> scopes = new HashSet<String>();
scopes.addAll(Arrays.asList("read", "write"));
OAuth2Request storedOAuth2Request = RequestTokenFactory.createOAuth2Request("client", false, scopes);
this.authentication = new OAuth2Authentication(storedOAuth2Request, null);
filter.successfulAuthentication(new MockHttpServletRequest(), new MockHttpServletResponse(), null, authentication);
Mockito.verify(restTemplate, Mockito.times(1)).getAccessToken();
Mockito.verify(clientTokenServices, Mockito.times(1)).saveAccessToken(restTemplate.getResource(), SecurityContextHolder.getContext().getAuthentication(),restTemplate.getAccessToken());
}

@Test
public void testDeniedToken() throws Exception {
filter.setRestTemplate(restTemplate);
Expand Down