Skip to content

Commit

Permalink
feat: relax group; fix: exact match group not found handle
Browse files Browse the repository at this point in the history
  • Loading branch information
tonny-zhang committed Mar 2, 2022
1 parent af38a26 commit 48edaf5
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 68 deletions.
69 changes: 2 additions & 67 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type Router struct {

notfoundHandlers []HandlerFunc

groups []*Router
groups groupArr

domains map[string]*Router

Expand All @@ -50,70 +50,6 @@ func Default() *Router {
return router
}

// Group get group router
func (router *Router) Group(path string, handler ...HandlerFunc) *Router {
if len(path) == 0 || path[0] != '/' {
panic(fmt.Errorf("group [%s] must start with /", path))
}
if strings.Index(path, "*") > -1 || strings.Index(path, ":") > -1 {
panic(fmt.Errorf("group path [%s] can not has parameter", path))
}
prefix := utils.CleanPath(path + "/")
matchedGroup := router.matchGroup(prefix)
if matchedGroup != nil {
panic(fmt.Errorf("group [%s] conflicts with [%s]", prefix, matchedGroup.prefix))
}
if router.prefix != "" {
prefix = utils.CleanPath(router.prefix + "/" + prefix)
}
r := &Router{
prefix: prefix,
domain: router.domain,
trees: router.trees,
middlewares: router.middlewares,
notfoundHandlers: router.notfoundHandlers,
}
r.middlewares = append(r.middlewares, handler...)
router.groups = append(router.groups, r)
return r
}
func (router *Router) matchGroup(path string) *Router {
for _, g := range router.groups {
if len(g.groups) > 0 {
gg := g.matchGroup(path)
if gg != nil {
return gg
}
}
if matchGroup(g, path) {
return g
}
}
return nil
}
func matchGroup(router *Router, path string) bool {
if len(router.prefix) > 0 {
if strings.HasPrefix(path, router.prefix) {
return true
}
arrRP := strings.Split(router.prefix, "/")
arrPath := strings.Split(path, "/")
if len(arrPath) < len(arrRP) {
return false
}

for i, j := 0, len(arrRP); i < j; i++ {
if i == j-1 && arrRP[i] == "" {
return true
}
if arrRP[i] != arrPath[i] {
return false
}
}
}
return false
}

// NotFound custom NotFoundHandler
func (router *Router) NotFound(handler ...HandlerFunc) {
router.notfoundHandlers = handler
Expand All @@ -133,10 +69,8 @@ func (router *Router) addHandleFunc(method, path string, handler HandlerFunc) {
}
nodeAdded := router.trees[method].add(path, nil)
nodeAdded.middleware = append(nodeAdded.middleware, router.middlewares...)
// nodeAdded.handler = handler
nodeAdded.middleware = append(nodeAdded.middleware, handler)
router.hasHandled = true
// debugPrint("list domain [%s]", router.domain)
debugPrintRoute(method, router.domain+path, handler)
}

Expand Down Expand Up @@ -246,6 +180,7 @@ func (router *Router) Run(addr string) error {
}
r.groups = groupsNew
}
router.sort() // 对group进行排序
debugPrint("Listening and serving HTTP on %s\n", addr)

srv := &http.Server{
Expand Down
103 changes: 103 additions & 0 deletions router_group.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package cotton

import (
"fmt"
"sort"
"strings"

"github.com/tonny-zhang/cotton/utils"
)

type groupArr []*Router

func (s groupArr) Len() int { return len(s) }
func (s groupArr) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (s groupArr) Less(i, j int) bool {
return len(strings.Split(s[i].prefix, "/")) > len(strings.Split(s[j].prefix, "/"))
}

// Group get group router
func (router *Router) Group(path string, handler ...HandlerFunc) *Router {
if len(path) == 0 || path[0] != '/' {
panic(fmt.Errorf("group [%s] must start with /", path))
}
if strings.Index(path, "*") > -1 || strings.Index(path, ":") > -1 {
panic(fmt.Errorf("group path [%s] can not has parameter", path))
}
prefix := utils.CleanPath(path + "/")
hasGroup := router.hasGroup(prefix)
if hasGroup {
panic(fmt.Errorf("group [%s] is setted", prefix))
}
if router.prefix != "" {
prefix = utils.CleanPath(router.prefix + "/" + prefix)
}
r := &Router{
prefix: prefix,
domain: router.domain,
trees: router.trees,
middlewares: router.middlewares,
notfoundHandlers: router.notfoundHandlers,
}
r.middlewares = append(r.middlewares, handler...)
router.groups = append(router.groups, r)
return r
}
func (router *Router) sort() {
if len(router.groups) > 0 {
for _, g := range router.groups {
g.sort()
}
sort.Sort(router.groups)
}
}
func (router *Router) hasGroup(path string) bool {
for _, g := range router.groups {
if len(g.groups) > 0 {
has := g.hasGroup(path)
if has {
return true
}
}
if g.prefix == path {
return true
}
}
return false
}
func (router *Router) matchGroup(path string) *Router {
for _, g := range router.groups {
if len(g.groups) > 0 {
matchedGroup := g.matchGroup(path)
if matchedGroup != nil {
return matchedGroup
}
}
if matchGroup(g, path) {
return g
}
}
return nil
}
func matchGroup(router *Router, path string) bool {
if len(router.prefix) > 0 {
if strings.HasPrefix(path, router.prefix) {
return true
}
arrRP := strings.Split(router.prefix, "/")
arrPath := strings.Split(path, "/")
if len(arrPath) < len(arrRP) {
return false
}

for i, j := 0, len(arrRP); i < j; i++ {
if i == j-1 && arrRP[i] == "" {
return true
}
if arrRP[i] != arrPath[i] {
return false
}
}
}
return false
}
55 changes: 54 additions & 1 deletion router_group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,18 @@ func TestGroupPanic(t *testing.T) {
router.Group("abc")
})

assert.PanicsWithError(t, "group [/a/] conflicts with [/a/]", func() {
assert.PanicsWithError(t, "group [/a/] is setted", func() {
router := NewRouter()
router.Group("/a")
router.Group("/a")
})

assert.PanicsWithError(t, "group [/a/b/] is setted", func() {
router := NewRouter()
router.Group("/a").Group("/b")
router.Group("/a//b")
})

assert.PanicsWithError(t, "group path [/:method] can not has parameter", func() {
router := NewRouter()
router.Group("/:method")
Expand Down Expand Up @@ -145,3 +151,50 @@ func TestGroupMulty(t *testing.T) {
w = doRequest(router, http.MethodGet, "/a/b/test")
assert.Equal(t, "/a/b/test", w.Body.String())
}

func TestCustomGroupNotFoundOrder(t *testing.T) {

infoCustomNotFound := "not found from custom"
infoCustomGroupNotFound := "not found from custom group"
infoCustomGroupUserNotFound := "not found from custom group user"

{
router := NewRouter()
router.NotFound(func(ctx *Context) {
ctx.String(http.StatusNotFound, infoCustomNotFound)
})
g := router.Group("/v1")
g.NotFound(func(ctx *Context) {
ctx.String(http.StatusNotFound, infoCustomGroupNotFound)
})
gUser := router.Group("/v1/user")
gUser.NotFound(func(ctx *Context) {
ctx.String(http.StatusNotFound, infoCustomGroupUserNotFound)
})

router.sort()
w := doRequest(router, http.MethodGet, "/v1/user/path404")
assert.Equal(t, http.StatusNotFound, w.Code)
assert.Equal(t, infoCustomGroupUserNotFound, w.Body.String())
}

{
router := NewRouter()
router.NotFound(func(ctx *Context) {
ctx.String(http.StatusNotFound, infoCustomNotFound)
})
gUser := router.Group("/v1/user")
gUser.NotFound(func(ctx *Context) {
ctx.String(http.StatusNotFound, infoCustomGroupUserNotFound)
})
g := router.Group("/v1")
g.NotFound(func(ctx *Context) {
ctx.String(http.StatusNotFound, infoCustomGroupNotFound)
})

router.sort()
w := doRequest(router, http.MethodGet, "/v1/user/path404")
assert.Equal(t, http.StatusNotFound, w.Code)
assert.Equal(t, infoCustomGroupUserNotFound, w.Body.String())
}
}

0 comments on commit 48edaf5

Please sign in to comment.