diff --git a/servlet/src/main/java/io/undertow/servlet/core/ManagedServlet.java b/servlet/src/main/java/io/undertow/servlet/core/ManagedServlet.java index f2b0af8268..0742901c47 100644 --- a/servlet/src/main/java/io/undertow/servlet/core/ManagedServlet.java +++ b/servlet/src/main/java/io/undertow/servlet/core/ManagedServlet.java @@ -20,7 +20,9 @@ import java.nio.file.Path; import java.nio.file.Paths; +import java.util.Date; import java.util.List; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; import javax.servlet.MultipartConfigElement; import javax.servlet.Servlet; @@ -61,6 +63,11 @@ public class ManagedServlet implements Lifecycle { private FormParserFactory formParserFactory; private MultipartConfigElement multipartConfig; + private static final AtomicLongFieldUpdater unavailableUntilUpdater = AtomicLongFieldUpdater.newUpdater(ManagedServlet.class, "unavailableUntil"); + + @SuppressWarnings("unused") + private volatile long unavailableUntil = 0; + public ManagedServlet(final ServletInfo servletInfo, final ServletContextImpl servletContext) { this.servletInfo = servletInfo; this.servletContext = servletContext; @@ -157,6 +164,18 @@ public boolean isPermanentlyUnavailable() { return permanentlyUnavailable; } + public boolean isTemporarilyUnavailable() { + long until = unavailableUntil; + if (until != 0) { + if (System.currentTimeMillis() < until) { + return true; + } else { + unavailableUntilUpdater.compareAndSet(this, until, 0); + } + } + return false; + } + public void setPermanentlyUnavailable(final boolean permanentlyUnavailable) { this.permanentlyUnavailable = permanentlyUnavailable; } @@ -184,13 +203,29 @@ public void forceInit() throws ServletException { } synchronized (this) { if (!started) { - instanceStrategy.start(); + try { + instanceStrategy.start(); + } catch (UnavailableException e) { + handleUnavailableException(e); + } started = true; } } } } + public void handleUnavailableException(UnavailableException e) { + if (e.isPermanent()) { + UndertowServletLogger.REQUEST_LOGGER.stoppingServletDueToPermanentUnavailability(getServletInfo().getName(), e); + stop(); + setPermanentlyUnavailable(true); + } else { + long until = System.currentTimeMillis() + e.getUnavailableSeconds() * 1000; + unavailableUntilUpdater.set(this, until); + UndertowServletLogger.REQUEST_LOGGER.stoppingServletUntilDueToTemporaryUnavailability(getServletInfo().getName(), new Date(until), e); + } + } + public ServletInfo getServletInfo() { return servletInfo; } diff --git a/servlet/src/main/java/io/undertow/servlet/handlers/ServletHandler.java b/servlet/src/main/java/io/undertow/servlet/handlers/ServletHandler.java index b9143e677c..44274f0125 100644 --- a/servlet/src/main/java/io/undertow/servlet/handlers/ServletHandler.java +++ b/servlet/src/main/java/io/undertow/servlet/handlers/ServletHandler.java @@ -19,8 +19,6 @@ package io.undertow.servlet.handlers; import java.io.IOException; -import java.util.Date; -import java.util.concurrent.atomic.AtomicLongFieldUpdater; import javax.servlet.Servlet; import javax.servlet.ServletException; @@ -46,10 +44,6 @@ public class ServletHandler implements HttpHandler { private final ManagedServlet managedServlet; - private static final AtomicLongFieldUpdater unavailableUntilUpdater = AtomicLongFieldUpdater.newUpdater(ServletHandler.class, "unavailableUntil"); - - @SuppressWarnings("unused") - private volatile long unavailableUntil = 0; public ServletHandler(final ManagedServlet managedServlet) { this.managedServlet = managedServlet; @@ -63,15 +57,10 @@ public void handleRequest(final HttpServerExchange exchange) throws IOException, return; } - long until = unavailableUntil; - if (until != 0) { + if (managedServlet.isTemporarilyUnavailable()) { UndertowServletLogger.REQUEST_LOGGER.debugf("Returning 503 for servlet %s due to temporary unavailability", managedServlet.getServletInfo().getName()); - if (System.currentTimeMillis() < until) { - exchange.setStatusCode(StatusCodes.SERVICE_UNAVAILABLE); - return; - } else { - unavailableUntilUpdater.compareAndSet(this, until, 0); - } + exchange.setStatusCode(StatusCodes.SERVICE_UNAVAILABLE); + return; } final ServletRequestContext servletRequestContext = exchange.getAttachment(ServletRequestContext.ATTACHMENT_KEY); if(!managedServlet.getServletInfo().isAsyncSupported()) { @@ -95,14 +84,10 @@ public void handleRequest(final HttpServerExchange exchange) throws IOException, // } //} } catch (UnavailableException e) { + managedServlet.handleUnavailableException(e); if (e.isPermanent()) { - UndertowServletLogger.REQUEST_LOGGER.stoppingServletDueToPermanentUnavailability(managedServlet.getServletInfo().getName(), e); - managedServlet.stop(); - managedServlet.setPermanentlyUnavailable(true); exchange.setStatusCode(StatusCodes.NOT_FOUND); } else { - unavailableUntilUpdater.set(this, System.currentTimeMillis() + e.getUnavailableSeconds() * 1000); - UndertowServletLogger.REQUEST_LOGGER.stoppingServletUntilDueToTemporaryUnavailability(managedServlet.getServletInfo().getName(), new Date(until), e); exchange.setStatusCode(StatusCodes.SERVICE_UNAVAILABLE); } } finally { diff --git a/servlet/src/test/java/io/undertow/servlet/test/spec/UnavailableServlet.java b/servlet/src/test/java/io/undertow/servlet/test/spec/UnavailableServlet.java new file mode 100644 index 0000000000..98e1b4a11a --- /dev/null +++ b/servlet/src/test/java/io/undertow/servlet/test/spec/UnavailableServlet.java @@ -0,0 +1,48 @@ +package io.undertow.servlet.test.spec; + +import javax.servlet.Servlet; +import javax.servlet.ServletConfig; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.UnavailableException; +import java.io.IOException; + +/** + * @author Stuart Douglas + */ +public class UnavailableServlet implements Servlet { + + static final String PERMANENT = "permanent"; + static boolean first = true; + + @Override + public void init(ServletConfig config) throws ServletException { + if(config.getInitParameter(PERMANENT) != null) { + throw new UnavailableException("msg"); + } else if(first){ + first = false; + throw new UnavailableException("msg", 1); + } + } + + @Override + public ServletConfig getServletConfig() { + return null; + } + + @Override + public void service(ServletRequest req, ServletResponse res) throws ServletException, IOException { + + } + + @Override + public String getServletInfo() { + return null; + } + + @Override + public void destroy() { + + } +} diff --git a/servlet/src/test/java/io/undertow/servlet/test/spec/UnavailableServletTestCase.java b/servlet/src/test/java/io/undertow/servlet/test/spec/UnavailableServletTestCase.java new file mode 100644 index 0000000000..2f739ffd29 --- /dev/null +++ b/servlet/src/test/java/io/undertow/servlet/test/spec/UnavailableServletTestCase.java @@ -0,0 +1,71 @@ +package io.undertow.servlet.test.spec; + +import io.undertow.servlet.test.util.DeploymentUtils; +import io.undertow.testutils.DefaultServer; +import io.undertow.testutils.HttpClientUtils; +import io.undertow.testutils.TestHttpClient; +import io.undertow.util.StatusCodes; +import org.apache.http.HttpResponse; +import org.apache.http.client.methods.HttpGet; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; + +import javax.servlet.ServletException; + +import java.io.IOException; + +import static io.undertow.servlet.Servlets.servlet; + +/** + * @author Stuart Douglas + */ +@RunWith(DefaultServer.class) +public class UnavailableServletTestCase { + + + @BeforeClass + public static void setup() throws ServletException { + DeploymentUtils.setupServlet( + servlet("p", UnavailableServlet.class) + .addInitParam(UnavailableServlet.PERMANENT, "1") + .addMapping("/p"), + servlet("t", UnavailableServlet.class) + .addMapping("/t")); + + } + + @Test + public void testPermanentUnavailableServlet() throws IOException { + TestHttpClient client = new TestHttpClient(); + try { + HttpGet get = new HttpGet(DefaultServer.getDefaultServerURL() + "/servletContext/p"); + HttpResponse result = client.execute(get); + Assert.assertEquals(StatusCodes.NOT_FOUND, result.getStatusLine().getStatusCode()); + HttpClientUtils.readResponse(result); + } finally { + client.getConnectionManager().shutdown(); + } + } + + + @Test + public void testTempUnavailableServlet() throws IOException, InterruptedException { + TestHttpClient client = new TestHttpClient(); + try { + HttpGet get = new HttpGet(DefaultServer.getDefaultServerURL() + "/servletContext/t"); + HttpResponse result = client.execute(get); + Assert.assertEquals(StatusCodes.SERVICE_UNAVAILABLE, result.getStatusLine().getStatusCode()); + HttpClientUtils.readResponse(result); + Thread.sleep(1001); + + get = new HttpGet(DefaultServer.getDefaultServerURL() + "/servletContext/t"); + result = client.execute(get); + Assert.assertEquals(StatusCodes.OK, result.getStatusLine().getStatusCode()); + HttpClientUtils.readResponse(result); + } finally { + client.getConnectionManager().shutdown(); + } + } +}