diff --git a/core/api/src/main/java/cloud/piranha/core/api/WebApplicationRequest.java b/core/api/src/main/java/cloud/piranha/core/api/WebApplicationRequest.java index 3d45cbef6..b2c707de7 100644 --- a/core/api/src/main/java/cloud/piranha/core/api/WebApplicationRequest.java +++ b/core/api/src/main/java/cloud/piranha/core/api/WebApplicationRequest.java @@ -32,6 +32,7 @@ import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpUpgradeHandler; import java.security.Principal; +import java.util.Map; /** * The WebApplicationRequest API. @@ -40,6 +41,13 @@ */ public interface WebApplicationRequest extends HttpServletRequest { + /** + * Get the modifiable parameter map. + * + * @return the modifiable parameter map. + */ + Map getModifiableParameterMap(); + /** * {@return the multipartConfig} */ diff --git a/core/impl/src/main/java/cloud/piranha/core/impl/AsyncHttpDispatchWrapper.java b/core/impl/src/main/java/cloud/piranha/core/impl/AsyncHttpDispatchWrapper.java index e04589dd9..c5f1d2fdc 100644 --- a/core/impl/src/main/java/cloud/piranha/core/impl/AsyncHttpDispatchWrapper.java +++ b/core/impl/src/main/java/cloud/piranha/core/impl/AsyncHttpDispatchWrapper.java @@ -112,6 +112,11 @@ public AsyncHttpDispatchWrapper(HttpServletRequest request) { wrapperAttributes.add("piranha.response"); } + @Override + public Map getModifiableParameterMap() { + return wrapperParameters; + } + @Override public HttpServletRequest getRequest() { return (HttpServletRequest) super.getRequest(); diff --git a/core/impl/src/main/java/cloud/piranha/core/impl/DefaultServletRequestDispatcher.java b/core/impl/src/main/java/cloud/piranha/core/impl/DefaultServletRequestDispatcher.java index 6c44951e7..dea6a8957 100644 --- a/core/impl/src/main/java/cloud/piranha/core/impl/DefaultServletRequestDispatcher.java +++ b/core/impl/src/main/java/cloud/piranha/core/impl/DefaultServletRequestDispatcher.java @@ -724,7 +724,7 @@ private void dispatchSyncForward( */ WebApplicationRequest request = unwrap(servletRequest, WebApplicationRequest.class); - /** + /* * Set the forward attributes. */ request.setAttribute(FORWARD_CONTEXT_PATH, request.getContextPath()); @@ -733,12 +733,12 @@ private void dispatchSyncForward( request.setAttribute(FORWARD_REQUEST_URI, request.getRequestURI()); request.setAttribute(FORWARD_SERVLET_PATH, request.getServletPath()); - /** + /* * Set the new dispatcher type. */ request.setDispatcherType(FORWARD); - /** + /* * Set the new servlet path and the new query string. */ if (path != null) { @@ -749,6 +749,33 @@ private void dispatchSyncForward( request.setQueryString(request.getQueryString()); } + /* + * Aggregate the request parameters with giving the new parameter values + * precedence over the old ones. + */ + if (request.getQueryString() != null) { + String queryString = request.getQueryString(); + Map parameters = request.getModifiableParameterMap(); + for (String param : queryString.split("&")) { + String[] pair = param.split("="); + String key = URLDecoder.decode(pair[0], StandardCharsets.UTF_8); + String value = ""; + if (pair.length > 1) { + value = URLDecoder.decode(pair[1], StandardCharsets.UTF_8); + } + String[] values = parameters.get(key); + if (values == null) { + values = new String[]{value}; + parameters.put(key, values); + } else { + String[] newValues = new String[values.length + 1]; + System.arraycopy(values, 0, newValues, 1, values.length); + newValues[0] = value; + parameters.put(key, newValues); + } + } + } + invocationFinder.addFilters(FORWARD, servletInvocation, request.getServletPath(), ""); if (servletInvocation.getServletEnvironment() != null) { diff --git a/core/impl/src/main/java/cloud/piranha/core/impl/DefaultWebApplicationRequest.java b/core/impl/src/main/java/cloud/piranha/core/impl/DefaultWebApplicationRequest.java index 892cb7519..d800ddd9f 100644 --- a/core/impl/src/main/java/cloud/piranha/core/impl/DefaultWebApplicationRequest.java +++ b/core/impl/src/main/java/cloud/piranha/core/impl/DefaultWebApplicationRequest.java @@ -600,6 +600,11 @@ public String getMethod() { return method; } + @Override + public Map getModifiableParameterMap() { + return parameters; + } + @Override public MultipartConfigElement getMultipartConfig() { return multipartConfig; diff --git a/core/impl/src/test/java/cloud/piranha/core/impl/DefaultServletRequestDispatcherTest.java b/core/impl/src/test/java/cloud/piranha/core/impl/DefaultServletRequestDispatcherTest.java index b9e57f038..d9c5c7b7e 100644 --- a/core/impl/src/test/java/cloud/piranha/core/impl/DefaultServletRequestDispatcherTest.java +++ b/core/impl/src/test/java/cloud/piranha/core/impl/DefaultServletRequestDispatcherTest.java @@ -27,6 +27,7 @@ */ package cloud.piranha.core.impl; +import cloud.piranha.core.api.WebApplication; import java.io.IOException; import jakarta.servlet.RequestDispatcher; import jakarta.servlet.ServletException; @@ -86,12 +87,12 @@ void testForward2() throws Exception { TestWebApplicationRequest request = new TestWebApplicationRequest(); request.setWebApplication(webApplication); - + TestWebApplicationResponse response = new TestWebApplicationResponse(); response.setWebApplication(webApplication); - + webApplication.linkRequestAndResponse(request, response); - + RequestDispatcher dispatcher = webApplication.getRequestDispatcher("/Snoop"); dispatcher.forward(request, response); String responseText = new String(response.getResponseBytes()); @@ -136,34 +137,83 @@ void testForward4() throws Exception { RequestDispatcher dispatcher = webApp.getRequestDispatcher("/Runtime"); assertThrows(RuntimeException.class, () -> dispatcher.forward(request, response)); } - + /** * Test that a request given to the request dispatcher upon forward is the * same as the original request. */ @Test void testForwardNoWrapping() throws Exception { + WebApplication webApplication = new DefaultWebApplicationBuilder() + .servlet("NoWrapping", new HttpServlet() { + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + resp.getWriter().print(req.toString()); + } + }) + .servletMapping("NoWrapping", "/nowrapping") + .build(); + webApplication.initialize(); + webApplication.start(); DefaultWebApplicationRequest request = new DefaultWebApplicationRequest(); DefaultWebApplicationResponse response = new DefaultWebApplicationResponse(); ByteArrayOutputStream byteOutput = new ByteArrayOutputStream(); response.getWebApplicationOutputStream().setOutputStream(byteOutput); response.setBodyOnly(true); - DefaultWebApplication webApplication = new DefaultWebApplication(); - webApplication.addServlet("NoWrapping", new HttpServlet() { - @Override - protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { - response.getWriter().print(req.toString()); - } - }); - webApplication.addServletMapping("NoWrapping", "/nowrapping"); - webApplication.initialize(); - webApplication.start(); RequestDispatcher dispatcher = webApplication.getRequestDispatcher("/nowrapping"); assertNotNull(dispatcher); dispatcher.forward(request, response); assertEquals(request.toString(), byteOutput.toString("UTF-8")); } + /** + * Test a forward with a query string. + */ + @Test + void testForwardWithQueryString() throws Exception { + WebApplication webApplication = new DefaultWebApplicationBuilder() + .servlet("Forward", new HttpServlet() { + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + /* + * Force the original parameter to be parsed out of the query string. + */ + req.getParameter("p"); + /* + * Dispatch with a new parameter value. + */ + getServletContext() + .getRequestDispatcher("/forward2?p=New") + .forward(req, resp); + } + }) + .servletMapping("Forward", "/forward") + .servlet("Forward2", new HttpServlet() { + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + /* + * We should be getting the new parameter value here as it takes precendence. + */ + resp.getWriter().print(req.getParameterMap().get("p")[0]); + } + }) + .servletMapping("Forward2", "/forward2") + .build(); + webApplication.initialize(); + webApplication.start(); + DefaultWebApplicationRequest request = new DefaultWebApplicationRequest(); + request.setWebApplication(webApplication); + request.setServletPath("/forward"); + request.setQueryString("p=Original"); + DefaultWebApplicationResponse response = new DefaultWebApplicationResponse(); + ByteArrayOutputStream byteOutput = new ByteArrayOutputStream(); + response.getWebApplicationOutputStream().setOutputStream(byteOutput); + response.setBodyOnly(true); + response.setWebApplication(webApplication); + webApplication.service(request, response); + assertEquals("New", byteOutput.toString()); + } + /** * Test include method. * @@ -178,12 +228,12 @@ void testInclude() throws Exception { TestWebApplicationRequest request = new TestWebApplicationRequest(); request.setWebApplication(webApplication); - + TestWebApplicationResponse response = new TestWebApplicationResponse(); response.setWebApplication(webApplication); webApplication.linkRequestAndResponse(request, response); - + RequestDispatcher dispatcher = webApplication.getNamedDispatcher("Echo"); dispatcher.include(request, response); response.flushBuffer(); @@ -213,29 +263,30 @@ void testInclude2() throws Exception { webApp.stop(); assertTrue(responseText.contains("ECHO")); } - + /** * Test that a request given to the request dispatcher upon include is the * same as the original request. */ @Test void testIncludeNoWrapping() throws Exception { - DefaultWebApplication webApplication = new DefaultWebApplication(); - webApplication.addServlet("NoWrapping", new HttpServlet() { - @Override - protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { - req.getRequestDispatcher("/nowrapping2").include(req, resp); - resp.flushBuffer(); - } - }); - webApplication.addServlet("NoWrapping2", new HttpServlet() { - @Override - protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { - resp.getWriter().print(req.toString()); - } - }); - webApplication.addServletMapping("NoWrapping", "/nowrapping"); - webApplication.addServletMapping("NoWrapping2", "/nowrapping2"); + WebApplication webApplication = new DefaultWebApplicationBuilder() + .servlet("NoWrapping", new HttpServlet() { + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + req.getRequestDispatcher("/nowrapping2").include(req, resp); + resp.flushBuffer(); + } + }) + .servletMapping("NoWrapping", "/nowrapping") + .servlet("NoWrapping2", new HttpServlet() { + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + resp.getWriter().print(req.toString()); + } + }) + .servletMapping("NoWrapping2", "/nowrapping2") + .build(); webApplication.initialize(); webApplication.start(); DefaultWebApplicationRequest request = new DefaultWebApplicationRequest(); @@ -296,8 +347,9 @@ void testErrorDispatcher2() throws Exception { assertTrue(responseText.contains(RequestDispatcher.ERROR_EXCEPTION_TYPE)); assertTrue(responseText.contains(IOException.class.getName())); } - + static class TestSendError extends HttpServlet { + @Override protected void doGet(HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException { if (request.getParameter("send-error") != null) {