Skip to content

Commit

Permalink
Add state and response wrapping to StandardServletAsyncWebRequest
Browse files Browse the repository at this point in the history
The wrapped response prevents use after AsyncListener onError or completion
to ensure compliance with Servlet Spec 2.3.3.4.

See gh-32342
  • Loading branch information
rstoyanchev committed Mar 1, 2024
1 parent 3478a70 commit 6432b13
Show file tree
Hide file tree
Showing 6 changed files with 254 additions and 21 deletions.
@@ -0,0 +1,44 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.web.context.request.async;

import java.io.IOException;

/**
* Raised when the response for an asynchronous request becomes unusable as
* indicated by a write failure, or a Servlet container error notification, or
* after the async request has completed.
*
* <p>The exception relies on response wrapping, and on {@code AsyncListener}
* notifications, managed by {@link StandardServletAsyncWebRequest}.
*
* @author Rossen Stoyanchev
* @since 5.3.33
*/
@SuppressWarnings("serial")
public class AsyncRequestNotUsableException extends IOException {


public AsyncRequestNotUsableException(String message) {
super(message);
}

public AsyncRequestNotUsableException(String message, Throwable cause) {
super(message, cause);
}

}
Expand Up @@ -19,14 +19,17 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;

import javax.servlet.AsyncContext;
import javax.servlet.AsyncEvent;
import javax.servlet.AsyncListener;
import javax.servlet.ServletOutputStream;
import javax.servlet.WriteListener;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;

import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
Expand All @@ -45,8 +48,6 @@
*/
public class StandardServletAsyncWebRequest extends ServletWebRequest implements AsyncWebRequest, AsyncListener {

private final AtomicBoolean asyncCompleted = new AtomicBoolean();

private final List<Runnable> timeoutHandlers = new ArrayList<>();

private final List<Consumer<Throwable>> exceptionHandlers = new ArrayList<>();
Expand All @@ -59,14 +60,43 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
@Nullable
private AsyncContext asyncContext;

private final AtomicReference<State> state;

private volatile boolean hasError;


/**
* Create a new instance for the given request/response pair.
* @param request current HTTP request
* @param response current HTTP response
*/
public StandardServletAsyncWebRequest(HttpServletRequest request, HttpServletResponse response) {
super(request, response);
this(request, response, null);
}

/**
* Constructor to wrap the request and response for the current dispatch that
* also picks up the state of the last (probably the REQUEST) dispatch.
* @param request current HTTP request
* @param response current HTTP response
* @param previousRequest the existing request from the last dispatch
* @since 5.3.33
*/
StandardServletAsyncWebRequest(HttpServletRequest request, HttpServletResponse response,
@Nullable StandardServletAsyncWebRequest previousRequest) {

super(request, new LifecycleHttpServletResponse(response));

if (previousRequest != null) {
this.state = previousRequest.state;
this.hasError = previousRequest.hasError;
}
else {
this.state = new AtomicReference<>(State.ACTIVE);
}

//noinspection DataFlowIssue
((LifecycleHttpServletResponse) getResponse()).setParent(this);
}


Expand Down Expand Up @@ -107,7 +137,7 @@ public boolean isAsyncStarted() {
*/
@Override
public boolean isAsyncComplete() {
return this.asyncCompleted.get();
return (this.state.get() == State.COMPLETED);
}

@Override
Expand All @@ -117,6 +147,7 @@ public void startAsync() {
"in async request processing. This is done in Java code using the Servlet API " +
"or by adding \"<async-supported>true</async-supported>\" to servlet and " +
"filter declarations in web.xml.");

Assert.state(!isAsyncComplete(), "Async processing has already completed");

if (isAsyncStarted()) {
Expand All @@ -131,8 +162,10 @@ public void startAsync() {

@Override
public void dispatch() {
Assert.state(this.asyncContext != null, "Cannot dispatch without an AsyncContext");
this.asyncContext.dispatch();
Assert.state(this.asyncContext != null, "AsyncContext not yet initialized");
if (!this.isAsyncComplete()) {
this.asyncContext.dispatch();
}
}


Expand All @@ -151,14 +184,152 @@ public void onTimeout(AsyncEvent event) throws IOException {

@Override
public void onError(AsyncEvent event) throws IOException {
transitionToErrorState();
this.exceptionHandlers.forEach(consumer -> consumer.accept(event.getThrowable()));
}

private void transitionToErrorState() {
this.hasError = true;
this.state.compareAndSet(State.ACTIVE, State.ERROR);
}

@Override
public void onComplete(AsyncEvent event) throws IOException {
this.completionHandlers.forEach(Runnable::run);
this.asyncContext = null;
this.asyncCompleted.set(true);
this.state.set(State.COMPLETED);
}


/**
* Response wrapper to wrap the output stream with {@link LifecycleServletOutputStream}.
*/
private static final class LifecycleHttpServletResponse extends HttpServletResponseWrapper {

@Nullable
private StandardServletAsyncWebRequest parent;

private ServletOutputStream outputStream;

public LifecycleHttpServletResponse(HttpServletResponse response) {
super(response);
}

public void setParent(StandardServletAsyncWebRequest parent) {
this.parent = parent;
}

@Override
public ServletOutputStream getOutputStream() {
if (this.outputStream == null) {
Assert.notNull(this.parent, "Not initialized");
this.outputStream = new LifecycleServletOutputStream((HttpServletResponse) getResponse(), this.parent);
}
return this.outputStream;
}
}


/**
* Wraps a ServletOutputStream to prevent use after Servlet container onError
* notifications, and after async request completion.
*/
private static final class LifecycleServletOutputStream extends ServletOutputStream {

private final HttpServletResponse response;

private final StandardServletAsyncWebRequest parent;

private LifecycleServletOutputStream(HttpServletResponse response, StandardServletAsyncWebRequest parent) {
this.response = response;
this.parent = parent;
}

@Override
public boolean isReady() {
return false;
}

@Override
public void setWriteListener(WriteListener writeListener) {
}

@Override
public void write(int b) throws IOException {
checkState();
try {
this.response.getOutputStream().write(b);
}
catch (IOException ex) {
handleIOException(ex, "ServletOutputStream failed to write");
}
}

public void write(byte[] buf, int offset, int len) throws IOException {
checkState();
try {
this.response.getOutputStream().write(buf, offset, len);
}
catch (IOException ex) {
handleIOException(ex, "ServletOutputStream failed to write");
}
}

@Override
public void flush() throws IOException {
checkState();
try {
this.response.getOutputStream().flush();
}
catch (IOException ex) {
handleIOException(ex, "ServletOutputStream failed to flush");
}
}

@Override
public void close() throws IOException {
checkState();
try {
this.response.getOutputStream().close();
}
catch (IOException ex) {
handleIOException(ex, "ServletOutputStream failed to close");
}
}

private void checkState() throws AsyncRequestNotUsableException {
if (this.parent.state.get() != State.ACTIVE) {
String reason = this.parent.state.get() == State.COMPLETED ?
"async request completion" : "Servlet container onError notification";
throw new AsyncRequestNotUsableException("Response not usable after " + reason + ".");
}
}

private void handleIOException(IOException ex, String msg) throws AsyncRequestNotUsableException {
this.parent.transitionToErrorState();
throw new AsyncRequestNotUsableException(msg, ex);
}

}


/**
* Represents a state for {@link StandardServletAsyncWebRequest} to be in.
* <p><pre>
* ACTIVE ----+
* | |
* v |
* ERROR |
* | |
* v |
* COMPLETED <--+
* </pre>
* @since 5.3.33
*/
private enum State {

ACTIVE, ERROR, COMPLETED

}

}
Expand Up @@ -132,6 +132,15 @@ public void setAsyncWebRequest(AsyncWebRequest asyncWebRequest) {
WebAsyncUtils.WEB_ASYNC_MANAGER_ATTRIBUTE, RequestAttributes.SCOPE_REQUEST));
}

/**
* Return the current {@link AsyncWebRequest}.
* @since 5.3.33
*/
@Nullable
public AsyncWebRequest getAsyncWebRequest() {
return this.asyncWebRequest;
}

/**
* Configure an AsyncTaskExecutor for use with concurrent processing via
* {@link #startCallableProcessing(Callable, Object...)}.
Expand Down
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -82,7 +82,10 @@ public static WebAsyncManager getAsyncManager(WebRequest webRequest) {
* @return an AsyncWebRequest instance (never {@code null})
*/
public static AsyncWebRequest createAsyncWebRequest(HttpServletRequest request, HttpServletResponse response) {
return new StandardServletAsyncWebRequest(request, response);
AsyncWebRequest prev = getAsyncManager(request).getAsyncWebRequest();
return (prev instanceof StandardServletAsyncWebRequest ?
new StandardServletAsyncWebRequest(request, response, (StandardServletAsyncWebRequest) prev) :
new StandardServletAsyncWebRequest(request, response));
}

}
Expand Up @@ -41,7 +41,8 @@ public class DisconnectedClientHelper {
new HashSet<>(Arrays.asList("broken pipe", "connection reset by peer"));

private static final Set<String> EXCEPTION_TYPE_NAMES =
new HashSet<>(Arrays.asList("AbortedException", "ClientAbortException", "EOFException", "EofException"));
new HashSet<>(Arrays.asList("AbortedException", "ClientAbortException",
"EOFException", "EofException", "AsyncRequestNotUsableException"));

private final Log logger;

Expand Down
Expand Up @@ -853,7 +853,21 @@ private SessionAttributesHandler getSessionAttributesHandler(HandlerMethod handl
protected ModelAndView invokeHandlerMethod(HttpServletRequest request,
HttpServletResponse response, HandlerMethod handlerMethod) throws Exception {

ServletWebRequest webRequest = new ServletWebRequest(request, response);
WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(request);
AsyncWebRequest asyncWebRequest = WebAsyncUtils.createAsyncWebRequest(request, response);
asyncWebRequest.setTimeout(this.asyncRequestTimeout);

asyncManager.setTaskExecutor(this.taskExecutor);
asyncManager.setAsyncWebRequest(asyncWebRequest);
asyncManager.registerCallableInterceptors(this.callableInterceptors);
asyncManager.registerDeferredResultInterceptors(this.deferredResultInterceptors);

// Obtain wrapped response to enforce lifecycle rule from Servlet spec, section 2.3.3.4
response = asyncWebRequest.getNativeResponse(HttpServletResponse.class);

ServletWebRequest webRequest = (asyncWebRequest instanceof ServletWebRequest ?
(ServletWebRequest) asyncWebRequest : new ServletWebRequest(request, response));

try {
WebDataBinderFactory binderFactory = getDataBinderFactory(handlerMethod);
ModelFactory modelFactory = getModelFactory(handlerMethod, binderFactory);
Expand All @@ -873,15 +887,6 @@ protected ModelAndView invokeHandlerMethod(HttpServletRequest request,
modelFactory.initModel(webRequest, mavContainer, invocableMethod);
mavContainer.setIgnoreDefaultModelOnRedirect(this.ignoreDefaultModelOnRedirect);

AsyncWebRequest asyncWebRequest = WebAsyncUtils.createAsyncWebRequest(request, response);
asyncWebRequest.setTimeout(this.asyncRequestTimeout);

WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(request);
asyncManager.setTaskExecutor(this.taskExecutor);
asyncManager.setAsyncWebRequest(asyncWebRequest);
asyncManager.registerCallableInterceptors(this.callableInterceptors);
asyncManager.registerDeferredResultInterceptors(this.deferredResultInterceptors);

if (asyncManager.hasConcurrentResult()) {
Object result = asyncManager.getConcurrentResult();
Object[] resultContext = asyncManager.getConcurrentResultContext();
Expand Down

0 comments on commit 6432b13

Please sign in to comment.