diff --git a/router.go b/router.go index b09e08f..38d2c04 100644 --- a/router.go +++ b/router.go @@ -25,7 +25,7 @@ type Router struct { notfoundHandlers []HandlerFunc - groups []*Router + groups groupArr domains map[string]*Router @@ -50,70 +50,6 @@ func Default() *Router { return router } -// Group get group router -func (router *Router) Group(path string, handler ...HandlerFunc) *Router { - if len(path) == 0 || path[0] != '/' { - panic(fmt.Errorf("group [%s] must start with /", path)) - } - if strings.Index(path, "*") > -1 || strings.Index(path, ":") > -1 { - panic(fmt.Errorf("group path [%s] can not has parameter", path)) - } - prefix := utils.CleanPath(path + "/") - matchedGroup := router.matchGroup(prefix) - if matchedGroup != nil { - panic(fmt.Errorf("group [%s] conflicts with [%s]", prefix, matchedGroup.prefix)) - } - if router.prefix != "" { - prefix = utils.CleanPath(router.prefix + "/" + prefix) - } - r := &Router{ - prefix: prefix, - domain: router.domain, - trees: router.trees, - middlewares: router.middlewares, - notfoundHandlers: router.notfoundHandlers, - } - r.middlewares = append(r.middlewares, handler...) - router.groups = append(router.groups, r) - return r -} -func (router *Router) matchGroup(path string) *Router { - for _, g := range router.groups { - if len(g.groups) > 0 { - gg := g.matchGroup(path) - if gg != nil { - return gg - } - } - if matchGroup(g, path) { - return g - } - } - return nil -} -func matchGroup(router *Router, path string) bool { - if len(router.prefix) > 0 { - if strings.HasPrefix(path, router.prefix) { - return true - } - arrRP := strings.Split(router.prefix, "/") - arrPath := strings.Split(path, "/") - if len(arrPath) < len(arrRP) { - return false - } - - for i, j := 0, len(arrRP); i < j; i++ { - if i == j-1 && arrRP[i] == "" { - return true - } - if arrRP[i] != arrPath[i] { - return false - } - } - } - return false -} - // NotFound custom NotFoundHandler func (router *Router) NotFound(handler ...HandlerFunc) { router.notfoundHandlers = handler @@ -133,10 +69,8 @@ func (router *Router) addHandleFunc(method, path string, handler HandlerFunc) { } nodeAdded := router.trees[method].add(path, nil) nodeAdded.middleware = append(nodeAdded.middleware, router.middlewares...) - // nodeAdded.handler = handler nodeAdded.middleware = append(nodeAdded.middleware, handler) router.hasHandled = true - // debugPrint("list domain [%s]", router.domain) debugPrintRoute(method, router.domain+path, handler) } @@ -246,6 +180,7 @@ func (router *Router) Run(addr string) error { } r.groups = groupsNew } + router.sort() // 对group进行排序 debugPrint("Listening and serving HTTP on %s\n", addr) srv := &http.Server{ diff --git a/router_group.go b/router_group.go new file mode 100644 index 0000000..358b01a --- /dev/null +++ b/router_group.go @@ -0,0 +1,103 @@ +package cotton + +import ( + "fmt" + "sort" + "strings" + + "github.com/tonny-zhang/cotton/utils" +) + +type groupArr []*Router + +func (s groupArr) Len() int { return len(s) } +func (s groupArr) Swap(i, j int) { s[i], s[j] = s[j], s[i] } +func (s groupArr) Less(i, j int) bool { + return len(strings.Split(s[i].prefix, "/")) > len(strings.Split(s[j].prefix, "/")) +} + +// Group get group router +func (router *Router) Group(path string, handler ...HandlerFunc) *Router { + if len(path) == 0 || path[0] != '/' { + panic(fmt.Errorf("group [%s] must start with /", path)) + } + if strings.Index(path, "*") > -1 || strings.Index(path, ":") > -1 { + panic(fmt.Errorf("group path [%s] can not has parameter", path)) + } + prefix := utils.CleanPath(path + "/") + hasGroup := router.hasGroup(prefix) + if hasGroup { + panic(fmt.Errorf("group [%s] is setted", prefix)) + } + if router.prefix != "" { + prefix = utils.CleanPath(router.prefix + "/" + prefix) + } + r := &Router{ + prefix: prefix, + domain: router.domain, + trees: router.trees, + middlewares: router.middlewares, + notfoundHandlers: router.notfoundHandlers, + } + r.middlewares = append(r.middlewares, handler...) + router.groups = append(router.groups, r) + return r +} +func (router *Router) sort() { + if len(router.groups) > 0 { + for _, g := range router.groups { + g.sort() + } + sort.Sort(router.groups) + } +} +func (router *Router) hasGroup(path string) bool { + for _, g := range router.groups { + if len(g.groups) > 0 { + has := g.hasGroup(path) + if has { + return true + } + } + if g.prefix == path { + return true + } + } + return false +} +func (router *Router) matchGroup(path string) *Router { + for _, g := range router.groups { + if len(g.groups) > 0 { + matchedGroup := g.matchGroup(path) + if matchedGroup != nil { + return matchedGroup + } + } + if matchGroup(g, path) { + return g + } + } + return nil +} +func matchGroup(router *Router, path string) bool { + if len(router.prefix) > 0 { + if strings.HasPrefix(path, router.prefix) { + return true + } + arrRP := strings.Split(router.prefix, "/") + arrPath := strings.Split(path, "/") + if len(arrPath) < len(arrRP) { + return false + } + + for i, j := 0, len(arrRP); i < j; i++ { + if i == j-1 && arrRP[i] == "" { + return true + } + if arrRP[i] != arrPath[i] { + return false + } + } + } + return false +} diff --git a/router_group_test.go b/router_group_test.go index c08bead..74b7cc4 100644 --- a/router_group_test.go +++ b/router_group_test.go @@ -49,12 +49,18 @@ func TestGroupPanic(t *testing.T) { router.Group("abc") }) - assert.PanicsWithError(t, "group [/a/] conflicts with [/a/]", func() { + assert.PanicsWithError(t, "group [/a/] is setted", func() { router := NewRouter() router.Group("/a") router.Group("/a") }) + assert.PanicsWithError(t, "group [/a/b/] is setted", func() { + router := NewRouter() + router.Group("/a").Group("/b") + router.Group("/a//b") + }) + assert.PanicsWithError(t, "group path [/:method] can not has parameter", func() { router := NewRouter() router.Group("/:method") @@ -145,3 +151,50 @@ func TestGroupMulty(t *testing.T) { w = doRequest(router, http.MethodGet, "/a/b/test") assert.Equal(t, "/a/b/test", w.Body.String()) } + +func TestCustomGroupNotFoundOrder(t *testing.T) { + + infoCustomNotFound := "not found from custom" + infoCustomGroupNotFound := "not found from custom group" + infoCustomGroupUserNotFound := "not found from custom group user" + + { + router := NewRouter() + router.NotFound(func(ctx *Context) { + ctx.String(http.StatusNotFound, infoCustomNotFound) + }) + g := router.Group("/v1") + g.NotFound(func(ctx *Context) { + ctx.String(http.StatusNotFound, infoCustomGroupNotFound) + }) + gUser := router.Group("/v1/user") + gUser.NotFound(func(ctx *Context) { + ctx.String(http.StatusNotFound, infoCustomGroupUserNotFound) + }) + + router.sort() + w := doRequest(router, http.MethodGet, "/v1/user/path404") + assert.Equal(t, http.StatusNotFound, w.Code) + assert.Equal(t, infoCustomGroupUserNotFound, w.Body.String()) + } + + { + router := NewRouter() + router.NotFound(func(ctx *Context) { + ctx.String(http.StatusNotFound, infoCustomNotFound) + }) + gUser := router.Group("/v1/user") + gUser.NotFound(func(ctx *Context) { + ctx.String(http.StatusNotFound, infoCustomGroupUserNotFound) + }) + g := router.Group("/v1") + g.NotFound(func(ctx *Context) { + ctx.String(http.StatusNotFound, infoCustomGroupNotFound) + }) + + router.sort() + w := doRequest(router, http.MethodGet, "/v1/user/path404") + assert.Equal(t, http.StatusNotFound, w.Code) + assert.Equal(t, infoCustomGroupUserNotFound, w.Body.String()) + } +}