Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #3622 - Make sure RequestDispatcher.forward honors parameter precendence #3623

Merged
merged 1 commit into from
Jan 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpUpgradeHandler;
import java.security.Principal;
import java.util.Map;

/**
* The WebApplicationRequest API.
Expand All @@ -40,6 +41,13 @@
*/
public interface WebApplicationRequest extends HttpServletRequest {

/**
* Get the modifiable parameter map.
*
* @return the modifiable parameter map.
*/
Map<String, String[]> getModifiableParameterMap();

/**
* {@return the multipartConfig}
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ public AsyncHttpDispatchWrapper(HttpServletRequest request) {
wrapperAttributes.add("piranha.response");
}

@Override
public Map<String, String[]> getModifiableParameterMap() {
return wrapperParameters;
}

@Override
public HttpServletRequest getRequest() {
return (HttpServletRequest) super.getRequest();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,7 @@ private void dispatchSyncForward(
*/
WebApplicationRequest request = unwrap(servletRequest, WebApplicationRequest.class);

/**
/*
* Set the forward attributes.
*/
request.setAttribute(FORWARD_CONTEXT_PATH, request.getContextPath());
Expand All @@ -733,12 +733,12 @@ private void dispatchSyncForward(
request.setAttribute(FORWARD_REQUEST_URI, request.getRequestURI());
request.setAttribute(FORWARD_SERVLET_PATH, request.getServletPath());

/**
/*
* Set the new dispatcher type.
*/
request.setDispatcherType(FORWARD);

/**
/*
* Set the new servlet path and the new query string.
*/
if (path != null) {
Expand All @@ -749,6 +749,33 @@ private void dispatchSyncForward(
request.setQueryString(request.getQueryString());
}

/*
* Aggregate the request parameters with giving the new parameter values
* precedence over the old ones.
*/
if (request.getQueryString() != null) {
String queryString = request.getQueryString();
Map<String, String[]> parameters = request.getModifiableParameterMap();
for (String param : queryString.split("&")) {
String[] pair = param.split("=");
String key = URLDecoder.decode(pair[0], StandardCharsets.UTF_8);
String value = "";
if (pair.length > 1) {
value = URLDecoder.decode(pair[1], StandardCharsets.UTF_8);
}
String[] values = parameters.get(key);
if (values == null) {
values = new String[]{value};
parameters.put(key, values);
} else {
String[] newValues = new String[values.length + 1];
System.arraycopy(values, 0, newValues, 1, values.length);
newValues[0] = value;
parameters.put(key, newValues);
}
}
}

invocationFinder.addFilters(FORWARD, servletInvocation, request.getServletPath(), "");

if (servletInvocation.getServletEnvironment() != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,11 @@ public String getMethod() {
return method;
}

@Override
public Map<String, String[]> getModifiableParameterMap() {
return parameters;
}

@Override
public MultipartConfigElement getMultipartConfig() {
return multipartConfig;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
*/
package cloud.piranha.core.impl;

import cloud.piranha.core.api.WebApplication;
import java.io.IOException;
import jakarta.servlet.RequestDispatcher;
import jakarta.servlet.ServletException;
Expand Down Expand Up @@ -86,12 +87,12 @@ void testForward2() throws Exception {

TestWebApplicationRequest request = new TestWebApplicationRequest();
request.setWebApplication(webApplication);

TestWebApplicationResponse response = new TestWebApplicationResponse();
response.setWebApplication(webApplication);

webApplication.linkRequestAndResponse(request, response);

RequestDispatcher dispatcher = webApplication.getRequestDispatcher("/Snoop");
dispatcher.forward(request, response);
String responseText = new String(response.getResponseBytes());
Expand Down Expand Up @@ -136,34 +137,83 @@ void testForward4() throws Exception {
RequestDispatcher dispatcher = webApp.getRequestDispatcher("/Runtime");
assertThrows(RuntimeException.class, () -> dispatcher.forward(request, response));
}

/**
* Test that a request given to the request dispatcher upon forward is the
* same as the original request.
*/
@Test
void testForwardNoWrapping() throws Exception {
WebApplication webApplication = new DefaultWebApplicationBuilder()
.servlet("NoWrapping", new HttpServlet() {
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
resp.getWriter().print(req.toString());
}
})
.servletMapping("NoWrapping", "/nowrapping")
.build();
webApplication.initialize();
webApplication.start();
DefaultWebApplicationRequest request = new DefaultWebApplicationRequest();
DefaultWebApplicationResponse response = new DefaultWebApplicationResponse();
ByteArrayOutputStream byteOutput = new ByteArrayOutputStream();
response.getWebApplicationOutputStream().setOutputStream(byteOutput);
response.setBodyOnly(true);
DefaultWebApplication webApplication = new DefaultWebApplication();
webApplication.addServlet("NoWrapping", new HttpServlet() {
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
response.getWriter().print(req.toString());
}
});
webApplication.addServletMapping("NoWrapping", "/nowrapping");
webApplication.initialize();
webApplication.start();
RequestDispatcher dispatcher = webApplication.getRequestDispatcher("/nowrapping");
assertNotNull(dispatcher);
dispatcher.forward(request, response);
assertEquals(request.toString(), byteOutput.toString("UTF-8"));
}

/**
* Test a forward with a query string.
*/
@Test
void testForwardWithQueryString() throws Exception {
WebApplication webApplication = new DefaultWebApplicationBuilder()
.servlet("Forward", new HttpServlet() {
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
/*
* Force the original parameter to be parsed out of the query string.
*/
req.getParameter("p");
/*
* Dispatch with a new parameter value.
*/
getServletContext()
.getRequestDispatcher("/forward2?p=New")
.forward(req, resp);
}
})
.servletMapping("Forward", "/forward")
.servlet("Forward2", new HttpServlet() {
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
/*
* We should be getting the new parameter value here as it takes precendence.
*/
resp.getWriter().print(req.getParameterMap().get("p")[0]);
}
})
.servletMapping("Forward2", "/forward2")
.build();
webApplication.initialize();
webApplication.start();
DefaultWebApplicationRequest request = new DefaultWebApplicationRequest();
request.setWebApplication(webApplication);
request.setServletPath("/forward");
request.setQueryString("p=Original");
DefaultWebApplicationResponse response = new DefaultWebApplicationResponse();
ByteArrayOutputStream byteOutput = new ByteArrayOutputStream();
response.getWebApplicationOutputStream().setOutputStream(byteOutput);
response.setBodyOnly(true);
response.setWebApplication(webApplication);
webApplication.service(request, response);
assertEquals("New", byteOutput.toString());
}

/**
* Test include method.
*
Expand All @@ -178,12 +228,12 @@ void testInclude() throws Exception {

TestWebApplicationRequest request = new TestWebApplicationRequest();
request.setWebApplication(webApplication);

TestWebApplicationResponse response = new TestWebApplicationResponse();
response.setWebApplication(webApplication);

webApplication.linkRequestAndResponse(request, response);

RequestDispatcher dispatcher = webApplication.getNamedDispatcher("Echo");
dispatcher.include(request, response);
response.flushBuffer();
Expand Down Expand Up @@ -213,29 +263,30 @@ void testInclude2() throws Exception {
webApp.stop();
assertTrue(responseText.contains("ECHO"));
}

/**
* Test that a request given to the request dispatcher upon include is the
* same as the original request.
*/
@Test
void testIncludeNoWrapping() throws Exception {
DefaultWebApplication webApplication = new DefaultWebApplication();
webApplication.addServlet("NoWrapping", new HttpServlet() {
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
req.getRequestDispatcher("/nowrapping2").include(req, resp);
resp.flushBuffer();
}
});
webApplication.addServlet("NoWrapping2", new HttpServlet() {
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
resp.getWriter().print(req.toString());
}
});
webApplication.addServletMapping("NoWrapping", "/nowrapping");
webApplication.addServletMapping("NoWrapping2", "/nowrapping2");
WebApplication webApplication = new DefaultWebApplicationBuilder()
.servlet("NoWrapping", new HttpServlet() {
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
req.getRequestDispatcher("/nowrapping2").include(req, resp);
resp.flushBuffer();
}
})
.servletMapping("NoWrapping", "/nowrapping")
.servlet("NoWrapping2", new HttpServlet() {
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
resp.getWriter().print(req.toString());
}
})
.servletMapping("NoWrapping2", "/nowrapping2")
.build();
webApplication.initialize();
webApplication.start();
DefaultWebApplicationRequest request = new DefaultWebApplicationRequest();
Expand Down Expand Up @@ -296,8 +347,9 @@ void testErrorDispatcher2() throws Exception {
assertTrue(responseText.contains(RequestDispatcher.ERROR_EXCEPTION_TYPE));
assertTrue(responseText.contains(IOException.class.getName()));
}

static class TestSendError extends HttpServlet {

@Override
protected void doGet(HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException {
if (request.getParameter("send-error") != null) {
Expand Down