Skip to content

Commit

Permalink
[RESTEASY-1721] ResourceMethodInvoker checks if @context HttpServletR…
Browse files Browse the repository at this point in the history
…esponse was written to. (#2752)
  • Loading branch information
ronsigal committed Apr 14, 2021
1 parent 8a80aab commit abfcd15
Show file tree
Hide file tree
Showing 4 changed files with 719 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,15 @@
import org.jboss.resteasy.spi.ValueInjector;
import org.jboss.resteasy.spi.util.Types;

import javax.servlet.ServletOutputStream;
import javax.servlet.WriteListener;
import javax.ws.rs.container.ResourceInfo;
import javax.ws.rs.core.Application;
import javax.ws.rs.ext.Providers;
import javax.ws.rs.sse.Sse;
import javax.ws.rs.sse.SseEventSink;

import java.io.IOException;
import java.lang.annotation.Annotation;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.InvocationTargetException;
Expand All @@ -22,7 +26,9 @@
import java.lang.reflect.Type;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletionStage;

/**
Expand All @@ -32,13 +38,14 @@
@SuppressWarnings("unchecked")
public class ContextParameterInjector implements ValueInjector
{
private Class rawType;
private Class proxy;
private Class<?> rawType;
private Class<?> proxy;
private ResteasyProviderFactory factory;
private Type genericType;
private Annotation[] annotations;
private volatile boolean outputStreamWasWritten = false;

public ContextParameterInjector(final Class proxy, final Class rawType, final Type genericType, final Annotation[] annotations, final ResteasyProviderFactory factory)
public ContextParameterInjector(final Class<?> proxy, final Class<?> rawType, final Type genericType, final Annotation[] annotations, final ResteasyProviderFactory factory)
{
this.rawType = rawType;
this.genericType = genericType;
Expand All @@ -61,7 +68,7 @@ else if (rawType.equals(Sse.class))
{
return new SseImpl();
} else if (rawType == CompletionStage.class) {
return new CompletionStageHolder((CompletionStage)createProxy());
return new CompletionStageHolder((CompletionStage<?>)createProxy());
}
return createProxy();
}
Expand Down Expand Up @@ -94,7 +101,7 @@ private Object unwrapIfRequired(HttpRequest request, Object contextData, boolean
}
return (CompletionStage<Object>) contextData;
} else if (rawType == CompletionStage.class && contextData instanceof CompletionStage) {
return new CompletionStageHolder((CompletionStage)contextData);
return new CompletionStageHolder((CompletionStage<?>)contextData);
} else if (!unwrapAsync && rawType != CompletionStage.class && contextData instanceof CompletionStage) {
throw new LoggableFailure(Messages.MESSAGES.shouldBeUnreachable());
}
Expand Down Expand Up @@ -123,6 +130,15 @@ public Object invoke(Object o, Method method, Object[] objects) throws Throwable
}
throw new LoggableFailure(Messages.MESSAGES.unableToFindContextualData(rawType.getName()));
}
// Fix for RESTEASY-1721
if ("javax.servlet.http.HttpServletResponse".equals(rawType.getName()))
{
if ("getOutputStream".equals(method.getName()))
{
ServletOutputStream sos = (ServletOutputStream) method.invoke(delegate, objects);
return new ContextOutputStream(sos);
}
}
return method.invoke(delegate, objects);
}
catch (IllegalAccessException e)
Expand Down Expand Up @@ -158,11 +174,11 @@ else if (!rawType.isInterface())
if (delegate != null) return unwrapIfRequired(null, delegate, unwrapAsync);
else throw new RuntimeException(Messages.MESSAGES.illegalToInjectNonInterfaceType());
} else if (rawType == CompletionStage.class) {
return new CompletionStageHolder((CompletionStage)createProxy());
return new CompletionStageHolder((CompletionStage<?>)createProxy());
}

return createProxy();
}
}

protected Object createProxy()
{
Expand All @@ -179,20 +195,195 @@ protected Object createProxy()
}
else
{
Class[] intfs = {rawType};
Object delegate = factory.getContextData(rawType, genericType, annotations, false);
Class<?>[] intfs = computeInterfaces(delegate, rawType);
ClassLoader clazzLoader = null;
final SecurityManager sm = System.getSecurityManager();
if (sm == null) {
clazzLoader = rawType.getClassLoader();
clazzLoader = delegate == null ? rawType.getClassLoader() : delegate.getClass().getClassLoader();
} else {
clazzLoader = AccessController.doPrivileged(new PrivilegedAction<ClassLoader>() {
@Override
public ClassLoader run() {
return rawType.getClassLoader();
return delegate == null ? rawType.getClassLoader() : delegate.getClass().getClassLoader();
}
});
}
return Proxy.newProxyInstance(clazzLoader, intfs, new GenericDelegatingProxy());
}
}

boolean outputStreamWasWrittenTo()
{
return outputStreamWasWritten;
}

protected Class<?>[] computeInterfaces(Object delegate, Class<?> cls)
{
Set<Class<?>> set = new HashSet<>();
set.add(cls);
if (delegate != null)
{
Class<?> delegateClass = delegate.getClass();
while (delegateClass != null)
{
for (Class<?> intf : delegateClass.getInterfaces())
{
set.add(intf);
for (Class<?> superIntf : intf.getInterfaces())
{
set.add(superIntf);
}
}
delegateClass = delegateClass.getSuperclass();
}
}
return set.toArray(new Class<?>[]{});
}

private final class ContextOutputStream extends ServletOutputStream
{
private ServletOutputStream delegate;

ContextOutputStream(final ServletOutputStream delegate)
{
this.delegate = delegate;
}

////////////////////////////////////////////////////////////////
/// ServletOutputStream methods
////////////////////////////////////////////////////////////////
@Override
public void print(String s) throws IOException
{
delegate.print(s);
outputStreamWasWritten = true;
}
@Override
public void print(boolean b) throws IOException
{
delegate.print(b);
outputStreamWasWritten = true;
}
@Override
public void print(char c) throws IOException
{
delegate.print(c);
outputStreamWasWritten = true;
}
@Override
public void print(int i) throws IOException
{
delegate.print(i);
outputStreamWasWritten = true;
}
@Override
public void print(long l) throws IOException
{
delegate.print(l);
outputStreamWasWritten = true;
}
@Override
public void print(float f) throws IOException
{
delegate.print(f);
outputStreamWasWritten = true;
}
@Override
public void print(double d) throws IOException
{
delegate.print(d);
outputStreamWasWritten = true;
}
@Override
public void println() throws IOException
{
delegate.println();
outputStreamWasWritten = true;
}
@Override
public void println(String s) throws IOException
{
delegate.println(s);
outputStreamWasWritten = true;
}
@Override
public void println(boolean b) throws IOException
{
delegate.println(b);
outputStreamWasWritten = true;
}
@Override
public void println(char c) throws IOException
{
delegate.print(c);
outputStreamWasWritten = true;
}
@Override
public void println(int i) throws IOException
{
delegate.println(i);
outputStreamWasWritten = true;
}
@Override
public void println(long l) throws IOException
{
delegate.println(l);
outputStreamWasWritten = true;
}
@Override
public void println(float f) throws IOException
{
delegate.println(f);
outputStreamWasWritten = true;
}
@Override
public void println(double d) throws IOException
{
delegate.println(d);
outputStreamWasWritten = true;
}
@Override
public boolean isReady()
{
return delegate.isReady();
}
@Override
public void setWriteListener(WriteListener writeListener)
{
delegate.setWriteListener(writeListener);
}

////////////////////////////////////////////////////////////////
/// OutputStream methods
////////////////////////////////////////////////////////////////
@Override
public void write(byte[] b) throws IOException
{
delegate.write(b);
outputStreamWasWritten = true;
}
@Override
public void write(byte[] b, int off, int len) throws IOException
{
delegate.write(b, off, len);
outputStreamWasWritten = true;
}
@Override
public void write(int b) throws IOException
{
delegate.write(b);
outputStreamWasWritten = true;
}
@Override
public void flush() throws IOException
{
delegate.flush();
}
@Override
public void close() throws IOException
{
delegate.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ private BuiltResponse afterInvoke(HttpRequest request, AsyncResponseConsumer asy
{
return null;
}
if (rtn == null || method.getReturnType().equals(void.class))
if (!contextOutputStreamWrittenTo() && (rtn == null || method.getReturnType().equals(void.class)))
{
BuiltResponse build = (BuiltResponse) Response.noContent().build();
build.addMethodAnnotations(getMethodAnnotations());
Expand Down Expand Up @@ -808,7 +808,22 @@ protected MediaType resolveContentTypeByAccept(List<MediaType> accepts, Object e
return MediaType.WILDCARD_TYPE;
}


/**
* Checks if any bytes were written to a @Context HttpServletResponse
* @see ContextParameterInjector for details
* Fix for RESTEASY-1721
*/
private boolean contextOutputStreamWrittenTo()
{
for (ValueInjector vi : methodInjector.getParams())
{
if (vi instanceof ContextParameterInjector)
{
return ((ContextParameterInjector) vi).outputStreamWasWrittenTo();
}
}
return false;
}


public Set<String> getHttpMethods()
Expand Down
Loading

0 comments on commit abfcd15

Please sign in to comment.