Skip to content

Commit

Permalink
Fixes #3622 - Make sure RequestDispatcher.forward honors parameter pr…
Browse files Browse the repository at this point in the history
…ecendence (#3623)
  • Loading branch information
mnriem committed Jan 14, 2024
1 parent 09cd53f commit 924accb
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpUpgradeHandler;
import java.security.Principal;
import java.util.Map;

/**
* The WebApplicationRequest API.
Expand All @@ -40,6 +41,13 @@
*/
public interface WebApplicationRequest extends HttpServletRequest {

/**
* Get the modifiable parameter map.
*
* @return the modifiable parameter map.
*/
Map<String, String[]> getModifiableParameterMap();

/**
* {@return the multipartConfig}
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ public AsyncHttpDispatchWrapper(HttpServletRequest request) {
wrapperAttributes.add("piranha.response");
}

@Override
public Map<String, String[]> getModifiableParameterMap() {
return wrapperParameters;
}

@Override
public HttpServletRequest getRequest() {
return (HttpServletRequest) super.getRequest();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,7 @@ private void dispatchSyncForward(
*/
WebApplicationRequest request = unwrap(servletRequest, WebApplicationRequest.class);

/**
/*
* Set the forward attributes.
*/
request.setAttribute(FORWARD_CONTEXT_PATH, request.getContextPath());
Expand All @@ -733,12 +733,12 @@ private void dispatchSyncForward(
request.setAttribute(FORWARD_REQUEST_URI, request.getRequestURI());
request.setAttribute(FORWARD_SERVLET_PATH, request.getServletPath());

/**
/*
* Set the new dispatcher type.
*/
request.setDispatcherType(FORWARD);

/**
/*
* Set the new servlet path and the new query string.
*/
if (path != null) {
Expand All @@ -749,6 +749,33 @@ private void dispatchSyncForward(
request.setQueryString(request.getQueryString());
}

/*
* Aggregate the request parameters with giving the new parameter values
* precedence over the old ones.
*/
if (request.getQueryString() != null) {
String queryString = request.getQueryString();
Map<String, String[]> parameters = request.getModifiableParameterMap();
for (String param : queryString.split("&")) {
String[] pair = param.split("=");
String key = URLDecoder.decode(pair[0], StandardCharsets.UTF_8);
String value = "";
if (pair.length > 1) {
value = URLDecoder.decode(pair[1], StandardCharsets.UTF_8);
}
String[] values = parameters.get(key);
if (values == null) {
values = new String[]{value};
parameters.put(key, values);
} else {
String[] newValues = new String[values.length + 1];
System.arraycopy(values, 0, newValues, 1, values.length);
newValues[0] = value;
parameters.put(key, newValues);
}
}
}

invocationFinder.addFilters(FORWARD, servletInvocation, request.getServletPath(), "");

if (servletInvocation.getServletEnvironment() != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,11 @@ public String getMethod() {
return method;
}

@Override
public Map<String, String[]> getModifiableParameterMap() {
return parameters;
}

@Override
public MultipartConfigElement getMultipartConfig() {
return multipartConfig;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
*/
package cloud.piranha.core.impl;

import cloud.piranha.core.api.WebApplication;
import java.io.IOException;
import jakarta.servlet.RequestDispatcher;
import jakarta.servlet.ServletException;
Expand Down Expand Up @@ -86,12 +87,12 @@ void testForward2() throws Exception {

TestWebApplicationRequest request = new TestWebApplicationRequest();
request.setWebApplication(webApplication);

TestWebApplicationResponse response = new TestWebApplicationResponse();
response.setWebApplication(webApplication);

webApplication.linkRequestAndResponse(request, response);

RequestDispatcher dispatcher = webApplication.getRequestDispatcher("/Snoop");
dispatcher.forward(request, response);
String responseText = new String(response.getResponseBytes());
Expand Down Expand Up @@ -136,34 +137,83 @@ void testForward4() throws Exception {
RequestDispatcher dispatcher = webApp.getRequestDispatcher("/Runtime");
assertThrows(RuntimeException.class, () -> dispatcher.forward(request, response));
}

/**
* Test that a request given to the request dispatcher upon forward is the
* same as the original request.
*/
@Test
void testForwardNoWrapping() throws Exception {
WebApplication webApplication = new DefaultWebApplicationBuilder()
.servlet("NoWrapping", new HttpServlet() {
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
resp.getWriter().print(req.toString());
}
})
.servletMapping("NoWrapping", "/nowrapping")
.build();
webApplication.initialize();
webApplication.start();
DefaultWebApplicationRequest request = new DefaultWebApplicationRequest();
DefaultWebApplicationResponse response = new DefaultWebApplicationResponse();
ByteArrayOutputStream byteOutput = new ByteArrayOutputStream();
response.getWebApplicationOutputStream().setOutputStream(byteOutput);
response.setBodyOnly(true);
DefaultWebApplication webApplication = new DefaultWebApplication();
webApplication.addServlet("NoWrapping", new HttpServlet() {
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
response.getWriter().print(req.toString());
}
});
webApplication.addServletMapping("NoWrapping", "/nowrapping");
webApplication.initialize();
webApplication.start();
RequestDispatcher dispatcher = webApplication.getRequestDispatcher("/nowrapping");
assertNotNull(dispatcher);
dispatcher.forward(request, response);
assertEquals(request.toString(), byteOutput.toString("UTF-8"));
}

/**
* Test a forward with a query string.
*/
@Test
void testForwardWithQueryString() throws Exception {
WebApplication webApplication = new DefaultWebApplicationBuilder()
.servlet("Forward", new HttpServlet() {
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
/*
* Force the original parameter to be parsed out of the query string.
*/
req.getParameter("p");
/*
* Dispatch with a new parameter value.
*/
getServletContext()
.getRequestDispatcher("/forward2?p=New")
.forward(req, resp);
}
})
.servletMapping("Forward", "/forward")
.servlet("Forward2", new HttpServlet() {
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
/*
* We should be getting the new parameter value here as it takes precendence.
*/
resp.getWriter().print(req.getParameterMap().get("p")[0]);
}
})
.servletMapping("Forward2", "/forward2")
.build();
webApplication.initialize();
webApplication.start();
DefaultWebApplicationRequest request = new DefaultWebApplicationRequest();
request.setWebApplication(webApplication);
request.setServletPath("/forward");
request.setQueryString("p=Original");
DefaultWebApplicationResponse response = new DefaultWebApplicationResponse();
ByteArrayOutputStream byteOutput = new ByteArrayOutputStream();
response.getWebApplicationOutputStream().setOutputStream(byteOutput);
response.setBodyOnly(true);
response.setWebApplication(webApplication);
webApplication.service(request, response);
assertEquals("New", byteOutput.toString());
}

/**
* Test include method.
*
Expand All @@ -178,12 +228,12 @@ void testInclude() throws Exception {

TestWebApplicationRequest request = new TestWebApplicationRequest();
request.setWebApplication(webApplication);

TestWebApplicationResponse response = new TestWebApplicationResponse();
response.setWebApplication(webApplication);

webApplication.linkRequestAndResponse(request, response);

RequestDispatcher dispatcher = webApplication.getNamedDispatcher("Echo");
dispatcher.include(request, response);
response.flushBuffer();
Expand Down Expand Up @@ -213,29 +263,30 @@ void testInclude2() throws Exception {
webApp.stop();
assertTrue(responseText.contains("ECHO"));
}

/**
* Test that a request given to the request dispatcher upon include is the
* same as the original request.
*/
@Test
void testIncludeNoWrapping() throws Exception {
DefaultWebApplication webApplication = new DefaultWebApplication();
webApplication.addServlet("NoWrapping", new HttpServlet() {
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
req.getRequestDispatcher("/nowrapping2").include(req, resp);
resp.flushBuffer();
}
});
webApplication.addServlet("NoWrapping2", new HttpServlet() {
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
resp.getWriter().print(req.toString());
}
});
webApplication.addServletMapping("NoWrapping", "/nowrapping");
webApplication.addServletMapping("NoWrapping2", "/nowrapping2");
WebApplication webApplication = new DefaultWebApplicationBuilder()
.servlet("NoWrapping", new HttpServlet() {
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
req.getRequestDispatcher("/nowrapping2").include(req, resp);
resp.flushBuffer();
}
})
.servletMapping("NoWrapping", "/nowrapping")
.servlet("NoWrapping2", new HttpServlet() {
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
resp.getWriter().print(req.toString());
}
})
.servletMapping("NoWrapping2", "/nowrapping2")
.build();
webApplication.initialize();
webApplication.start();
DefaultWebApplicationRequest request = new DefaultWebApplicationRequest();
Expand Down Expand Up @@ -296,8 +347,9 @@ void testErrorDispatcher2() throws Exception {
assertTrue(responseText.contains(RequestDispatcher.ERROR_EXCEPTION_TYPE));
assertTrue(responseText.contains(IOException.class.getName()));
}

static class TestSendError extends HttpServlet {

@Override
protected void doGet(HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException {
if (request.getParameter("send-error") != null) {
Expand Down

0 comments on commit 924accb

Please sign in to comment.