diff --git a/server/apiHandler.go b/server/apiHandler.go index d7ed1cf..be453d4 100644 --- a/server/apiHandler.go +++ b/server/apiHandler.go @@ -60,7 +60,7 @@ func (s *Service) apiHandler(w http.ResponseWriter, r *http.Request) { return } if err != nil { - httpError(w, err, s.enc) + httpError(w, err, s.enc, false) return } @@ -77,8 +77,12 @@ func (s *Service) apiHandler(w http.ResponseWriter, r *http.Request) { return } - if r.Method == "GET" { - rid := PathToRID(path, r.URL.RawQuery, apiPath) + var rid, action string + switch r.Method { + case "GET": + fallthrough + case "HEAD": + rid = PathToRID(path, r.URL.RawQuery, apiPath) if !codec.IsValidRID(rid, true) { notFoundHandler(w, r, s.enc) return @@ -94,12 +98,10 @@ func (s *Service) apiHandler(w http.ResponseWriter, r *http.Request) { }) }) return - } - var rid, action string - if r.Method == "POST" { + case "POST": rid, action = PathToRIDAction(path, r.URL.RawQuery, apiPath) - } else { + default: var m *string switch r.Method { case "PUT": @@ -117,10 +119,9 @@ func (s *Service) apiHandler(w http.ResponseWriter, r *http.Request) { } // Return error if we have no mapping for the method if m == nil { - httpError(w, reserr.ErrMethodNotAllowed, s.enc) + httpError(w, reserr.ErrMethodNotAllowed, s.enc, false) return } - rid = PathToRID(path, r.URL.RawQuery, apiPath) action = *m } @@ -131,7 +132,9 @@ func (s *Service) apiHandler(w http.ResponseWriter, r *http.Request) { func notFoundHandler(w http.ResponseWriter, r *http.Request, enc APIEncoder) { w.Header().Set("Content-Type", enc.ContentType()) w.WriteHeader(http.StatusNotFound) - w.Write(enc.NotFoundError()) + if r.Method != "HEAD" { + w.Write(enc.NotFoundError()) + } } func (s *Service) handleCall(w http.ResponseWriter, r *http.Request, rid string, action string) { @@ -143,7 +146,7 @@ func (s *Service) handleCall(w http.ResponseWriter, r *http.Request, rid string, // Try to parse the body b, err := ioutil.ReadAll(r.Body) if err != nil { - httpError(w, &reserr.Error{Code: reserr.CodeBadRequest, Message: "Error reading request body: " + err.Error()}, s.enc) + httpError(w, &reserr.Error{Code: reserr.CodeBadRequest, Message: "Error reading request body: " + err.Error()}, s.enc, false) return } @@ -151,7 +154,7 @@ func (s *Service) handleCall(w http.ResponseWriter, r *http.Request, rid string, if strings.TrimSpace(string(b)) != "" { err = json.Unmarshal(b, ¶ms) if err != nil { - httpError(w, &reserr.Error{Code: reserr.CodeBadRequest, Message: "Error decoding request body: " + err.Error()}, s.enc) + httpError(w, &reserr.Error{Code: reserr.CodeBadRequest, Message: "Error decoding request body: " + err.Error()}, s.enc, false) return } } @@ -174,7 +177,7 @@ func (s *Service) handleCall(w http.ResponseWriter, r *http.Request, rid string, func (s *Service) temporaryConn(w http.ResponseWriter, r *http.Request, cb func(*wsConn, func([]byte, error))) { c := s.newWSConn(nil, r, latestProtocol) if c == nil { - httpError(w, reserr.ErrServiceUnavailable, s.enc) + httpError(w, reserr.ErrServiceUnavailable, s.enc, r.Method == "HEAD") return } @@ -187,17 +190,19 @@ func (s *Service) temporaryConn(w http.ResponseWriter, r *http.Request, cb func( // Convert system.methodNotFound to system.methodNotAllowed for PUT/DELETE/PATCH if rerr, ok := err.(*reserr.Error); ok { if rerr.Code == reserr.CodeMethodNotFound && (r.Method == "PUT" || r.Method == "DELETE" || r.Method == "PATCH") { - httpError(w, reserr.ErrMethodNotAllowed, s.enc) + httpError(w, reserr.ErrMethodNotAllowed, s.enc, false) return } } - httpError(w, err, s.enc) + httpError(w, err, s.enc, r.Method == "HEAD") return } if len(out) > 0 { w.Header().Set("Content-Type", s.enc.ContentType()) - w.Write(out) + if r.Method != "HEAD" { + w.Write(out) + } return } @@ -215,9 +220,8 @@ func (s *Service) temporaryConn(w http.ResponseWriter, r *http.Request, cb func( <-done } -func httpError(w http.ResponseWriter, err error, enc APIEncoder) { +func httpError(w http.ResponseWriter, err error, enc APIEncoder, noBody bool) { rerr := reserr.RESError(err) - out := enc.EncodeError(rerr) var code int switch rerr.Code { @@ -243,5 +247,7 @@ func httpError(w http.ResponseWriter, err error, enc APIEncoder) { w.Header().Set("Content-Type", enc.ContentType()) w.WriteHeader(code) - w.Write(out) + if !noBody { + w.Write(enc.EncodeError(rerr)) + } } diff --git a/test/23http_head_test.go b/test/23http_head_test.go new file mode 100644 index 0000000..cebcd2c --- /dev/null +++ b/test/23http_head_test.go @@ -0,0 +1,71 @@ +package test + +import ( + "encoding/json" + "fmt" + "net/http" + "testing" + + "github.com/resgateio/resgate/server/reserr" +) + +// Test invalid urls for HTTP get requests +func TestHTTPMethodHEAD_InvalidURLs_CorrectStatus(t *testing.T) { + tbl := []struct { + URL string // Url path + ExpectedCode int + }{ + {"/wrong_prefix/test/model", http.StatusNotFound}, + {"/api/", http.StatusNotFound}, + {"/api/test.model", http.StatusNotFound}, + {"/api/test/model/", http.StatusNotFound}, + {"/api/test//model", http.StatusNotFound}, + {"/api/test/mådel/action", http.StatusNotFound}, + } + + for i, l := range tbl { + runNamedTest(t, fmt.Sprintf("#%d", i+1), func(s *Session) { + s.HTTPRequest("HEAD", l.URL, nil). + GetResponse(t). + AssertStatusCode(t, l.ExpectedCode). + AssertBody(t, nil) + }) + } +} + +func TestHTTPHead_OnSuccess_NoBody(t *testing.T) { + model := resourceData("test.model") + runTest(t, func(s *Session) { + hreq := s.HTTPRequest("HEAD", "/api/test/model", nil) + + /// Handle model get and access request + mreqs := s.GetParallelRequests(t, 2) + req := mreqs.GetRequest(t, "access.test.model") + req.RespondSuccess(json.RawMessage(`{"get":true}`)) + req = mreqs.GetRequest(t, "get.test.model") + req.RespondSuccess(json.RawMessage(`{"model":` + model + `}`)) + + // Validate http response + hreq.GetResponse(t). + AssertStatusCode(t, http.StatusOK). + AssertBody(t, nil) + }) +} + +func TestHTTPHead_OnError_NoBody(t *testing.T) { + runTest(t, func(s *Session) { + hreq := s.HTTPRequest("HEAD", "/api/test/model", nil) + + /// Handle model get and access request + mreqs := s.GetParallelRequests(t, 2) + req := mreqs.GetRequest(t, "access.test.model") + req.RespondSuccess(json.RawMessage(`{"get":true}`)) + req = mreqs.GetRequest(t, "get.test.model") + req.RespondError(reserr.ErrNotFound) + + // Validate http response + hreq.GetResponse(t). + AssertStatusCode(t, http.StatusNotFound). + AssertBody(t, nil) + }) +}