From 18e4b3596d25023ac5b81cbbd9139752399b1164 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20B=C3=B8rlum?= Date: Fri, 29 Apr 2011 14:57:58 +0200 Subject: [PATCH] Check replication parameters and log errors. --- .../replication/RegistryServlet.java | 70 +++++++++++-------- 1 file changed, 42 insertions(+), 28 deletions(-) diff --git a/subprojects/replication/src/main/java/com/trifork/stamdata/replication/replication/RegistryServlet.java b/subprojects/replication/src/main/java/com/trifork/stamdata/replication/replication/RegistryServlet.java index 21abfca3..aa7b5c28 100644 --- a/subprojects/replication/src/main/java/com/trifork/stamdata/replication/replication/RegistryServlet.java +++ b/subprojects/replication/src/main/java/com/trifork/stamdata/replication/replication/RegistryServlet.java @@ -25,8 +25,10 @@ package com.trifork.stamdata.replication.replication; import static com.google.common.base.Preconditions.checkNotNull; +import static java.net.HttpURLConnection.HTTP_BAD_REQUEST; import static java.net.HttpURLConnection.HTTP_OK; import static java.net.HttpURLConnection.HTTP_UNAUTHORIZED; +import static org.slf4j.LoggerFactory.getLogger; import java.io.IOException; import java.util.Map; @@ -37,6 +39,7 @@ import javax.servlet.http.HttpServletResponse; import org.hibernate.ScrollableResults; +import org.slf4j.Logger; import com.google.inject.Inject; import com.google.inject.Provider; @@ -52,6 +55,8 @@ @Singleton public class RegistryServlet extends HttpServlet { + private static final Logger logger = getLogger(RegistryServlet.class); + private static final long serialVersionUID = -172563300590543180L; private static final int DEFAULT_PAGE_SIZE = 10000; @@ -64,11 +69,8 @@ public class RegistryServlet extends HttpServlet { private final Provider usageLogger; @Inject - RegistryServlet(@Registry Map> registry, - Provider usageLogger, - Provider securityManager, - Provider recordDao, - Provider writers) { + RegistryServlet(@Registry Map> registry, Provider usageLogger, Provider securityManager, Provider recordDao, Provider writers) { + this.usageLogger = usageLogger; this.registry = checkNotNull(registry); this.recordDao = checkNotNull(recordDao); @@ -78,6 +80,9 @@ public class RegistryServlet extends HttpServlet { @Override protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { + + // Check authorization. + if (isNotAuthorized(request)) { setUnauthorizedHeaders(response); return; @@ -86,10 +91,34 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) t String viewName = getViewName(request); String clientId = securityManager.get().getClientId(request); + // Check the request parameters. + + String offsetParam = request.getParameter("offset"); + + if (offsetParam != null && !offsetParam.matches("[0-9]+")) { + + response.sendError(HTTP_BAD_REQUEST, "The 'offset' parameter must be a non-negative integer."); + logger.warn("Invalid parameter offset='{}'. ClientId='{}'.", offsetParam, securityManager.get().getClientId(request)); + return; + } + // Parse the offset query parameters. - HistoryOffset offset = new HistoryOffset(request.getParameter("offset")); - int count = getCount(request); + HistoryOffset offset = new HistoryOffset(offsetParam); + + // Get the count parameter. + + String countParam = request.getParameter("count"); + int count = DEFAULT_PAGE_SIZE; + + if (countParam != null && !countParam.matches("[1-9][0-9]+")) { + + response.sendError(HTTP_BAD_REQUEST, "The 'count' parameter must be a positive integer."); + logger.warn("Invalid parameter count='{}'. ClientId='{}'.", countParam, securityManager.get().getClientId(request)); + return; + } + + count = Integer.parseInt(countParam); // Determine what content type the client wants. // @@ -121,6 +150,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) t View newestRecord = (View) records.get(0); response.addHeader("Link", WebLinking.createNextLink(viewName, newestRecord.getOffset())); } + records.beforeFirst(); response.setStatus(HTTP_OK); @@ -138,39 +168,23 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) t } private boolean shouldBeLogged(Class entityType) { + UsageLogged usageLogged = entityType.getAnnotation(UsageLogged.class); return usageLogged == null || usageLogged.value(); } private boolean isNotAuthorized(HttpServletRequest request) { - // AUTHORIZE THE REQUEST - // - // The HTTP RFC specifies that if returning a 401 status - // the server MUST issue a challenge also. - // - // TODO: Log any unauthorized attempts. return !securityManager.get().isAuthorized(request); } private void setUnauthorizedHeaders(HttpServletResponse response) { - response.setStatus(HTTP_UNAUTHORIZED); - response.setHeader("WWW-Authenticate", "STAMDATA"); - } - private int getCount(HttpServletRequest request) { - String countParam = request.getParameter("count"); - int count = DEFAULT_PAGE_SIZE; + // The HTTP RFC specifies that if returning a 401 status + // the server MUST issue a challenge also. - if (countParam != null) { - try { - count = Integer.parseInt(countParam); - } - catch (NumberFormatException e) { - // Ignore this. We might decide to log this in the future. - } - } - return count; + response.setStatus(HTTP_UNAUTHORIZED); + response.setHeader("WWW-Authenticate", "STAMDATA"); } protected String getViewName(HttpServletRequest request) {