diff --git a/spring-web/src/main/java/org/springframework/web/multipart/support/StandardMultipartHttpServletRequest.java b/spring-web/src/main/java/org/springframework/web/multipart/support/StandardMultipartHttpServletRequest.java index 2cf8f3ac2316..6d7bd180bfc5 100644 --- a/spring-web/src/main/java/org/springframework/web/multipart/support/StandardMultipartHttpServletRequest.java +++ b/spring-web/src/main/java/org/springframework/web/multipart/support/StandardMultipartHttpServletRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -112,13 +112,21 @@ private void parseRequest(HttpServletRequest request) { } protected void handleParseFailure(Throwable ex) { - String msg = ex.getMessage(); - if (msg != null) { - msg = msg.toLowerCase(); - if (msg.contains("size") && msg.contains("exceed")) { - throw new MaxUploadSizeExceededException(-1, ex); + // MaxUploadSizeExceededException ? + Throwable cause = ex; + do { + String msg = cause.getMessage(); + if (msg != null) { + msg = msg.toLowerCase(); + if (msg.contains("exceed") && (msg.contains("size") || msg.contains("length"))) { + throw new MaxUploadSizeExceededException(-1, ex); + } } + cause = cause.getCause(); } + while (cause != null); + + // General MultipartException throw new MultipartException("Failed to parse multipart servlet request", ex); } diff --git a/spring-web/src/test/java/org/springframework/web/multipart/support/StandardMultipartHttpServletRequestTests.java b/spring-web/src/test/java/org/springframework/web/multipart/support/StandardMultipartHttpServletRequestTests.java index 5876b3d659cd..79f98200bab9 100644 --- a/spring-web/src/test/java/org/springframework/web/multipart/support/StandardMultipartHttpServletRequestTests.java +++ b/spring-web/src/test/java/org/springframework/web/multipart/support/StandardMultipartHttpServletRequestTests.java @@ -18,23 +18,29 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.util.Collection; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.Part; import org.junit.jupiter.api.Test; import org.springframework.http.converter.FormHttpMessageConverter; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; +import org.springframework.web.multipart.MaxUploadSizeExceededException; import org.springframework.web.multipart.MultipartFile; import org.springframework.web.testfixture.http.MockHttpOutputMessage; import org.springframework.web.testfixture.servlet.MockHttpServletRequest; import org.springframework.web.testfixture.servlet.MockPart; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; /** * Unit tests for {@link StandardMultipartHttpServletRequest}. * * @author Rossen Stoyanchev + * @author Juergen Hoeller */ class StandardMultipartHttpServletRequestTests { @@ -92,6 +98,31 @@ void multipartFileResource() throws IOException { """.replace("\n", "\r\n")); } + @Test + void plainSizeExceededServletException() { + ServletException ex = new ServletException("Request size exceeded"); + + assertThatExceptionOfType(MaxUploadSizeExceededException.class) + .isThrownBy(() -> requestWithException(ex)).withCause(ex); + } + + @Test // gh-28759 + void jetty94MaxRequestSizeException() { + ServletException ex = new ServletException(new IllegalStateException("Request exceeds maxRequestSize")); + + assertThatExceptionOfType(MaxUploadSizeExceededException.class) + .isThrownBy(() -> requestWithException(ex)).withCause(ex); + } + + @Test // gh-31850 + void jetty12MaxLengthExceededException() { + ServletException ex = new ServletException(new RuntimeException("400: bad multipart", + new IllegalStateException("max length exceeded"))); + + assertThatExceptionOfType(MaxUploadSizeExceededException.class) + .isThrownBy(() -> requestWithException(ex)).withCause(ex); + } + private static StandardMultipartHttpServletRequest requestWithPart(String name, String disposition, String content) { MockHttpServletRequest request = new MockHttpServletRequest(); @@ -101,4 +132,14 @@ private static StandardMultipartHttpServletRequest requestWithPart(String name, return new StandardMultipartHttpServletRequest(request); } + private static StandardMultipartHttpServletRequest requestWithException(ServletException ex) { + MockHttpServletRequest request = new MockHttpServletRequest() { + @Override + public Collection getParts() throws ServletException { + throw ex; + } + }; + return new StandardMultipartHttpServletRequest(request); + } + }