Skip to content

Commit

Permalink
Remove temporary HttpSessionSecurityContextRepository
Browse files Browse the repository at this point in the history
Issue gh-482
  • Loading branch information
jgrandja committed May 24, 2022
1 parent d8421d5 commit c4406cd
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 148 deletions.
Expand Up @@ -16,30 +16,17 @@
package org.springframework.security.config.annotation.web.configurers.oauth2.server.authorization;

import java.net.URI;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import javax.servlet.AsyncContext;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;

import com.nimbusds.jose.jwk.source.JWKSource;

import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.security.config.Customizer;
import org.springframework.security.config.annotation.web.HttpSecurityBuilder;
import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
import org.springframework.security.config.annotation.web.configurers.ExceptionHandlingConfigurer;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.Transient;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsentService;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
Expand All @@ -51,11 +38,7 @@
import org.springframework.security.oauth2.server.authorization.web.ProviderContextFilter;
import org.springframework.security.web.authentication.HttpStatusEntryPoint;
import org.springframework.security.web.authentication.preauth.AbstractPreAuthenticatedProcessingFilter;
import org.springframework.security.web.context.HttpRequestResponseHolder;
import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
import org.springframework.security.web.context.SaveContextOnUpdateOrErrorResponseWrapper;
import org.springframework.security.web.context.SecurityContextPersistenceFilter;
import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.OrRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
Expand Down Expand Up @@ -254,127 +237,6 @@ public void init(B builder) {
getRequestMatcher(OAuth2TokenRevocationEndpointConfigurer.class))
);
}

// gh-482
initSecurityContextRepository(builder);
}

private void initSecurityContextRepository(B builder) {
// TODO This is a temporary fix and should be removed after upgrading to Spring Security 5.7.0 GA.
//
// See:
// Prevent Save @Transient Authentication with existing HttpSession
// https://github.com/spring-projects/spring-security/pull/9993

final SecurityContextRepository securityContextRepository = builder.getSharedObject(SecurityContextRepository.class);
if (!(securityContextRepository instanceof HttpSessionSecurityContextRepository)) {
return;
}

SecurityContextRepository securityContextRepositoryTransientNotSaved = new SecurityContextRepository() {

private final RequestMatcher clientAuthenticationRequestMatcher = initClientAuthenticationRequestMatcher();
private final RequestMatcher jwtAuthenticationRequestMatcher = initJwtAuthenticationRequestMatcher();

@Override
public SecurityContext loadContext(HttpRequestResponseHolder requestResponseHolder) {
final HttpServletRequest unwrappedRequest = requestResponseHolder.getRequest();
final HttpServletResponse unwrappedResponse = requestResponseHolder.getResponse();

SecurityContext securityContext = securityContextRepository.loadContext(requestResponseHolder);

if (this.clientAuthenticationRequestMatcher.matches(unwrappedRequest) ||
this.jwtAuthenticationRequestMatcher.matches(unwrappedRequest)) {

final SaveContextOnUpdateOrErrorResponseWrapper transientAuthenticationResponseWrapper =
new SaveContextOnUpdateOrErrorResponseWrapper(unwrappedResponse, false) {

@Override
protected void saveContext(SecurityContext context) {
// @Transient Authentication should not be saved
if (context.getAuthentication() != null) {
Assert.state(isTransientAuthentication(context.getAuthentication()), "Expected @Transient Authentication");
}
}

};
// Override the default HttpSessionSecurityContextRepository.SaveToSessionResponseWrapper
requestResponseHolder.setResponse(transientAuthenticationResponseWrapper);

final HttpServletRequestWrapper transientAuthenticationRequestWrapper =
new HttpServletRequestWrapper(unwrappedRequest) {

@Override
public AsyncContext startAsync() {
transientAuthenticationResponseWrapper.disableSaveOnResponseCommitted();
return super.startAsync();
}

@Override
public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse)
throws IllegalStateException {
transientAuthenticationResponseWrapper.disableSaveOnResponseCommitted();
return super.startAsync(servletRequest, servletResponse);
}

};
// Override the default HttpSessionSecurityContextRepository.SaveToSessionRequestWrapper
requestResponseHolder.setRequest(transientAuthenticationRequestWrapper);
}

return securityContext;
}

@Override
public void saveContext(SecurityContext context, HttpServletRequest request, HttpServletResponse response) {
Authentication authentication = context.getAuthentication();
if (authentication == null || isTransientAuthentication(authentication)) {
return;
}
securityContextRepository.saveContext(context, request, response);
}

@Override
public boolean containsContext(HttpServletRequest request) {
return securityContextRepository.containsContext(request);
}

private boolean isTransientAuthentication(Authentication authentication) {
return AnnotationUtils.getAnnotation(authentication.getClass(), Transient.class) != null;
}

private RequestMatcher initClientAuthenticationRequestMatcher() {
// OAuth2ClientAuthenticationToken is @Transient and is accepted by
// OAuth2TokenEndpointFilter, OAuth2TokenIntrospectionEndpointFilter and OAuth2TokenRevocationEndpointFilter

List<RequestMatcher> requestMatchers = new ArrayList<>();
requestMatchers.add(getRequestMatcher(OAuth2TokenEndpointConfigurer.class));
requestMatchers.add(getRequestMatcher(OAuth2TokenIntrospectionEndpointConfigurer.class));
requestMatchers.add(getRequestMatcher(OAuth2TokenRevocationEndpointConfigurer.class));
return new OrRequestMatcher(requestMatchers);
}

private RequestMatcher initJwtAuthenticationRequestMatcher() {
// JwtAuthenticationToken is @Transient and is accepted by
// OidcUserInfoEndpointFilter and OidcClientRegistrationEndpointFilter

List<RequestMatcher> requestMatchers = new ArrayList<>();
requestMatchers.add(
getConfigurer(OidcConfigurer.class)
.getConfigurer(OidcUserInfoEndpointConfigurer.class).getRequestMatcher()
);
OidcClientRegistrationEndpointConfigurer clientRegistrationEndpointConfigurer =
getConfigurer(OidcConfigurer.class)
.getConfigurer(OidcClientRegistrationEndpointConfigurer.class);
if (clientRegistrationEndpointConfigurer != null) {
requestMatchers.add(clientRegistrationEndpointConfigurer.getRequestMatcher());
}
return new OrRequestMatcher(requestMatchers);
}

};

builder.setSharedObject(SecurityContextRepository.class, securityContextRepositoryTransientNotSaved);
}

@Override
Expand Down
Expand Up @@ -127,7 +127,6 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
Expand Down Expand Up @@ -669,7 +668,7 @@ public void requestWhenClientObtainsAccessTokenThenClientAuthenticationNotPersis
String authorizationCode = extractParameterFromRedirectUri(mvcResult.getResponse().getRedirectedUrl(), "code");
OAuth2Authorization authorizationCodeAuthorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE);

this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI)
mvcResult = this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI)
.params(getTokenRequestParameters(registeredClient, authorizationCodeAuthorization))
.param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
.param(PkceParameterNames.CODE_VERIFIER, S256_CODE_VERIFIER))
Expand All @@ -680,9 +679,12 @@ public void requestWhenClientObtainsAccessTokenThenClientAuthenticationNotPersis
.andExpect(jsonPath("$.token_type").isNotEmpty())
.andExpect(jsonPath("$.expires_in").isNotEmpty())
.andExpect(jsonPath("$.refresh_token").doesNotExist())
.andExpect(jsonPath("$.scope").isNotEmpty());
.andExpect(jsonPath("$.scope").isNotEmpty())
.andReturn();

verify(securityContextRepository, never()).saveContext(any(), any(), any());
org.springframework.security.core.context.SecurityContext securityContext =
securityContextRepository.loadContext(mvcResult.getRequest()).get();
assertThat(securityContext.getAuthentication()).isNull();
}

private static MultiValueMap<String, String> getAuthorizationRequestParameters(RegisteredClient registeredClient) {
Expand Down
Expand Up @@ -69,13 +69,12 @@
import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
import org.springframework.test.web.servlet.ResultMatcher;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.never;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.springframework.test.web.servlet.ResultMatcher.matchAll;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
Expand Down Expand Up @@ -171,13 +170,16 @@ public void requestWhenUserInfoRequestThenBearerTokenAuthenticationNotPersisted(

OAuth2AccessToken accessToken = authorization.getAccessToken().getToken();
// @formatter:off
this.mvc.perform(get(DEFAULT_OIDC_USER_INFO_ENDPOINT_URI)
MvcResult mvcResult = this.mvc.perform(get(DEFAULT_OIDC_USER_INFO_ENDPOINT_URI)
.header(HttpHeaders.AUTHORIZATION, "Bearer " + accessToken.getTokenValue()))
.andExpect(status().is2xxSuccessful())
.andExpect(userInfoResponse());
.andExpect(userInfoResponse())
.andReturn();
// @formatter:on

verify(securityContextRepository, never()).saveContext(any(), any(), any());
org.springframework.security.core.context.SecurityContext securityContext =
securityContextRepository.loadContext(mvcResult.getRequest()).get();
assertThat(securityContext.getAuthentication()).isNull();
}

private static ResultMatcher userInfoResponse() {
Expand Down

0 comments on commit c4406cd

Please sign in to comment.