Skip to content

Commit

Permalink
GH-151: Added HTTP HEAD method request support.
Browse files Browse the repository at this point in the history
  • Loading branch information
jirenius committed Mar 19, 2020
1 parent 4bda354 commit 7c1daab
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 19 deletions.
44 changes: 25 additions & 19 deletions server/apiHandler.go
Expand Up @@ -60,7 +60,7 @@ func (s *Service) apiHandler(w http.ResponseWriter, r *http.Request) {
return
}
if err != nil {
httpError(w, err, s.enc)
httpError(w, err, s.enc, false)
return
}

Expand All @@ -77,8 +77,12 @@ func (s *Service) apiHandler(w http.ResponseWriter, r *http.Request) {
return
}

if r.Method == "GET" {
rid := PathToRID(path, r.URL.RawQuery, apiPath)
var rid, action string
switch r.Method {
case "GET":
fallthrough
case "HEAD":
rid = PathToRID(path, r.URL.RawQuery, apiPath)
if !codec.IsValidRID(rid, true) {
notFoundHandler(w, r, s.enc)
return
Expand All @@ -94,12 +98,10 @@ func (s *Service) apiHandler(w http.ResponseWriter, r *http.Request) {
})
})
return
}

var rid, action string
if r.Method == "POST" {
case "POST":
rid, action = PathToRIDAction(path, r.URL.RawQuery, apiPath)
} else {
default:
var m *string
switch r.Method {
case "PUT":
Expand All @@ -117,10 +119,9 @@ func (s *Service) apiHandler(w http.ResponseWriter, r *http.Request) {
}
// Return error if we have no mapping for the method
if m == nil {
httpError(w, reserr.ErrMethodNotAllowed, s.enc)
httpError(w, reserr.ErrMethodNotAllowed, s.enc, false)
return
}

rid = PathToRID(path, r.URL.RawQuery, apiPath)
action = *m
}
Expand All @@ -131,7 +132,9 @@ func (s *Service) apiHandler(w http.ResponseWriter, r *http.Request) {
func notFoundHandler(w http.ResponseWriter, r *http.Request, enc APIEncoder) {
w.Header().Set("Content-Type", enc.ContentType())
w.WriteHeader(http.StatusNotFound)
w.Write(enc.NotFoundError())
if r.Method != "HEAD" {
w.Write(enc.NotFoundError())
}
}

func (s *Service) handleCall(w http.ResponseWriter, r *http.Request, rid string, action string) {
Expand All @@ -143,15 +146,15 @@ func (s *Service) handleCall(w http.ResponseWriter, r *http.Request, rid string,
// Try to parse the body
b, err := ioutil.ReadAll(r.Body)
if err != nil {
httpError(w, &reserr.Error{Code: reserr.CodeBadRequest, Message: "Error reading request body: " + err.Error()}, s.enc)
httpError(w, &reserr.Error{Code: reserr.CodeBadRequest, Message: "Error reading request body: " + err.Error()}, s.enc, false)
return
}

var params json.RawMessage
if strings.TrimSpace(string(b)) != "" {
err = json.Unmarshal(b, &params)
if err != nil {
httpError(w, &reserr.Error{Code: reserr.CodeBadRequest, Message: "Error decoding request body: " + err.Error()}, s.enc)
httpError(w, &reserr.Error{Code: reserr.CodeBadRequest, Message: "Error decoding request body: " + err.Error()}, s.enc, false)
return
}
}
Expand All @@ -174,7 +177,7 @@ func (s *Service) handleCall(w http.ResponseWriter, r *http.Request, rid string,
func (s *Service) temporaryConn(w http.ResponseWriter, r *http.Request, cb func(*wsConn, func([]byte, error))) {
c := s.newWSConn(nil, r, latestProtocol)
if c == nil {
httpError(w, reserr.ErrServiceUnavailable, s.enc)
httpError(w, reserr.ErrServiceUnavailable, s.enc, r.Method == "HEAD")
return
}

Expand All @@ -187,17 +190,19 @@ func (s *Service) temporaryConn(w http.ResponseWriter, r *http.Request, cb func(
// Convert system.methodNotFound to system.methodNotAllowed for PUT/DELETE/PATCH
if rerr, ok := err.(*reserr.Error); ok {
if rerr.Code == reserr.CodeMethodNotFound && (r.Method == "PUT" || r.Method == "DELETE" || r.Method == "PATCH") {
httpError(w, reserr.ErrMethodNotAllowed, s.enc)
httpError(w, reserr.ErrMethodNotAllowed, s.enc, false)
return
}
}
httpError(w, err, s.enc)
httpError(w, err, s.enc, r.Method == "HEAD")
return
}

if len(out) > 0 {
w.Header().Set("Content-Type", s.enc.ContentType())
w.Write(out)
if r.Method != "HEAD" {
w.Write(out)
}
return
}

Expand All @@ -215,9 +220,8 @@ func (s *Service) temporaryConn(w http.ResponseWriter, r *http.Request, cb func(
<-done
}

func httpError(w http.ResponseWriter, err error, enc APIEncoder) {
func httpError(w http.ResponseWriter, err error, enc APIEncoder, noBody bool) {
rerr := reserr.RESError(err)
out := enc.EncodeError(rerr)

var code int
switch rerr.Code {
Expand All @@ -243,5 +247,7 @@ func httpError(w http.ResponseWriter, err error, enc APIEncoder) {

w.Header().Set("Content-Type", enc.ContentType())
w.WriteHeader(code)
w.Write(out)
if !noBody {
w.Write(enc.EncodeError(rerr))
}
}
71 changes: 71 additions & 0 deletions test/23http_head_test.go
@@ -0,0 +1,71 @@
package test

import (
"encoding/json"
"fmt"
"net/http"
"testing"

"github.com/resgateio/resgate/server/reserr"
)

// Test invalid urls for HTTP get requests
func TestHTTPMethodHEAD_InvalidURLs_CorrectStatus(t *testing.T) {
tbl := []struct {
URL string // Url path
ExpectedCode int
}{
{"/wrong_prefix/test/model", http.StatusNotFound},
{"/api/", http.StatusNotFound},
{"/api/test.model", http.StatusNotFound},
{"/api/test/model/", http.StatusNotFound},
{"/api/test//model", http.StatusNotFound},
{"/api/test/mådel/action", http.StatusNotFound},
}

for i, l := range tbl {
runNamedTest(t, fmt.Sprintf("#%d", i+1), func(s *Session) {
s.HTTPRequest("HEAD", l.URL, nil).
GetResponse(t).
AssertStatusCode(t, l.ExpectedCode).
AssertBody(t, nil)
})
}
}

func TestHTTPHead_OnSuccess_NoBody(t *testing.T) {
model := resourceData("test.model")
runTest(t, func(s *Session) {
hreq := s.HTTPRequest("HEAD", "/api/test/model", nil)

/// Handle model get and access request
mreqs := s.GetParallelRequests(t, 2)
req := mreqs.GetRequest(t, "access.test.model")
req.RespondSuccess(json.RawMessage(`{"get":true}`))
req = mreqs.GetRequest(t, "get.test.model")
req.RespondSuccess(json.RawMessage(`{"model":` + model + `}`))

// Validate http response
hreq.GetResponse(t).
AssertStatusCode(t, http.StatusOK).
AssertBody(t, nil)
})
}

func TestHTTPHead_OnError_NoBody(t *testing.T) {
runTest(t, func(s *Session) {
hreq := s.HTTPRequest("HEAD", "/api/test/model", nil)

/// Handle model get and access request
mreqs := s.GetParallelRequests(t, 2)
req := mreqs.GetRequest(t, "access.test.model")
req.RespondSuccess(json.RawMessage(`{"get":true}`))
req = mreqs.GetRequest(t, "get.test.model")
req.RespondError(reserr.ErrNotFound)

// Validate http response
hreq.GetResponse(t).
AssertStatusCode(t, http.StatusNotFound).
AssertBody(t, nil)
})
}

0 comments on commit 7c1daab

Please sign in to comment.