Skip to content
Permalink
3 contributors

Users who have contributed to this file

@dsyer @kazuki43zoo @jgrandja
252 lines (207 sloc) 9.02 KB
package org.springframework.security.oauth2.client.token;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.http.converter.FormHttpMessageConverter;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.security.oauth2.client.resource.OAuth2AccessDeniedException;
import org.springframework.security.oauth2.client.resource.OAuth2ProtectedResourceDetails;
import org.springframework.security.oauth2.client.token.auth.ClientAuthenticationHandler;
import org.springframework.security.oauth2.client.token.auth.DefaultClientAuthenticationHandler;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
import org.springframework.security.oauth2.http.converter.FormOAuth2AccessTokenMessageConverter;
import org.springframework.security.oauth2.http.converter.FormOAuth2ExceptionHttpMessageConverter;
import org.springframework.util.Assert;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.DefaultResponseErrorHandler;
import org.springframework.web.client.HttpMessageConverterExtractor;
import org.springframework.web.client.RequestCallback;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.ResponseExtractor;
import org.springframework.web.client.RestClientException;
import org.springframework.web.client.RestOperations;
import org.springframework.web.client.RestTemplate;
import java.io.IOException;
import java.net.HttpURLConnection;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* Base support logic for obtaining access tokens.
*
* @author Ryan Heaton
* @author Dave Syer
*/
public abstract class OAuth2AccessTokenSupport {
protected final Log logger = LogFactory.getLog(getClass());
private static final FormHttpMessageConverter FORM_MESSAGE_CONVERTER = new FormHttpMessageConverter();
private RestOperations restTemplate;
private List<HttpMessageConverter<?>> messageConverters;
private ClientAuthenticationHandler authenticationHandler = new DefaultClientAuthenticationHandler();
private ResponseErrorHandler responseErrorHandler = new AccessTokenErrorHandler();
private List<ClientHttpRequestInterceptor> interceptors = new ArrayList<ClientHttpRequestInterceptor>();
private RequestEnhancer tokenRequestEnhancer = new DefaultRequestEnhancer();
/**
* Sets the request interceptors that this accessor should use.
*/
public void setInterceptors(List<ClientHttpRequestInterceptor> interceptors) {
this.interceptors = interceptors;
}
/**
* A custom enhancer for the access token request
* @param tokenRequestEnhancer
*/
public void setTokenRequestEnhancer(RequestEnhancer tokenRequestEnhancer) {
this.tokenRequestEnhancer = tokenRequestEnhancer;
}
private ClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory() {
@Override
protected void prepareConnection(HttpURLConnection connection, String httpMethod)
throws IOException {
super.prepareConnection(connection, httpMethod);
connection.setInstanceFollowRedirects(false);
connection.setUseCaches(false);
}
};
protected RestOperations getRestTemplate() {
if (restTemplate == null) {
synchronized (this) {
if (restTemplate == null) {
RestTemplate restTemplate = new RestTemplate();
restTemplate.setErrorHandler(getResponseErrorHandler());
restTemplate.setRequestFactory(requestFactory);
restTemplate.setInterceptors(interceptors);
this.restTemplate = restTemplate;
}
}
}
if (messageConverters == null) {
setMessageConverters(new RestTemplate().getMessageConverters());
}
return restTemplate;
}
public void setAuthenticationHandler(ClientAuthenticationHandler authenticationHandler) {
this.authenticationHandler = authenticationHandler;
}
public void setMessageConverters(List<HttpMessageConverter<?>> messageConverters) {
this.messageConverters = new ArrayList<HttpMessageConverter<?>>(messageConverters);
this.messageConverters.add(new FormOAuth2AccessTokenMessageConverter());
this.messageConverters.add(new FormOAuth2ExceptionHttpMessageConverter());
}
protected OAuth2AccessToken retrieveToken(AccessTokenRequest request, OAuth2ProtectedResourceDetails resource,
MultiValueMap<String, String> form, HttpHeaders headers) throws OAuth2AccessDeniedException {
try {
// Prepare headers and form before going into rest template call in case the URI is affected by the result
authenticationHandler.authenticateTokenRequest(resource, form, headers);
// Opportunity to customize form and headers
tokenRequestEnhancer.enhance(request, resource, form, headers);
final AccessTokenRequest copy = request;
final ResponseExtractor<OAuth2AccessToken> delegate = getResponseExtractor();
ResponseExtractor<OAuth2AccessToken> extractor = new ResponseExtractor<OAuth2AccessToken>() {
@Override
public OAuth2AccessToken extractData(ClientHttpResponse response) throws IOException {
if (response.getHeaders().containsKey("Set-Cookie")) {
copy.setCookie(response.getHeaders().getFirst("Set-Cookie"));
}
return delegate.extractData(response);
}
};
return getRestTemplate().execute(getAccessTokenUri(resource, form), getHttpMethod(),
getRequestCallback(resource, form, headers), extractor , form.toSingleValueMap());
}
catch (OAuth2Exception oe) {
throw new OAuth2AccessDeniedException("Access token denied.", resource, oe);
}
catch (RestClientException rce) {
throw new OAuth2AccessDeniedException("Error requesting access token.", resource, rce);
}
}
protected HttpMethod getHttpMethod() {
return HttpMethod.POST;
}
protected String getAccessTokenUri(OAuth2ProtectedResourceDetails resource, MultiValueMap<String, String> form) {
String accessTokenUri = resource.getAccessTokenUri();
if (logger.isDebugEnabled()) {
logger.debug("Retrieving token from " + accessTokenUri);
}
StringBuilder builder = new StringBuilder(accessTokenUri);
if (getHttpMethod() == HttpMethod.GET) {
String separator = "?";
if (accessTokenUri.contains("?")) {
separator = "&";
}
for (String key : form.keySet()) {
builder.append(separator);
builder.append(key + "={" + key + "}");
separator = "&";
}
}
return builder.toString();
}
protected ResponseErrorHandler getResponseErrorHandler() {
return responseErrorHandler;
}
/**
* Set the request factory that this template uses for obtaining {@link ClientHttpRequest HttpRequests}.
*/
public void setRequestFactory(ClientHttpRequestFactory requestFactory) {
Assert.notNull(requestFactory, "'requestFactory' must not be null");
this.requestFactory = requestFactory;
}
protected ResponseExtractor<OAuth2AccessToken> getResponseExtractor() {
getRestTemplate(); // force initialization
return new HttpMessageConverterExtractor<OAuth2AccessToken>(OAuth2AccessToken.class, this.messageConverters);
}
protected RequestCallback getRequestCallback(OAuth2ProtectedResourceDetails resource,
MultiValueMap<String, String> form, HttpHeaders headers) {
return new OAuth2AuthTokenCallback(form, headers);
}
/**
* Request callback implementation that writes the given object to the request stream.
*/
private class OAuth2AuthTokenCallback implements RequestCallback {
private final MultiValueMap<String, String> form;
private final HttpHeaders headers;
private OAuth2AuthTokenCallback(MultiValueMap<String, String> form, HttpHeaders headers) {
this.form = form;
this.headers = headers;
}
public void doWithRequest(ClientHttpRequest request) throws IOException {
request.getHeaders().putAll(this.headers);
request.getHeaders().setAccept(
Arrays.asList(MediaType.APPLICATION_JSON, MediaType.APPLICATION_FORM_URLENCODED));
if (logger.isDebugEnabled()) {
logger.debug("Encoding and sending form: " + form);
}
FORM_MESSAGE_CONVERTER.write(this.form, MediaType.APPLICATION_FORM_URLENCODED, request);
}
}
private class AccessTokenErrorHandler extends DefaultResponseErrorHandler {
@SuppressWarnings("unchecked")
@Override
public void handleError(ClientHttpResponse response) throws IOException {
for (HttpMessageConverter<?> converter : messageConverters) {
if (converter.canRead(OAuth2Exception.class, response.getHeaders().getContentType())) {
OAuth2Exception ex;
try {
ex = ((HttpMessageConverter<OAuth2Exception>) converter).read(OAuth2Exception.class, response);
}
catch (Exception e) {
// ignore
continue;
}
throw ex;
}
}
super.handleError(response);
}
}
}
You can’t perform that action at this time.