diff --git a/web/src/test/java/org/springframework/security/web/context/HttpSessionSecurityContextRepositoryTests.java b/web/src/test/java/org/springframework/security/web/context/HttpSessionSecurityContextRepositoryTests.java index aea6eca5029..d122376b4f4 100644 --- a/web/src/test/java/org/springframework/security/web/context/HttpSessionSecurityContextRepositoryTests.java +++ b/web/src/test/java/org/springframework/security/web/context/HttpSessionSecurityContextRepositoryTests.java @@ -16,12 +16,16 @@ package org.springframework.security.web.context; +import java.io.IOException; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +import javax.servlet.Filter; +import javax.servlet.ServletException; import javax.servlet.ServletOutputStream; +import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import javax.servlet.http.HttpServletResponse; @@ -31,6 +35,7 @@ import org.junit.After; import org.junit.Test; +import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpSession; @@ -38,10 +43,14 @@ import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.AuthenticationTrustResolver; import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Transient; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextImpl; +import org.springframework.security.core.userdetails.User; +import org.springframework.security.core.userdetails.UserDetails; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; @@ -162,6 +171,48 @@ public void saveContextCallsSetAttributeIfContextIsModifiedDirectlyDuringRequest verify(session).setAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, ctx); } + @Test + public void saveContextWhenSaveNewContextThenOriginalContextThenOriginalContextSaved() throws Exception { + HttpSessionSecurityContextRepository repository = new HttpSessionSecurityContextRepository(); + SecurityContextPersistenceFilter securityContextPersistenceFilter = new SecurityContextPersistenceFilter( + repository); + + UserDetails original = User.withUsername("user").password("password").roles("USER").build(); + SecurityContext originalContext = createSecurityContext(original); + UserDetails impersonate = User.withUserDetails(original).username("impersonate").build(); + SecurityContext impersonateContext = createSecurityContext(impersonate); + + MockHttpServletRequest mockRequest = new MockHttpServletRequest(); + MockHttpServletResponse mockResponse = new MockHttpServletResponse(); + + Filter saveImpersonateContext = (request, response, chain) -> { + SecurityContextHolder.setContext(impersonateContext); + // ensure the response is committed to trigger save + response.flushBuffer(); + chain.doFilter(request, response); + }; + Filter saveOriginalContext = (request, response, chain) -> { + SecurityContextHolder.setContext(originalContext); + chain.doFilter(request, response); + }; + HttpServlet servlet = new HttpServlet() { + @Override + protected void service(HttpServletRequest req, HttpServletResponse resp) + throws ServletException, IOException { + resp.getWriter().write("Hi"); + } + }; + + SecurityContextHolder.setContext(originalContext); + MockFilterChain chain = new MockFilterChain(servlet, saveImpersonateContext, saveOriginalContext); + + securityContextPersistenceFilter.doFilter(mockRequest, mockResponse, chain); + + assertThat( + mockRequest.getSession().getAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY)) + .isEqualTo(originalContext); + } + @Test public void nonSecurityContextInSessionIsIgnored() { HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository(); @@ -577,6 +628,13 @@ public void saveContextWhenTransientAuthenticationWithCustomAnnotationThenSkippe assertThat(session).isNull(); } + private SecurityContext createSecurityContext(UserDetails userDetails) { + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(userDetails, + userDetails.getPassword(), userDetails.getAuthorities()); + SecurityContext securityContext = new SecurityContextImpl(token); + return securityContext; + } + @Transient private static class SomeTransientAuthentication extends AbstractAuthenticationToken {