diff --git a/runtime/router/router.go b/runtime/router/router.go index 6b1ebda28..51b3036ca 100644 --- a/runtime/router/router.go +++ b/runtime/router/router.go @@ -23,6 +23,8 @@ package router import ( "context" "net/http" + "sort" + "strings" ) // Router dispatches http requests to a registered http.Handler. @@ -30,7 +32,7 @@ import ( // the main differences are: // 1. this router does not treat "/a/:b" and "/a/b/c" as conflicts (https://github.com/julienschmidt/httprouter/issues/175) // 2. this router does not treat "/a/:b" and "/a/:c" as different routes (https://github.com/julienschmidt/httprouter/issues/6) -// 3. this router does tno treat "/a" and "/a/" as different routes +// 3. this router does not treat "/a" and "/a/" as different routes type Router struct { tries map[string]*Trie @@ -132,21 +134,20 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { } func (r *Router) allowed(path, reqMethod string) string { - allow := "" + var allow []string for method, trie := range r.tries { - if method == reqMethod || method == "OPTIONS" { + if method == reqMethod || method == http.MethodOptions { continue } if _, _, err := trie.Get(path); err == nil { - if len(allow) == 0 { - allow = method - } else { - allow += ", " + method - } + allow = append(allow, method) } } + sort.Slice(allow, func(i, j int) bool { + return allow[i] < allow[j] + }) - return allow + return strings.Join(allow, ", ") } diff --git a/runtime/router/trie.go b/runtime/router/trie.go index 10700fae6..5f3b19511 100644 --- a/runtime/router/trie.go +++ b/runtime/router/trie.go @@ -71,7 +71,12 @@ func NewTrie() *Trie { } } -// Set sets the value for given path, returns error if path already set +// Set sets the value for given path, returns error if path already set. +// When a http.Handler is registered for a given path, a subsequent Get returns the registered +// handler if the url passed to Get call matches the set path. Match in this context could mean either +// equality (e.g. url is "/foo" and path is "/foo") or url matches path pattern, which has two forms: +// - path ends with "/*", e.g. url "/foo" and "/foo/bar" both matches path "/*" +// - path contains colon wildcard ("/:"), e.g. url "/a/b" and "/a/c" bot matches path "/a/:var" func (t *Trie) Set(path string, value http.Handler) error { if path == "" || strings.Contains(path, "//") { return errPath @@ -90,14 +95,16 @@ func (t *Trie) Set(path string, value http.Handler) error { return errors.New("path can not contain more than one *") } - err := t.root.set(path, value) + err := t.root.set(path, value, false, false) if e, ok := err.(*paramMismatch); ok { return fmt.Errorf("path %q has a different param key %q, it should be the same key %q as in existing path %q", path, e.actual, e.expected, e.existingPath) } return err } -// Get returns the value for given path, returns error if not found +// Get returns the http.Handler for given path, returns error if not found. +// It also returns the url params if given path contains any, e.g. if a handler is registered for +// "/:foo/bar", then calling Get with path "/xyz/bar" returns a param whose key is "foo" and value is "xyz". func (t *Trie) Get(path string) (http.Handler, []Param, error) { if path == "" || strings.Contains(path, "//") { return nil, nil, errPath @@ -107,64 +114,73 @@ func (t *Trie) Get(path string) (http.Handler, []Param, error) { } // ignore trailing slash path = strings.TrimSuffix(path, "/") - return t.root.get(path) + return t.root.get(path, false, false, false) } -func (t *tnode) set(path string, value http.Handler) error { +// set sets the handler for given path, creates new child node if necessary +// lastKeyCharSlash tracks whether the previous key char is a '/', used to decide it is a pattern or not +// when the current key char is ':'. lastPathCharSlash tracks whether the previous path char is a '/', +// used to decide it is a pattern or not when the current path char is ':'. +func (t *tnode) set(path string, value http.Handler, lastKeyCharSlash, lastPathCharSlash bool) error { // find the longest common prefix - var l, i int - m, n := len(t.key), len(path) - if m > n { - l = n + var shorterLength, i int + keyLength, pathLength := len(t.key), len(path) + if keyLength > pathLength { + shorterLength = pathLength } else { - l = m + shorterLength = keyLength } - for i < l && t.key[i] == path[i] { + for i < shorterLength && t.key[i] == path[i] { i++ } - // find index j, k in key and path to which they match, - // j and k is only useful to check if there is conflict, - // they are not where splits happen, splits happen at index i. - var j, k int - for j < m && k < n { - if t.key[j] == ':' || path[k] == ':' { - oj, ok := j, k - same := t.key[j] == path[k] - for j < m && t.key[j] != '/' { - j++ + // Find the first character that differs between "path" and this node's key, if it exists. + // If we encounter a colon wildcard, ensure that the wildcard in path matches the wildcard + // in this node's key for that segment. The segment is a colon wildcard only when the colon + // is immediately after slash, e.g. "/:foo", "/x/:y". "/a:b" is not a colon wildcard segment. + var keyMatchIdx, pathMatchIdx int + for keyMatchIdx < keyLength && pathMatchIdx < pathLength { + if (t.key[keyMatchIdx] == ':' && lastKeyCharSlash) || + (path[pathMatchIdx] == ':' && lastPathCharSlash) { + keyStartIdx, pathStartIdx := keyMatchIdx, pathMatchIdx + same := t.key[keyMatchIdx] == path[pathMatchIdx] + for keyMatchIdx < keyLength && t.key[keyMatchIdx] != '/' { + keyMatchIdx++ } - for k < n && path[k] != '/' { - k++ + for pathMatchIdx < pathLength && path[pathMatchIdx] != '/' { + pathMatchIdx++ } - if same && (j-oj) != (k-ok) { + if same && (keyMatchIdx-keyStartIdx) != (pathMatchIdx-pathStartIdx) { return ¶mMismatch{ - t.key[oj:j], - path[ok:k], + t.key[keyStartIdx:keyMatchIdx], + path[pathStartIdx:pathMatchIdx], t.key, } } - } else if t.key[j] == path[k] { - j++ - k++ + } else if t.key[keyMatchIdx] == path[pathMatchIdx] { + keyMatchIdx++ + pathMatchIdx++ } else { break } + lastKeyCharSlash = t.key[keyMatchIdx-1] == '/' + lastPathCharSlash = path[pathMatchIdx-1] == '/' } - // conflicts caused by ":" is only possible when j == m - if j == m { + // If the node key is fully matched, we match the rest path with children nodes to see if a value + // already exists for the path. + if keyMatchIdx == keyLength { for _, c := range t.children { - if c.key[0] == path[k] || c.key[0] == ':' || path[k] == ':' { - if _, _, err := c.get(path[k:]); err == nil { + if c.key[0] == path[pathMatchIdx] || c.key[0] == ':' || path[pathMatchIdx] == ':' { + if _, _, err := c.get(path[pathMatchIdx:], lastKeyCharSlash, lastPathCharSlash, true); err == nil { return errExist } } } } - // node ley is longer than longest common prefix - if i < m { + // node key is longer than longest common prefix + if i < keyLength { // key/path suffix being "*" means a conflict if path[i:] == "*" || t.key[i:] == "*" { return errExist @@ -182,7 +198,7 @@ func (t *tnode) set(path string, value http.Handler) error { // path is equal to longest common prefix // set value on current node after split - if i == n { + if i == pathLength { t.value = value } else { // path is longer than longest common prefix @@ -196,9 +212,9 @@ func (t *tnode) set(path string, value http.Handler) error { } // node key is equal to longest common prefix - if i == m { + if i == keyLength { // path is equal to longest common prefix - if i == n { + if i == pathLength { // node is guaranteed to have zero value, // otherwise it would have caused errExist earlier t.value = value @@ -206,7 +222,9 @@ func (t *tnode) set(path string, value http.Handler) error { // path is longer than node key, try to recurse on node children for _, c := range t.children { if c.key[0] == path[i] { - err := c.set(path[i:], value) + lastKeyCharSlash = i > 0 && t.key[i-1] == '/' + lastPathCharSlash = i > 0 && path[i-1] == '/' + err := c.set(path[i:], value, lastKeyCharSlash, lastPathCharSlash) if e, ok := err.(*paramMismatch); ok { e.existingPath = t.key + e.existingPath return e @@ -226,66 +244,67 @@ func (t *tnode) set(path string, value http.Handler) error { return nil } -func (t *tnode) get(path string) (http.Handler, []Param, error) { - m, n := len(t.key), len(path) +func (t *tnode) get(path string, lastKeyCharSlash, lastPathCharSlash, colonAsPattern bool) (http.Handler, []Param, error) { + keyLength, pathLength := len(t.key), len(path) var params []Param // find the longest matched prefix - var j, k int - for j < m && k < n { - if t.key[j] == ':' { - oj, ok := j+1, k - for j < m && t.key[j] != '/' { - j++ + var keyIdx, pathIdx int + for keyIdx < keyLength && pathIdx < pathLength { + if t.key[keyIdx] == ':' && lastKeyCharSlash { + // wildcard starts - match until next slash + keyStartIdx, pathStartIdx := keyIdx+1, pathIdx + for keyIdx < keyLength && t.key[keyIdx] != '/' { + keyIdx++ } - for k < n && path[k] != '/' { - k++ + for pathIdx < pathLength && path[pathIdx] != '/' { + pathIdx++ } - params = append(params, Param{t.key[oj:j], path[ok:k]}) - } else if path[k] == ':' { // necessary for conflict check used in set call - for j < m && t.key[j] != '/' { - j++ + params = append(params, Param{t.key[keyStartIdx:keyIdx], path[pathStartIdx:pathIdx]}) + } else if path[pathIdx] == ':' && lastPathCharSlash && colonAsPattern { + // necessary for conflict check used in set call + for keyIdx < keyLength && t.key[keyIdx] != '/' { + keyIdx++ } - for k < n && path[k] != '/' { - k++ + for pathIdx < pathLength && path[pathIdx] != '/' { + pathIdx++ } - } else if t.key[j] == path[k] { - j++ - k++ + } else if t.key[keyIdx] == path[pathIdx] { + keyIdx++ + pathIdx++ } else { break } + lastKeyCharSlash = t.key[keyIdx-1] == '/' + lastPathCharSlash = path[pathIdx-1] == '/' } - if j < m { + if keyIdx < keyLength { // path matches up to node key's second to last character, // the last char of node key is "*" and path is no shorter than longest matched prefix - if t.key[j:] == "*" && k < n { + if t.key[keyIdx:] == "*" && pathIdx < pathLength { return t.value, params, nil } return nil, nil, errNotFound } // ':' in path matches '*' in node key - if j > 0 && t.key[j-1] == '*' { + if keyIdx > 0 && t.key[keyIdx-1] == '*' { return t.value, params, nil } // longest matched prefix matches up to node key length and path length - if k == n { + if pathIdx == pathLength { if t.value != nil { return t.value, params, nil } return nil, nil, errNotFound } - // TODO: recursion to iteration for speed // longest matched prefix matches up to node key length but not path length for _, c := range t.children { - if c.key[0] == path[k] || c.key[0] == ':' || path[k] == ':' { - if v, ps, err := c.get(path[k:]); err == nil { - return v, append(params, ps...), nil - } + if v, ps, err := c.get(path[pathIdx:], lastKeyCharSlash, lastPathCharSlash, colonAsPattern); err == nil { + return v, append(params, ps...), nil } } diff --git a/runtime/router/trie_test.go b/runtime/router/trie_test.go index 5e5a375f8..dee131aeb 100644 --- a/runtime/router/trie_test.go +++ b/runtime/router/trie_test.go @@ -142,6 +142,27 @@ func TestTriePathsWithPatten(t *testing.T) { } runTrieTests(t, trie, tests) + trie = NewTrie() + tests = []ts{ + // test ":a" is not treated as a pattern when queried as a url + {op: set, path: "/a", value: "foo"}, + {op: get, path: "/:a", errMsg: errNotFound.Error()}, + + {op: set, path: "/a:b", value: "bar"}, + {op: set, path: "/a:c", value: "baz"}, + {op: get, path: "/a:b", expectedValue: "bar"}, + {op: get, path: "/ac", errMsg: errNotFound.Error()}, + {op: get, path: "/a:", errMsg: errNotFound.Error()}, + } + runTrieTests(t, trie, tests) + + trie = NewTrie() + tests = []ts{ + {op: set, path: "/:a", value: "foo"}, + {op: get, path: "/:a", expectedValue: "foo", expectedParams: []Param{{"a", ":a"}}}, + } + runTrieTests(t, trie, tests) + trie = NewTrie() tests = []ts{ // test "/a" does not collide with "/:a/b"