Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #3620 - Make sure a RequestDispatcher include does not wrap the request #3621

Merged
merged 3 commits into from
Jan 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,14 @@ default void setRequestedSessionId(String requestedSessionId) {
*/
void setWebApplication(WebApplication webApplication);

/**
* Set the path info.
*
* @param pathInfo the path info.
*/
default void setPathInfo(String pathInfo) {
}

/**
* Set the query string.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Stream;

/**
* The default ServletRequestDispatcher.
Expand Down Expand Up @@ -250,76 +249,67 @@ public void forward(ServletRequest servletRequest, ServletResponse servletRespon
*/
@Override
public void include(ServletRequest servletRequest, ServletResponse servletResponse) throws ServletException, IOException {
DefaultWebApplicationRequest includedRequest = new DefaultWebApplicationRequest();
HttpServletRequest originalRequest = unwrap(servletRequest, HttpServletRequest.class);

// Change the underlying request if the request was wrapped
ServletRequestWrapper wrapper = (servletRequest instanceof ServletRequestWrapper wrapped)
? getLastWrapper(wrapped)
: new HttpServletRequestWrapper(originalRequest);
wrapper.setRequest(includedRequest);

includedRequest.setWebApplication(servletEnvironment.getWebApplication());
includedRequest.setMultipartConfig(servletEnvironment.getMultipartConfig());
includedRequest.setContextPath(originalRequest.getContextPath());

includedRequest.setServletPath(path == null ? "/" + servletEnvironment.getServletName() : getServletPath(path));
includedRequest.setDispatcherType(INCLUDE);
includedRequest.setPathInfo(null);
includedRequest.setQueryString(originalRequest.getQueryString());
HttpSession session = originalRequest.getSession(false);
if (session != null) {
includedRequest.setCurrentSessionId(session.getId());
/*
* We only do includes for a HttpServletRequest/Response pair.
*/
if (!(servletRequest instanceof HttpServletRequest)) {
throw new ServletException("Request is not a HttpServletRequest");
}

copyAttributesFromRequest(originalRequest, includedRequest, attributeName -> true);

if (path != null) {
includedRequest.setAttribute(INCLUDE_CONTEXT_PATH, includedRequest.getContextPath());
includedRequest.setAttribute(INCLUDE_SERVLET_PATH, includedRequest.getServletPath());
includedRequest.setAttribute(INCLUDE_PATH_INFO, includedRequest.getPathInfo());
includedRequest.setAttribute(INCLUDE_REQUEST_URI, includedRequest.getRequestURI());
includedRequest.setAttribute(INCLUDE_QUERY_STRING, getQueryString(path));
if (!(servletResponse instanceof HttpServletResponse)) {
throw new ServletException("Response is not a HttpServletResponse");
}
CurrentRequestHolder currentRequestHolder = updateCurrentRequest(originalRequest, includedRequest);

invocationFinder.addFilters(INCLUDE, servletInvocation, includedRequest.getServletPath(), "");
/*
* Unwrap the passed ServletRequest to the underlying WebApplicationRequest.
*/
WebApplicationRequest request = unwrap(servletRequest, WebApplicationRequest.class);

/**
* Set the include attributes.
*/
request.setAttribute(INCLUDE_CONTEXT_PATH, request.getContextPath());
request.setAttribute(INCLUDE_PATH_INFO, request.getPathInfo());
request.setAttribute(INCLUDE_QUERY_STRING, request.getQueryString());
request.setAttribute(INCLUDE_REQUEST_URI, request.getRequestURI());
request.setAttribute(INCLUDE_SERVLET_PATH, request.getServletPath());

/**
* Set the new dispatcher type.
*/
request.setDispatcherType(INCLUDE);
request.setServletPath(path == null ? "/" + servletEnvironment.getServletName() : getServletPath(path));
request.setPathInfo(null);

invocationFinder.addFilters(INCLUDE, servletInvocation, request.getServletPath(), "");

if (originalRequest instanceof DefaultWebApplicationRequest defaultRequest) {
// 12.3.1
// "the HttpServletMapping is not available for servlets that have been obtained with a call to
// ServletContext.getNamedDispatcher()."
//
// Although not spelled out, this means in practice the included Servlet uses the
// mapping of the forwarding servlet.
includedRequest.setHttpServletMapping(defaultRequest.getHttpServletMapping());
if (servletInvocation.getServletEnvironment() != null) {
request.setAsyncSupported(request.isAsyncSupported() && isAsyncSupportedInChain());
}

// After setting the include attributes and adding filters, reset the servlet path
includedRequest.setServletPath(originalRequest.getServletPath());

/*
* Execute the filter chain.
*/
try {
servletEnvironment.getWebApplication().linkRequestAndResponse(includedRequest, servletResponse);

servletInvocation.getFilterChain().doFilter(wrapper, servletResponse);

// After the include, we need to copy the attributes that were set in the new request to the old one
// but not include the "INCLUDE_" attributes that were set previously
copyAttributesFromRequest(includedRequest, originalRequest, attributeName
-> originalRequest.getAttribute(attributeName) == null && Stream.of(
INCLUDE_QUERY_STRING,
INCLUDE_CONTEXT_PATH,
INCLUDE_MAPPING,
INCLUDE_PATH_INFO,
INCLUDE_REQUEST_URI,
INCLUDE_SERVLET_PATH).noneMatch(attributeName::equals));

servletEnvironment.getWebApplication().unlinkRequestAndResponse(includedRequest, servletResponse);
} catch (Exception e) {
servletEnvironment.getWebApplication().linkRequestAndResponse(servletRequest, servletResponse);
servletInvocation.getFilterChain().doFilter(servletRequest, servletResponse);
servletEnvironment.getWebApplication().unlinkRequestAndResponse(servletRequest, servletResponse);
} catch(Exception e) {
rethrow(e);
} finally {
restoreCurrentRequest(currentRequestHolder, originalRequest);
wrapper.setRequest(originalRequest);
/*
* Set servlet path and path info back to original values.
*/
request.setServletPath((String) request.getAttribute(INCLUDE_SERVLET_PATH));
request.setPathInfo((String) request.getAttribute(INCLUDE_PATH_INFO));

/*
* Remove include attributes.
*/
request.removeAttribute(INCLUDE_CONTEXT_PATH);
request.removeAttribute(INCLUDE_PATH_INFO);
request.removeAttribute(INCLUDE_QUERY_STRING);
request.removeAttribute(INCLUDE_REQUEST_URI);
request.removeAttribute(INCLUDE_SERVLET_PATH);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,8 @@ public RequestDispatcher getRequestDispatcher(String path) {
if (!resolved.startsWith(rootContext)) {
resolved = rootContext.resolveSibling(resolved);
}
return webApplication.getRequestDispatcher(resolved.toString());
String servletPath = resolved.toString().replace('\\', '/');
return webApplication.getRequestDispatcher(servletPath);
}

@Override
Expand Down Expand Up @@ -1260,11 +1261,7 @@ public void setParameter(String name, String[] values) {
parameters.put(name, values);
}

/**
* Set the path info.
*
* @param pathInfo the path info.
*/
@Override
public void setPathInfo(String pathInfo) {
this.pathInfo = pathInfo;
}
Expand Down Expand Up @@ -1433,11 +1430,6 @@ public AsyncContext startAsync(ServletRequest request, ServletResponse response)
return asyncContext;
}

@Override
public String toString() {
return getRequestURIWithQueryString() + " " + super.toString();
}

/**
* Unwrap the request.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ void testForward4() throws Exception {
* same as the original request.
*/
@Test
void testForward5() throws Exception {
void testForwardNoWrapping() throws Exception {
DefaultWebApplicationRequest request = new DefaultWebApplicationRequest();
DefaultWebApplicationResponse response = new DefaultWebApplicationResponse();
ByteArrayOutputStream byteOutput = new ByteArrayOutputStream();
Expand Down Expand Up @@ -213,6 +213,45 @@ void testInclude2() throws Exception {
webApp.stop();
assertTrue(responseText.contains("ECHO"));
}

/**
* Test that a request given to the request dispatcher upon include is the
* same as the original request.
*/
@Test
void testIncludeNoWrapping() throws Exception {
DefaultWebApplication webApplication = new DefaultWebApplication();
webApplication.addServlet("NoWrapping", new HttpServlet() {
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
req.getRequestDispatcher("/nowrapping2").include(req, resp);
resp.flushBuffer();
}
});
webApplication.addServlet("NoWrapping2", new HttpServlet() {
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
resp.getWriter().print(req.toString());
}
});
webApplication.addServletMapping("NoWrapping", "/nowrapping");
webApplication.addServletMapping("NoWrapping2", "/nowrapping2");
webApplication.initialize();
webApplication.start();
DefaultWebApplicationRequest request = new DefaultWebApplicationRequest();
request.setWebApplication(webApplication);
DefaultWebApplicationResponse response = new DefaultWebApplicationResponse();
response.setWebApplication(webApplication);
ByteArrayOutputStream byteOutput = new ByteArrayOutputStream();
response.getWebApplicationOutputStream().setOutputStream(byteOutput);
response.setBodyOnly(true);
webApplication.linkRequestAndResponse(request, response);
RequestDispatcher dispatcher = webApplication.getRequestDispatcher("/nowrapping");
assertNotNull(dispatcher);
dispatcher.forward(request, response);
assertEquals(request.toString(), byteOutput.toString("UTF-8"));
webApplication.unlinkRequestAndResponse(request, response);
}

@Test
void testErrorDispatcher() throws Exception {
Expand Down