Skip to content
Merged
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
package com.vaadin.flow.spring.fusionsecurityjwt;

import java.util.Base64;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import org.junit.Assert;
import org.junit.Test;
import org.openqa.selenium.By;
import org.openqa.selenium.Cookie;
import org.openqa.selenium.JavascriptExecutor;

import com.vaadin.testbench.TestBenchElement;

import elemental.json.Json;
import elemental.json.JsonObject;
Expand Down Expand Up @@ -104,6 +97,26 @@ public void stateless_for_anonymous_after_logout() {
assertPublicEndpointWorks();
}

@Override
public void reload_when_anonymous_session_expires() {
// Skip: the server session is not relevant in the stateless mode
}

@Override
public void reload_when_user_session_expires() {
// Skip: the server session is not relevant in the stateless mode
}

@Test
public void reload_when_user_jwt_expires() {
openLogin();
loginUser();
getDriver().manage().deleteCookieNamed("jwt.headerAndPayload");
getDriver().manage().deleteCookieNamed("jwt.signature");
navigateTo("private", false);
assertLoginViewShown();
}

private void openLogin() {
getDriver().get(getRootURL() + "/login");
}
Expand All @@ -113,7 +126,7 @@ private Cookie getSpringCsrfCookie() {
}

private Cookie getJwtCookie() {
return getDriver().manage().getCookieNamed("jwt" + ".headerAndPayload");
return getDriver().manage().getCookieNamed("jwt.headerAndPayload");
}

private void checkJwtUsername(String expectedUsername) {
Expand All @@ -126,43 +139,4 @@ private void checkJwtUsername(String expectedUsername) {
Assert.assertEquals(expectedUsername, payloadJson.getString("sub"));
}

private void simulateNewServer() {
TestBenchElement mainView = waitUntil(driver -> $("main-view").get(0));
callAsyncMethod(mainView, "invalidateSessionIfPresent");
}

private void assertPublicEndpointWorks() {
TestBenchElement publicView = waitUntil(
driver -> $("public-view").get(0));
TestBenchElement timeText = publicView.findElement(By.id("time"));
String timeBefore = timeText.getText();
Assert.assertNotNull(timeBefore);
callAsyncMethod(publicView, "updateTime");
String timeAfter = timeText.getText();
Assert.assertNotNull(timeAfter);
Assert.assertNotEquals(timeAfter, timeBefore);
}

private String formatArgumentRef(int index) {
return String.format("arguments[%d]", index);
}

private Object callAsyncMethod(TestBenchElement element, String methodName,
Object... args) {
String objectRef = formatArgumentRef(0);
String argRefs = IntStream.range(1, args.length + 1)
.mapToObj(this::formatArgumentRef)
.collect(Collectors.joining(","));
String callbackRef = formatArgumentRef(args.length + 1);
String script = String.format("%s.%s(%s).then(%s)", objectRef,
methodName, argRefs, callbackRef);
Object[] scriptArgs = Stream.concat(Stream.of(element), Stream.of(args))
.toArray();
return getJavascriptExecutor().executeAsyncScript(script, scriptArgs);
}

private JavascriptExecutor getJavascriptExecutor() {
return (JavascriptExecutor) getDriver();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import {
ConnectClient,
InvalidSessionMiddleware,
} from '@vaadin/fusion-frontend';

const client = new ConnectClient({
prefix: 'connect',
middlewares: [
new InvalidSessionMiddleware(async () => {
location.reload();
return {
error: true
}
})
],
});

export default client;
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import org.junit.After;
import org.junit.Assert;
import org.junit.Test;
import org.openqa.selenium.By;
import org.openqa.selenium.JavascriptExecutor;
import org.openqa.selenium.StaleElementReferenceException;

import com.vaadin.flow.component.button.testbench.ButtonElement;
import com.vaadin.flow.component.login.testbench.LoginFormElement;
Expand Down Expand Up @@ -37,8 +42,12 @@ public void tearDown() {
private void checkForBrowserErrors() {
checkLogsForErrors(msg -> {
return msg.contains(
"admin-only/secret.txt - Failed to load resource: the "
"/admin-only/secret.txt - Failed to load resource: the "
+ "server responded with a status of 403")
|| msg.contains("/connect/") && msg.contains("Failed to "
+ "load resource: the server responded with "
+ "a status of 401")
|| msg.contains("expected \"200 OK\" response, but got 401")
|| msg.contains("webpack-internal://");
});
}
Expand Down Expand Up @@ -233,6 +242,22 @@ public void public_app_resources_available_for_all() {
shouldBeTextFile.contains("Public document for all users"));
}

@Test
public void reload_when_anonymous_session_expires() {
open("");
simulateNewServer();
assertPublicEndpointReloadsPage();
}

@Test
public void reload_when_user_session_expires() {
open("login");
loginUser();
simulateNewServer();
navigateTo("private", false);
assertLoginViewShown();
}

protected void navigateTo(String path) {
navigateTo(path, true);
}
Expand All @@ -248,7 +273,7 @@ private TestBenchElement getMainView() {
return waitUntil(driver -> $("*").id("main-view"));
}

private void assertLoginViewShown() {
protected void assertLoginViewShown() {
assertPathShown("login");
waitUntil(driver -> $(LoginOverlayElement.class).exists());
}
Expand Down Expand Up @@ -329,4 +354,58 @@ protected List<MenuItem> getMenuItems() {
}).collect(Collectors.toList());
}

private TestBenchElement getPublicView() {
return waitUntil(driver -> $("public-view").get(0));
}

protected void simulateNewServer() {
TestBenchElement mainView = waitUntil(driver -> $("main-view").get(0));
callAsyncMethod(mainView, "invalidateSessionIfPresent");
}

protected void assertPublicEndpointReloadsPage() {
String timeBefore = getPublicView().findElement(By.id("time"))
.getText();
Assert.assertNotNull(timeBefore);
try {
getPublicView().callFunction("updateTime");
} catch (StaleElementReferenceException e) {
// Page reload causes the exception, ignore
}
String timeAfter = getPublicView().findElement(By.id("time")).getText();
Assert.assertNotNull(timeAfter);
Assert.assertNotEquals(timeAfter, timeBefore);
}

protected void assertPublicEndpointWorks() {
String timeBefore = getPublicView().findElement(By.id("time"))
.getText();
Assert.assertNotNull(timeBefore);
callAsyncMethod(getPublicView(), "updateTime");
String timeAfter = getPublicView().findElement(By.id("time")).getText();
Assert.assertNotNull(timeAfter);
Assert.assertNotEquals(timeAfter, timeBefore);
}

private String formatArgumentRef(int index) {
return String.format("arguments[%d]", index);
}

private JavascriptExecutor getJavascriptExecutor() {
return (JavascriptExecutor) getDriver();
}

private Object callAsyncMethod(TestBenchElement element, String methodName,
Object... args) {
String objectRef = formatArgumentRef(0);
String argRefs = IntStream.range(1, args.length + 1)
.mapToObj(this::formatArgumentRef)
.collect(Collectors.joining(","));
String callbackRef = formatArgumentRef(args.length + 1);
String script = String.format("%s.%s(%s).then(%s)", objectRef,
methodName, argRefs, callbackRef);
Object[] scriptArgs = Stream.concat(Stream.of(element), Stream.of(args))
.toArray();
return getJavascriptExecutor().executeAsyncScript(script, scriptArgs);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,35 @@
package com.vaadin.flow.spring.security;

import javax.crypto.SecretKey;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpStatus;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.builders.WebSecurity;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
import org.springframework.security.config.annotation.web.configurers.ExpressionUrlAuthorizationConfigurer;
import org.springframework.security.config.annotation.web.configurers.FormLoginConfigurer;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
import org.springframework.security.web.access.AccessDeniedHandler;
import org.springframework.security.web.access.AccessDeniedHandlerImpl;
import org.springframework.security.web.access.DelegatingAccessDeniedHandler;
import org.springframework.security.web.access.RequestMatcherDelegatingAccessDeniedHandler;
import org.springframework.security.web.authentication.HttpStatusEntryPoint;
import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint;
import org.springframework.security.web.csrf.CsrfException;
import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.AnyRequestMatcher;
import org.springframework.security.web.util.matcher.OrRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;

Expand Down Expand Up @@ -93,6 +108,15 @@ protected void configure(HttpSecurity http) throws Exception {
SecurityContextHolder.setStrategyName(
VaadinAwareSecurityContextHolderStrategy.class.getName());

// Respond with 401 Unauthorized HTTP status code for unauthorized
// requests for protected Fusion endpoints, so that the response could
// be handled on the client side using e.g. `InvalidSessionMiddleware`.
http.exceptionHandling()
.accessDeniedHandler(createAccessDeniedHandler())
.defaultAuthenticationEntryPointFor(
new HttpStatusEntryPoint(HttpStatus.UNAUTHORIZED),
requestUtil::isEndpointRequest);

// Vaadin has its own CSRF protection.
// Spring CSRF is not compatible with Vaadin internal requests
http.csrf().ignoringRequestMatchers(
Expand Down Expand Up @@ -193,6 +217,9 @@ protected void setLoginView(HttpSecurity http, String fusionLoginViewPath,
formLogin.successHandler(
getVaadinSavedRequestAwareAuthenticationSuccessHandler(http));
http.logout().logoutSuccessUrl(logoutUrl);
http.exceptionHandling().defaultAuthenticationEntryPointFor(
new LoginUrlAuthenticationEntryPoint(fusionLoginViewPath),
AnyRequestMatcher.INSTANCE);
viewAccessChecker.setLoginView(fusionLoginViewPath);
}

Expand Down Expand Up @@ -248,6 +275,9 @@ protected void setLoginView(HttpSecurity http,
getVaadinSavedRequestAwareAuthenticationSuccessHandler(http));
http.csrf().ignoringAntMatchers(loginPath);
http.logout().logoutSuccessUrl(logoutUrl);
http.exceptionHandling().defaultAuthenticationEntryPointFor(
new LoginUrlAuthenticationEntryPoint(loginPath),
AnyRequestMatcher.INSTANCE);
viewAccessChecker.setLoginView(flowLoginView);
}

Expand Down Expand Up @@ -304,4 +334,32 @@ private VaadinSavedRequestAwareAuthenticationSuccessHandler getVaadinSavedReques
}
return vaadinSavedRequestAwareAuthenticationSuccessHandler;
}

private AccessDeniedHandler createAccessDeniedHandler() {
final AccessDeniedHandler defaultHandler = new AccessDeniedHandlerImpl();

final AccessDeniedHandler http401UnauthorizedHandler = new Http401UnauthorizedAccessDeniedHandler();

final LinkedHashMap<Class<? extends AccessDeniedException>, AccessDeniedHandler> exceptionHandlers = new LinkedHashMap<>();
exceptionHandlers.put(CsrfException.class, http401UnauthorizedHandler);

final LinkedHashMap<RequestMatcher, AccessDeniedHandler> matcherHandlers = new LinkedHashMap<>();
matcherHandlers.put(requestUtil::isEndpointRequest,
new DelegatingAccessDeniedHandler(exceptionHandlers,
new AccessDeniedHandlerImpl()));

return new RequestMatcherDelegatingAccessDeniedHandler(matcherHandlers,
defaultHandler);
}

private static class Http401UnauthorizedAccessDeniedHandler
implements AccessDeniedHandler {
@Override
public void handle(HttpServletRequest request,
HttpServletResponse response,
AccessDeniedException accessDeniedException)
throws IOException, ServletException {
response.setStatus(HttpStatus.UNAUTHORIZED.value());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ protected Stream<String> getExcludedPatterns() {
"com\\.vaadin\\.flow\\.spring\\.security\\.SerializedJwtSplitCookieRepository",
"com\\.vaadin\\.flow\\.spring\\.security\\.VaadinAwareSecurityContextHolderStrategy",
"com\\.vaadin\\.flow\\.spring\\.security\\.VaadinWebSecurityConfigurerAdapter",
"com\\.vaadin\\.flow\\.spring\\.security\\.VaadinWebSecurityConfigurerAdapter",
"com\\.vaadin\\.flow\\.spring\\.security\\.VaadinWebSecurityConfigurerAdapter\\$Http401UnauthorizedAccessDeniedHandler",
"com\\.vaadin\\.flow\\.spring\\.security\\.VaadinDefaultRequestCache",
"com\\.vaadin\\.flow\\.spring\\.security\\.VaadinSavedRequestAwareAuthenticationSuccessHandler",
"com\\.vaadin\\.flow\\.spring\\.security\\.VaadinSavedRequestAwareAuthenticationSuccessHandler\\$RedirectStrategy",
Expand Down