diff --git a/spring-security-oauth2/src/main/java/org/springframework/security/oauth2/provider/endpoint/CheckTokenEndpoint.java b/spring-security-oauth2/src/main/java/org/springframework/security/oauth2/provider/endpoint/CheckTokenEndpoint.java index 0639f893c..57b59a7c6 100644 --- a/spring-security-oauth2/src/main/java/org/springframework/security/oauth2/provider/endpoint/CheckTokenEndpoint.java +++ b/spring-security-oauth2/src/main/java/org/springframework/security/oauth2/provider/endpoint/CheckTokenEndpoint.java @@ -1,15 +1,18 @@ -/******************************************************************************* - * Cloud Foundry - * Copyright (c) [2009-2014] Pivotal Software, Inc. All Rights Reserved. +/* + * Copyright 2009-2019 the original author or authors. * - * This product is licensed to you under the Apache License, Version 2.0 (the "License"). - * You may not use this product except in compliance with the License. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * This product includes a number of subcomponents with - * separate copyright notices and license terms. Your use of these - * subcomponents is subject to the terms and conditions of the - * subcomponent's license, as noted in the LICENSE file. - *******************************************************************************/ + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.springframework.security.oauth2.provider.endpoint; import org.apache.commons.logging.Log; @@ -42,7 +45,7 @@ public class CheckTokenEndpoint { private ResourceServerTokenServices resourceServerTokenServices; - private AccessTokenConverter accessTokenConverter = new DefaultAccessTokenConverter(); + private AccessTokenConverter accessTokenConverter = new CheckTokenAccessTokenConverter(); protected final Log logger = LogFactory.getLog(getClass()); @@ -81,12 +84,7 @@ public void setAccessTokenConverter(AccessTokenConverter accessTokenConverter) { OAuth2Authentication authentication = resourceServerTokenServices.loadAuthentication(token.getValue()); - Map response = (Map)accessTokenConverter.convertAccessToken(token, authentication); - - // gh-1070 - response.put("active", true); // Always true if token exists and not expired - - return response; + return accessTokenConverter.convertAccessToken(token, authentication); } @ExceptionHandler(InvalidTokenException.class) @@ -106,4 +104,35 @@ public int getHttpErrorCode() { return exceptionTranslator.translate(e400); } + static class CheckTokenAccessTokenConverter implements AccessTokenConverter { + private final AccessTokenConverter accessTokenConverter; + + CheckTokenAccessTokenConverter() { + this(new DefaultAccessTokenConverter()); + } + + CheckTokenAccessTokenConverter(AccessTokenConverter accessTokenConverter) { + this.accessTokenConverter = accessTokenConverter; + } + + @Override + public Map convertAccessToken(OAuth2AccessToken token, OAuth2Authentication authentication) { + Map claims = (Map) this.accessTokenConverter.convertAccessToken(token, authentication); + + // gh-1070 + claims.put("active", true); // Always true if token exists and not expired + + return claims; + } + + @Override + public OAuth2AccessToken extractAccessToken(String value, Map map) { + return this.accessTokenConverter.extractAccessToken(value, map); + } + + @Override + public OAuth2Authentication extractAuthentication(Map map) { + return this.accessTokenConverter.extractAuthentication(map); + } + } } diff --git a/spring-security-oauth2/src/test/java/org/springframework/security/oauth2/provider/endpoint/CheckTokenEndpointTest.java b/spring-security-oauth2/src/test/java/org/springframework/security/oauth2/provider/endpoint/CheckTokenEndpointTest.java index 87fc2ffb7..5c7dfd093 100644 --- a/spring-security-oauth2/src/test/java/org/springframework/security/oauth2/provider/endpoint/CheckTokenEndpointTest.java +++ b/spring-security-oauth2/src/test/java/org/springframework/security/oauth2/provider/endpoint/CheckTokenEndpointTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2017 the original author or authors. + * Copyright 2012-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,8 +25,8 @@ import java.util.HashMap; import java.util.Map; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.*; +import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -36,29 +36,36 @@ */ public class CheckTokenEndpointTest { private CheckTokenEndpoint checkTokenEndpoint; + private AccessTokenConverter accessTokenConverter; @Before public void setUp() { - ResourceServerTokenServices resourceServerTokenServices = mock(ResourceServerTokenServices.class); OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class); - OAuth2Authentication authentication = mock(OAuth2Authentication.class); - when(resourceServerTokenServices.readAccessToken(anyString())).thenReturn(accessToken); when(accessToken.isExpired()).thenReturn(false); - when(accessToken.getValue()).thenReturn("access-token-1234"); - when(resourceServerTokenServices.loadAuthentication(accessToken.getValue())).thenReturn(authentication); + ResourceServerTokenServices resourceServerTokenServices = mock(ResourceServerTokenServices.class); + when(resourceServerTokenServices.readAccessToken(anyString())).thenReturn(accessToken); this.checkTokenEndpoint = new CheckTokenEndpoint(resourceServerTokenServices); - - AccessTokenConverter accessTokenConverter = mock(AccessTokenConverter.class); - when(accessTokenConverter.convertAccessToken(accessToken, authentication)).thenReturn(new HashMap()); - this.checkTokenEndpoint.setAccessTokenConverter(accessTokenConverter); + this.accessTokenConverter = mock(AccessTokenConverter.class); + when(this.accessTokenConverter.convertAccessToken(any(OAuth2AccessToken.class), any(OAuth2Authentication.class))).thenReturn(new HashMap()); + this.checkTokenEndpoint.setAccessTokenConverter(new CheckTokenEndpoint.CheckTokenAccessTokenConverter(this.accessTokenConverter)); } // gh-1070 @Test - public void checkTokenWhenTokenValidThenReturnActiveAttribute() throws Exception { + public void checkTokenWhenDefaultAccessTokenConverterThenActiveAttributeReturned() throws Exception { Map response = this.checkTokenEndpoint.checkToken("access-token-1234"); Object active = response.get("active"); assertNotNull("active is null", active); assertEquals("active not true", Boolean.TRUE, active); } + + // gh-1591 + @Test + public void checkTokenWhenCustomAccessTokenConverterThenActiveAttributeNotReturned() throws Exception { + this.accessTokenConverter = mock(AccessTokenConverter.class); + when(this.accessTokenConverter.convertAccessToken(any(OAuth2AccessToken.class), any(OAuth2Authentication.class))).thenReturn(new HashMap()); + this.checkTokenEndpoint.setAccessTokenConverter(this.accessTokenConverter); + Map response = this.checkTokenEndpoint.checkToken("access-token-1234"); + assertNull("active is not null", response.get("active")); + } } \ No newline at end of file