diff --git a/zxingorg/src/main/java/com/google/zxing/web/ChartDoSFilter.java b/zxingorg/src/main/java/com/google/zxing/web/ChartDoSFilter.java new file mode 100644 index 0000000000..b1d5d569de --- /dev/null +++ b/zxingorg/src/main/java/com/google/zxing/web/ChartDoSFilter.java @@ -0,0 +1,32 @@ +/* + * Copyright 2019 ZXing authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.zxing.web; + +import javax.servlet.annotation.WebFilter; +import javax.servlet.annotation.WebInitParam; + +/** + * Protect the /chart endpoint from too many requests. + */ +@WebFilter(urlPatterns = {"/w/chart"}, initParams = { + @WebInitParam(name = "maxAccessPerTime", value = "250"), + @WebInitParam(name = "accessTimeSec", value = "500"), + @WebInitParam(name = "maxEntries", value = "10000") +}) +public final class ChartDoSFilter extends DoSFilter { + // no additional implementation +} diff --git a/zxingorg/src/main/java/com/google/zxing/web/DecodeDoSFilter.java b/zxingorg/src/main/java/com/google/zxing/web/DecodeDoSFilter.java new file mode 100644 index 0000000000..d09172c863 --- /dev/null +++ b/zxingorg/src/main/java/com/google/zxing/web/DecodeDoSFilter.java @@ -0,0 +1,32 @@ +/* + * Copyright 2019 ZXing authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.zxing.web; + +import javax.servlet.annotation.WebFilter; +import javax.servlet.annotation.WebInitParam; + +/** + * Protect the /decode endpoint from too many requests. + */ +@WebFilter(urlPatterns = {"/w/decode"}, initParams = { + @WebInitParam(name = "maxAccessPerTime", value = "60"), + @WebInitParam(name = "accessTimeSec", value = "180"), + @WebInitParam(name = "maxEntries", value = "10000") +}) +public final class DecodeDoSFilter extends DoSFilter { + // no additional implementation +} diff --git a/zxingorg/src/main/java/com/google/zxing/web/DecodeServlet.java b/zxingorg/src/main/java/com/google/zxing/web/DecodeServlet.java index 65ee067a97..b7af4ccd08 100644 --- a/zxingorg/src/main/java/com/google/zxing/web/DecodeServlet.java +++ b/zxingorg/src/main/java/com/google/zxing/web/DecodeServlet.java @@ -59,6 +59,7 @@ import java.util.Map; import java.util.ResourceBundle; import java.util.Timer; +import java.util.TimerTask; import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.util.logging.Logger; @@ -89,8 +90,8 @@ fileSizeThreshold = 1 << 23, // ~8MB location = "/tmp") @WebServlet(value = "/w/decode", loadOnStartup = 1, initParams = { - @WebInitParam(name = "maxAccessPerTime", value = "150"), - @WebInitParam(name = "accessTimeSec", value = "300"), + @WebInitParam(name = "maxAccessPerTime", value = "120"), + @WebInitParam(name = "accessTimeSec", value = "120"), @WebInitParam(name = "maxEntries", value = "10000") }) public final class DecodeServlet extends HttpServlet { @@ -141,8 +142,17 @@ public void init(ServletConfig servletConfig) throws ServletException { long accessTimeMS = TimeUnit.MILLISECONDS.convert(accessTimeSec, TimeUnit.SECONDS); int maxEntries = Integer.parseInt(servletConfig.getInitParameter("maxEntries")); - timer = new Timer("DecodeServlet"); - destHostTracker = new DoSTracker(timer, maxAccessPerTime, accessTimeMS, maxEntries); + String name = getClass().getSimpleName(); + timer = new Timer(name); + destHostTracker = new DoSTracker(timer, name, maxAccessPerTime, accessTimeMS, maxEntries); + // Hack to try to avoid odd OOM due to memory leak in JAI? + timer.scheduleAtFixedRate( + new TimerTask() { + @Override + public void run() { + System.gc(); + } + }, 0L, TimeUnit.MILLISECONDS.convert(10, TimeUnit.MINUTES)); } @Override @@ -433,10 +443,8 @@ private static void processImage(BufferedImage image, try { throw savedException == null ? NotFoundException.getNotFoundInstance() : savedException; } catch (FormatException | ChecksumException e) { - log.info(e.toString()); errorResponse(request, response, "format"); } catch (ReaderException e) { // Including NotFoundException - log.info(e.toString()); errorResponse(request, response, "notfound"); } return; diff --git a/zxingorg/src/main/java/com/google/zxing/web/DoSFilter.java b/zxingorg/src/main/java/com/google/zxing/web/DoSFilter.java index 0586779b6a..db5ad26cb5 100644 --- a/zxingorg/src/main/java/com/google/zxing/web/DoSFilter.java +++ b/zxingorg/src/main/java/com/google/zxing/web/DoSFilter.java @@ -16,19 +16,18 @@ package com.google.zxing.web; +import com.google.common.base.Preconditions; + import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.FilterConfig; import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; -import javax.servlet.annotation.WebFilter; -import javax.servlet.annotation.WebInitParam; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.IOException; import java.util.Timer; -import java.util.TimerTask; import java.util.concurrent.TimeUnit; /** @@ -37,12 +36,7 @@ * * @author Sean Owen */ -@WebFilter(urlPatterns = {"/w/decode", "/w/chart"}, initParams = { - @WebInitParam(name = "maxAccessPerTime", value = "150"), - @WebInitParam(name = "accessTimeSec", value = "300"), - @WebInitParam(name = "maxEntries", value = "10000") -}) -public final class DoSFilter implements Filter { +public abstract class DoSFilter implements Filter { private Timer timer; private DoSTracker sourceAddrTracker; @@ -50,18 +44,16 @@ public final class DoSFilter implements Filter { @Override public void init(FilterConfig filterConfig) { int maxAccessPerTime = Integer.parseInt(filterConfig.getInitParameter("maxAccessPerTime")); + Preconditions.checkArgument(maxAccessPerTime > 0); int accessTimeSec = Integer.parseInt(filterConfig.getInitParameter("accessTimeSec")); + Preconditions.checkArgument(accessTimeSec > 0); long accessTimeMS = TimeUnit.MILLISECONDS.convert(accessTimeSec, TimeUnit.SECONDS); int maxEntries = Integer.parseInt(filterConfig.getInitParameter("maxEntries")); - timer = new Timer("DoSFilter"); - sourceAddrTracker = new DoSTracker(timer, maxAccessPerTime, accessTimeMS, maxEntries); - timer.scheduleAtFixedRate( - new TimerTask() { - @Override - public void run() { - System.gc(); - } - }, 0L, TimeUnit.MILLISECONDS.convert(15, TimeUnit.MINUTES)); + Preconditions.checkArgument(maxEntries > 0); + + String name = getClass().getSimpleName(); + timer = new Timer(name); + sourceAddrTracker = new DoSTracker(timer, name, maxAccessPerTime, accessTimeMS, maxEntries); } @Override @@ -79,9 +71,17 @@ public void doFilter(ServletRequest request, } private boolean isBanned(HttpServletRequest request) { - String remoteIPAddress = request.getHeader("x-forwarded-for"); + String remoteHost = request.getHeader("x-forwarded-for"); + if (remoteHost != null) { + int comma = remoteHost.indexOf(','); + if (comma >= 0) { + remoteHost = remoteHost.substring(0, comma); + } + remoteHost = remoteHost.trim(); + } + // Non-short-circuit "|" below is on purpose return - (remoteIPAddress != null && sourceAddrTracker.isBanned(remoteIPAddress)) || + (remoteHost != null && sourceAddrTracker.isBanned(remoteHost)) | sourceAddrTracker.isBanned(request.getRemoteAddr()); } diff --git a/zxingorg/src/main/java/com/google/zxing/web/DoSTracker.java b/zxingorg/src/main/java/com/google/zxing/web/DoSTracker.java index 9884ea8af0..2cbbefe94d 100644 --- a/zxingorg/src/main/java/com/google/zxing/web/DoSTracker.java +++ b/zxingorg/src/main/java/com/google/zxing/web/DoSTracker.java @@ -35,7 +35,7 @@ final class DoSTracker { private final long maxAccessesPerTime; private final Map numRecentAccesses; - DoSTracker(Timer timer, final int maxAccessesPerTime, long accessTimeMS, int maxEntries) { + DoSTracker(Timer timer, final String name, final int maxAccessesPerTime, long accessTimeMS, int maxEntries) { this.maxAccessesPerTime = maxAccessesPerTime; this.numRecentAccesses = new LRUMap<>(maxEntries); timer.schedule(new TimerTask() { @@ -51,8 +51,8 @@ public void run() { accessIt.remove(); } else { // Else it exceeded the max, so log it (again) - log.warning("Blocking " + entry.getKey() + " (" + count + " outstanding)"); - // Reduce count of accesses held against the IP + log.warning(name + ": Blocking " + entry.getKey() + " (" + count + " outstanding)"); + // Reduce count of accesses held against the host count.getAndAdd(-maxAccessesPerTime); } } @@ -70,8 +70,8 @@ boolean isBanned(String event) { synchronized (numRecentAccesses) { count = numRecentAccesses.get(event); if (count == null) { - count = new AtomicLong(); - numRecentAccesses.put(event, count); + numRecentAccesses.put(event, new AtomicLong(1)); + return false; } } return count.incrementAndGet() > maxAccessesPerTime; diff --git a/zxingorg/src/test/java/com/google/zxing/web/DoSFilterTestCase.java b/zxingorg/src/test/java/com/google/zxing/web/DoSFilterTestCase.java index 2ef3365378..24dd06e3d4 100644 --- a/zxingorg/src/test/java/com/google/zxing/web/DoSFilterTestCase.java +++ b/zxingorg/src/test/java/com/google/zxing/web/DoSFilterTestCase.java @@ -23,32 +23,79 @@ import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; +import javax.servlet.Filter; +import javax.servlet.ServletException; import javax.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.util.Arrays; /** - * Tests {@link DoSFilter}. + * Tests {@link DoSFilter} implementations. */ public final class DoSFilterTestCase extends Assert { + private static final int MAX_ACCESS_PER_TIME = 10; + @Test public void testRedirect() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setRequestURI("/"); - request.setRemoteAddr("1.2.3.4"); - HttpServletResponse response = new MockHttpServletResponse(); - DoSFilter filter = new DoSFilter(); + for (DoSFilter filter : Arrays.asList(new ChartDoSFilter(), new DecodeDoSFilter())) { + initFilter(filter); + try { + for (int i = 0; i < MAX_ACCESS_PER_TIME; i++) { + testRequest(filter, "1.2.3.4", null, HttpServletResponse.SC_OK); + } + testRequest(filter, "1.2.3.4", null, HttpServletResponse.SC_FORBIDDEN); + } finally { + filter.destroy(); + } + } + } + + @Test + public void testNoRemoteHost() throws Exception { + Filter filter = new DecodeDoSFilter(); + initFilter(filter); + try { + testRequest(filter, null, null, HttpServletResponse.SC_FORBIDDEN); + testRequest(filter, null, "1.1.1.1", HttpServletResponse.SC_FORBIDDEN); + } finally { + filter.destroy(); + } + } + + @Test + public void testProxy() throws Exception { + Filter filter = new DecodeDoSFilter(); + initFilter(filter); + try { + for (int i = 0; i < MAX_ACCESS_PER_TIME; i++) { + testRequest(filter, "1.2.3.4", "1.1.1." + i + ", proxy1", HttpServletResponse.SC_OK); + } + testRequest(filter, "1.2.3.4", "1.1.1.0", HttpServletResponse.SC_FORBIDDEN); + } finally { + filter.destroy(); + } + } + + private void initFilter(Filter filter) throws ServletException { MockFilterConfig config = new MockFilterConfig(); - int maxAccessPerTime = 10; - config.addInitParameter("maxAccessPerTime", Integer.toString(maxAccessPerTime)); + config.addInitParameter("maxAccessPerTime", Integer.toString(MAX_ACCESS_PER_TIME)); config.addInitParameter("accessTimeSec", "60"); config.addInitParameter("maxEntries", "100"); filter.init(config); - for (int i = 0; i < maxAccessPerTime; i++) { - filter.doFilter(request, response, new MockFilterChain()); - assertEquals(HttpServletResponse.SC_OK, response.getStatus()); + } + + private void testRequest(Filter filter, String host, String proxy, int expectedStatus) + throws IOException, ServletException { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setRequestURI("/"); + request.setRemoteAddr(host); + if (proxy != null) { + request.addHeader("X-Forwarded-For", proxy); } + HttpServletResponse response = new MockHttpServletResponse(); filter.doFilter(request, response, new MockFilterChain()); - assertEquals(HttpServletResponse.SC_FORBIDDEN, response.getStatus()); + assertEquals(expectedStatus, response.getStatus()); } } diff --git a/zxingorg/src/test/java/com/google/zxing/web/DoSTrackerTestCase.java b/zxingorg/src/test/java/com/google/zxing/web/DoSTrackerTestCase.java index 3471e51ce4..e6e77bda4c 100644 --- a/zxingorg/src/test/java/com/google/zxing/web/DoSTrackerTestCase.java +++ b/zxingorg/src/test/java/com/google/zxing/web/DoSTrackerTestCase.java @@ -31,7 +31,7 @@ public void testDoS() throws Exception { Timer timer = new Timer(); long timerTimeMS = 500; int maxAccessPerTime = 2; - DoSTracker tracker = new DoSTracker(timer, maxAccessPerTime, timerTimeMS, 3); + DoSTracker tracker = new DoSTracker(timer, "test", maxAccessPerTime, timerTimeMS, 3); // 2 requests allowed per time; 3rd should be banned assertFalse(tracker.isBanned("A"));