Skip to content

Commit

Permalink
Clear Repository on Logout
Browse files Browse the repository at this point in the history
  • Loading branch information
marcusdacoregio committed Apr 17, 2023
1 parent 37d8846 commit 2d52fb8
Show file tree
Hide file tree
Showing 7 changed files with 241 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,6 +35,8 @@
import org.springframework.security.web.authentication.logout.SecurityContextLogoutHandler;
import org.springframework.security.web.authentication.logout.SimpleUrlLogoutSuccessHandler;
import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter;
import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.OrRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
Expand Down Expand Up @@ -325,6 +327,7 @@ public List<LogoutHandler> getLogoutHandlers() {
* @return the {@link LogoutFilter} to use.
*/
private LogoutFilter createLogoutFilter(H http) {
this.contextLogoutHandler.setSecurityContextRepository(getSecurityContextRepository(http));
this.logoutHandlers.add(this.contextLogoutHandler);
this.logoutHandlers.add(postProcess(new LogoutSuccessEventPublishingLogoutHandler()));
LogoutHandler[] handlers = this.logoutHandlers.toArray(new LogoutHandler[0]);
Expand All @@ -334,6 +337,14 @@ private LogoutFilter createLogoutFilter(H http) {
return result;
}

private SecurityContextRepository getSecurityContextRepository(H http) {
SecurityContextRepository securityContextRepository = http.getSharedObject(SecurityContextRepository.class);
if (securityContextRepository == null) {
securityContextRepository = new HttpSessionSecurityContextRepository();
}
return securityContextRepository;
}

private RequestMatcher getLogoutRequestMatcher(H http) {
if (this.logoutRequestMatcher != null) {
return this.logoutRequestMatcher;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,24 +16,33 @@

package org.springframework.security.config.annotation.web.configurers;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.apache.http.HttpHeaders;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;

import org.springframework.beans.factory.BeanCreationException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.MediaType;
import org.springframework.mock.web.MockHttpSession;
import org.springframework.security.config.Customizer;
import org.springframework.security.config.annotation.ObjectPostProcessor;
import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
import org.springframework.security.config.test.SpringTestContext;
import org.springframework.security.config.test.SpringTestContextExtension;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.authentication.RememberMeServices;
import org.springframework.security.web.authentication.logout.LogoutFilter;
import org.springframework.security.web.authentication.logout.LogoutSuccessHandler;
import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder;
Expand All @@ -42,6 +51,7 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user;
Expand Down Expand Up @@ -302,6 +312,80 @@ public void logoutWhenDisabledThenLogoutUrlNotFound() throws Exception {
this.mvc.perform(post("/logout").with(csrf())).andExpect(status().isNotFound());
}

@Test
public void logoutWhenCustomSecurityContextRepositoryThenUses() throws Exception {
CustomSecurityContextRepositoryConfig.repository = mock(SecurityContextRepository.class);
this.spring.register(CustomSecurityContextRepositoryConfig.class).autowire();
// @formatter:off
MockHttpServletRequestBuilder logoutRequest = post("/logout")
.with(csrf())
.with(user("user"))
.header(HttpHeaders.ACCEPT, MediaType.TEXT_HTML_VALUE);
this.mvc.perform(logoutRequest)
.andExpect(status().isFound())
.andExpect(redirectedUrl("/login?logout"));
// @formatter:on
int invocationCount = 2; // 1 from user() post processor and 1 from
// SecurityContextLogoutHandler
verify(CustomSecurityContextRepositoryConfig.repository, times(invocationCount)).saveContext(any(),
any(HttpServletRequest.class), any(HttpServletResponse.class));
}

@Test
public void logoutWhenNoSecurityContextRepositoryThenHttpSessionSecurityContextRepository() throws Exception {
this.spring.register(InvalidateHttpSessionFalseConfig.class).autowire();
MockHttpSession session = mock(MockHttpSession.class);
// @formatter:off
MockHttpServletRequestBuilder logoutRequest = post("/logout")
.with(csrf())
.with(user("user"))
.session(session)
.header(HttpHeaders.ACCEPT, MediaType.TEXT_HTML_VALUE);
this.mvc.perform(logoutRequest)
.andExpect(status().isFound())
.andExpect(redirectedUrl("/login?logout"))
.andReturn();
// @formatter:on
verify(session).removeAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY);
}

@Configuration
@EnableWebSecurity
static class InvalidateHttpSessionFalseConfig {

@Bean
SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
// @formatter:off
http
.logout((logout) -> logout.invalidateHttpSession(false))
.securityContext((context) -> context.requireExplicitSave(true));
return http.build();
// @formatter:on
}

}

@Configuration
@EnableWebSecurity
static class CustomSecurityContextRepositoryConfig {

static SecurityContextRepository repository;

@Bean
SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
// @formatter:off
http
.logout(Customizer.withDefaults())
.securityContext((context) -> context
.requireExplicitSave(true)
.securityContextRepository(repository)
);
return http.build();
// @formatter:on
}

}

@EnableWebSecurity
static class NullLogoutSuccessHandlerConfig extends WebSecurityConfigurerAdapter {

Expand Down
1 change: 1 addition & 0 deletions docs/modules/ROOT/pages/servlet/authentication/logout.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The default is that accessing the URL `/logout` will log the user out by:
- Invalidating the HTTP Session
- Cleaning up any RememberMe authentication that was configured
- Clearing the `SecurityContextHolder`
- Clearing the `SecurityContextRepository`
- Redirect to `/login?logout`

Similar to configuring login capabilities, however, you also have various options to further customize your logout requirements:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.util.Assert;

/**
Expand All @@ -50,6 +52,8 @@ public class SecurityContextLogoutHandler implements LogoutHandler {

private boolean clearAuthentication = true;

private SecurityContextRepository securityContextRepository = new HttpSessionSecurityContextRepository();

/**
* Requires the request to be passed in.
* @param request from which to obtain a HTTP session (cannot be null)
Expand All @@ -73,6 +77,8 @@ public void logout(HttpServletRequest request, HttpServletResponse response, Aut
if (this.clearAuthentication) {
context.setAuthentication(null);
}
SecurityContext emptyContext = SecurityContextHolder.createEmptyContext();
this.securityContextRepository.saveContext(emptyContext, request, response);
}

public boolean isInvalidateHttpSession() {
Expand Down Expand Up @@ -100,4 +106,14 @@ public void setClearAuthentication(boolean clearAuthentication) {
this.clearAuthentication = clearAuthentication;
}

/**
* Sets the {@link SecurityContextRepository} to use. Default is
* {@link HttpSessionSecurityContextRepository}.
* @param securityContextRepository the {@link SecurityContextRepository} to use.
*/
public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) {
Assert.notNull(securityContextRepository, "securityContextRepository cannot be null");
this.securityContextRepository = securityContextRepository;
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -137,13 +137,46 @@ public void saveContext(SecurityContext context, HttpServletRequest request, Htt
SaveContextOnUpdateOrErrorResponseWrapper responseWrapper = WebUtils.getNativeResponse(response,
SaveContextOnUpdateOrErrorResponseWrapper.class);
if (responseWrapper == null) {
boolean httpSessionExists = request.getSession(false) != null;
SecurityContext initialContext = SecurityContextHolder.createEmptyContext();
responseWrapper = new SaveToSessionResponseWrapper(response, request, httpSessionExists, initialContext);
saveContextInHttpSession(context, request);
return;
}
responseWrapper.saveContext(context);
}

private void saveContextInHttpSession(SecurityContext context, HttpServletRequest request) {
if (isTransient(context) || isTransient(context.getAuthentication())) {
return;
}
SecurityContext emptyContext = generateNewContext();
if (emptyContext.equals(context)) {
HttpSession session = request.getSession(false);
removeContextFromSession(context, session);
}
else {
boolean createSession = this.allowSessionCreation;
HttpSession session = request.getSession(createSession);
setContextInSession(context, session);
}
}

private void setContextInSession(SecurityContext context, HttpSession session) {
if (session != null) {
session.setAttribute(this.springSecurityContextKey, context);
if (this.logger.isDebugEnabled()) {
this.logger.debug(LogMessage.format("Stored %s to HttpSession [%s]", context, session));
}
}
}

private void removeContextFromSession(SecurityContext context, HttpSession session) {
if (session != null) {
session.removeAttribute(this.springSecurityContextKey);
if (this.logger.isDebugEnabled()) {
this.logger.debug(LogMessage.format("Removed %s from HttpSession [%s]", context, session));
}
}
}

@Override
public boolean containsContext(HttpServletRequest request) {
HttpSession session = request.getSession(false);
Expand Down Expand Up @@ -369,11 +402,8 @@ protected void saveContext(SecurityContext context) {
// We may have a new session, so check also whether the context attribute
// is set SEC-1561
if (contextChanged(context) || httpSession.getAttribute(springSecurityContextKey) == null) {
httpSession.setAttribute(springSecurityContextKey, context);
HttpSessionSecurityContextRepository.this.saveContextInHttpSession(context, this.request);
this.isSaveContextInvoked = true;
if (this.logger.isDebugEnabled()) {
this.logger.debug(LogMessage.format("Stored %s to HttpSession [%s]", context, httpSession));
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2016 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,12 +27,19 @@
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.test.util.ReflectionTestUtils;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;

/**
* @author Rob Winch
*
*/
public class SecurityContextLogoutHandlerTests {

Expand Down Expand Up @@ -76,4 +83,35 @@ public void disableClearsAuthentication() {
assertThat(beforeContext.getAuthentication()).isSameAs(beforeAuthentication);
}

@Test
public void logoutWhenSecurityContextRepositoryThenSaveEmptyContext() {
SecurityContextRepository repository = mock(SecurityContextRepository.class);
this.handler.setSecurityContextRepository(repository);
this.handler.logout(this.request, this.response, SecurityContextHolder.getContext().getAuthentication());
verify(repository).saveContext(eq(SecurityContextHolder.createEmptyContext()), any(), any());
}

@Test
public void logoutWhenClearAuthenticationFalseThenSaveEmptyContext() {
SecurityContextRepository repository = mock(SecurityContextRepository.class);
this.handler.setSecurityContextRepository(repository);
this.handler.setClearAuthentication(false);
this.handler.logout(this.request, this.response, SecurityContextHolder.getContext().getAuthentication());
verify(repository).saveContext(eq(SecurityContextHolder.createEmptyContext()), any(), any());
}

@Test
public void constructorWhenDefaultSecurityContextRepositoryThenHttpSessionSecurityContextRepository() {
SecurityContextRepository securityContextRepository = (SecurityContextRepository) ReflectionTestUtils
.getField(this.handler, "securityContextRepository");
assertThat(securityContextRepository).isInstanceOf(HttpSessionSecurityContextRepository.class);
}

@Test
public void setSecurityContextRepositoryWhenNullThenException() {
assertThatExceptionOfType(IllegalArgumentException.class)
.isThrownBy(() -> this.handler.setSecurityContextRepository(null))
.withMessage("securityContextRepository cannot be null");
}

}
Loading

0 comments on commit 2d52fb8

Please sign in to comment.