Skip to content

Commit

Permalink
Merge pull request #54 from allisson/master
Browse files Browse the repository at this point in the history
Add support for HEAD method
  • Loading branch information
xujiajun committed Sep 27, 2019
2 parents 9cbaecd + 71d212d commit aa8f99d
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 21 deletions.
2 changes: 2 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ go:
- 1.9.x
- 1.10.x
- 1.11.x
- 1.12.x
- 1.13.x
- tip
before_install:
- go get golang.org/x/tools/cmd/cover
Expand Down
39 changes: 26 additions & 13 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ var (
http.MethodPut: {},
http.MethodDelete: {},
http.MethodPatch: {},
http.MethodHead: {},
}
)

Expand Down Expand Up @@ -101,6 +102,12 @@ func (r *Router) PATCH(path string, handle http.HandlerFunc) {
r.Handle(http.MethodPatch, path, handle)
}

// HEAD adds the route `path` that matches a HEAD http method to
// execute the `handle` http.HandlerFunc.
func (r *Router) HEAD(path string, handle http.HandlerFunc) {
r.Handle(http.MethodHead, path, handle)
}

// GETAndName is short for `GET` and Named routeName
func (r *Router) GETAndName(path string, handle http.HandlerFunc, routeName string) {
r.parameters.routeName = routeName
Expand Down Expand Up @@ -131,6 +138,12 @@ func (r *Router) PATCHAndName(path string, handle http.HandlerFunc, routeName st
r.PATCH(path, handle)
}

// HEADAndName is short for `HEAD` and Named routeName
func (r *Router) HEADAndName(path string, handle http.HandlerFunc, routeName string) {
r.parameters.routeName = routeName
r.HEAD(path, handle)
}

// Group define routes groups if there is a path prefix that uses `prefix`
func (r *Router) Group(prefix string) *Router {
return &Router{
Expand Down Expand Up @@ -243,7 +256,7 @@ func GetAllParams(r *http.Request) paramsMapType {

// ServeHTTP makes the router implement the http.Handler interface.
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
requestUrl := req.URL.Path
requestURL := req.URL.Path

if r.PanicHandler != nil {
defer func() {
Expand All @@ -258,29 +271,29 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
return
}

nodes := r.trees[req.Method].Find(requestUrl, false)
nodes := r.trees[req.Method].Find(requestURL, false)
if len(nodes) > 0 {
node := nodes[0]

if node.handle != nil {
if node.path == requestUrl {
if node.path == requestURL {
handle(w, req, node.handle, node.middleware)
return
}
if node.path == requestUrl[1:] {
if node.path == requestURL[1:] {
handle(w, req, node.handle, node.middleware)
return
}
}
}

if len(nodes) == 0 {
res := strings.Split(requestUrl, "/")
res := strings.Split(requestURL, "/")
prefix := res[1]
nodes := r.trees[req.Method].Find(prefix, true)
for _, node := range nodes {
if handler := node.handle; handler != nil && node.path != requestUrl {
if matchParamsMap, ok := r.matchAndParse(requestUrl, node.path); ok {
if handler := node.handle; handler != nil && node.path != requestURL {
if matchParamsMap, ok := r.matchAndParse(requestURL, node.path); ok {
ctx := context.WithValue(req.Context(), contextKey, matchParamsMap)
req = req.WithContext(ctx)
handle(w, req, handler, node.middleware)
Expand Down Expand Up @@ -321,13 +334,13 @@ func handle(w http.ResponseWriter, req *http.Request, handler http.HandlerFunc,
}

// Match checks if the request matches the route pattern
func (r *Router) Match(requestUrl string, path string) bool {
_, ok := r.matchAndParse(requestUrl, path)
func (r *Router) Match(requestURL string, path string) bool {
_, ok := r.matchAndParse(requestURL, path)
return ok
}

// matchAndParse checks if the request matches the route path and returns a map of the parsed
func (r *Router) matchAndParse(requestUrl string, path string) (matchParams paramsMapType, b bool) {
func (r *Router) matchAndParse(requestURL string, path string) (matchParams paramsMapType, b bool) {
var (
matchName []string
pattern string
Expand Down Expand Up @@ -364,13 +377,13 @@ func (r *Router) matchAndParse(requestUrl string, path string) (matchParams para
}
}

if strings.HasSuffix(requestUrl, "/") {
if strings.HasSuffix(requestURL, "/") {
pattern = pattern + "/"
}

re := regexp.MustCompile(pattern)
if subMatch := re.FindSubmatch([]byte(requestUrl)); subMatch != nil {
if string(subMatch[0]) == requestUrl {
if subMatch := re.FindSubmatch([]byte(requestURL)); subMatch != nil {
if string(subMatch[0]) == requestURL {
subMatch = subMatch[1:]
for k, v := range subMatch {
matchParams[matchName[k]] = string(v)
Expand Down
49 changes: 41 additions & 8 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,27 @@ func TestRouter_PUT(t *testing.T) {
}
}

func TestRouter_HEAD(t *testing.T) {
router := New()
rr := httptest.NewRecorder()

req, err := http.NewRequest(http.MethodHead, "/hi", nil)

if err != nil {
t.Fatal(err)
}

router.HEAD("/hi", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, expected)
})
router.ServeHTTP(rr, req)

if rr.Body.String() != expected {
t.Errorf(errorFormat,
rr.Body.String(), expected)
}
}

func TestRouter_Group(t *testing.T) {
router := New()

Expand Down Expand Up @@ -239,7 +260,6 @@ func TestRouter_CustomPanicHandler(t *testing.T) {

router.POST("/aaa", func(w http.ResponseWriter, r *http.Request) {
panic("err")
fmt.Fprint(w, expected)
})
router.ServeHTTP(rr, req)
}
Expand Down Expand Up @@ -465,30 +485,30 @@ func TestRouter_HandlePanic(t *testing.T) {

func TestRouter_Match(t *testing.T) {
router := New()
requestUrl := "/xxx/1/yyy/2"
requestURL := "/xxx/1/yyy/2"

ok := router.Match(requestUrl, "/xxx/:param1/yyy/:param2")
ok := router.Match(requestURL, "/xxx/:param1/yyy/:param2")

if !ok {
t.Fatal("TestRouter_Match test fail")
}

errorRequestUrl := "#xxx#1#yyy#2"
ok = router.Match(errorRequestUrl, "/xxx/:param1/yyy/:param2")
errorRequestURL := "#xxx#1#yyy#2"
ok = router.Match(errorRequestURL, "/xxx/:param1/yyy/:param2")

if ok {
t.Fatal("TestRouter_Match test fail")
}

errorPath := "#xxx#1#yyy#2"
ok = router.Match(requestUrl, errorPath)
ok = router.Match(requestURL, errorPath)

if ok {
t.Fatal("TestRouter_Match test fail")
}

missRequestUrl := "/xxx/1/yyy/###"
ok = router.Match(missRequestUrl, "/xxx/:param1/yyy/:param2")
missRequestURL := "/xxx/1/yyy/###"
ok = router.Match(missRequestURL, "/xxx/:param1/yyy/:param2")

if ok {
t.Fatal("TestRouter_Match test fail")
Expand Down Expand Up @@ -612,4 +632,17 @@ func TestRouter_Generate(t *testing.T) {
if _, err := mux.Generate("METHOD", routeName5, params); err == nil {
t.Fatal("TestRouter_Generate test fail")
}

routeName8 := "user_event"
params = make(map[string]string)
params["user"] = "xujiajun"

//HEADAndName
mux.HEADAndName("/users/:user/events", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("/users/:user/events"))
}, routeName8)

if url, _ := mux.Generate(http.MethodHead, routeName1, params); url != "/users/xujiajun/events" {
t.Fatal("TestRouter_Generate test fail")
}
}

0 comments on commit aa8f99d

Please sign in to comment.