diff --git a/node.go b/node.go index 81cfb16..8f72ebc 100644 --- a/node.go +++ b/node.go @@ -34,11 +34,13 @@ func (n *node) regexpToString() string { func (n *node) setRegexp(exp string) { reg, err := regexp.Compile(exp) - if err == nil { - n.regexp = reg - n.isRegexp = true - n.isWildcard = true + if err != nil { + panic(err) } + + n.regexp = reg + n.isRegexp = true + n.isWildcard = true } func (n *node) setRoute(r *route) { diff --git a/router_test.go b/router_test.go index 776dba2..c571ce6 100644 --- a/router_test.go +++ b/router_test.go @@ -172,12 +172,7 @@ func TestOPTIONS(t *testing.T) { router := New().(*router) testBasicMethod(t, router, router.OPTIONS, OPTIONS) -} - -func TestOPTIONSWithoutHandler(t *testing.T) { - t.Parallel() - router := New().(*router) handler := &mockHandler{} router.GET("/x/y", handler) router.POST("/x/y", handler) @@ -185,6 +180,8 @@ func TestOPTIONSWithoutHandler(t *testing.T) { checkIfHasRootRoute(t, router, GET) w := httptest.NewRecorder() + + // test all routes "*" paths req, err := http.NewRequest(OPTIONS, "*", nil) if err != nil { t.Fatal(err) @@ -195,6 +192,18 @@ func TestOPTIONSWithoutHandler(t *testing.T) { if allow := w.Header().Get("Allow"); allow != "POST, GET, OPTIONS" { t.Errorf("Allow header incorrect value: %s", allow) } + + // test specific path + req, err = http.NewRequest(OPTIONS, "/x/y", nil) + if err != nil { + t.Fatal(err) + } + + router.ServeHTTP(w, req) + + if allow := w.Header().Get("Allow"); allow != "POST, GET, OPTIONS" { + t.Errorf("Allow header incorrect value: %s", allow) + } } func TestNotFound(t *testing.T) { @@ -483,7 +492,7 @@ func TestChainCalls(t *testing.T) { router := New().(*router) serverd := false - router.GET("/users/{user}/starred", http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + router.GET("/users/{user:[a-z0-9]+)}/starred", http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { serverd = true params, ok := FromContext(r.Context()) diff --git a/tree.go b/tree.go index 1d90a7d..d0e37ea 100644 --- a/tree.go +++ b/tree.go @@ -65,7 +65,7 @@ func (t *tree) getByID(id string) *node { } for _, child := range t.regexps { - if child.regexp.MatchString(id) { + if child.regexp != nil && child.regexp.MatchString(id) { return child } } @@ -96,7 +96,7 @@ func (t *tree) getByPath(path string) (*node, string, string) { } for _, child := range t.regexps { - if child.regexp.MatchString(part) { + if child.regexp != nil && child.regexp.MatchString(part) { return child, part, path[len(part):] } }