Skip to content

Commit

Permalink
Handle cases where colon is in the middle of a url segment
Browse files Browse the repository at this point in the history
  • Loading branch information
Chuntao Lu committed Jul 3, 2019
1 parent 133a174 commit 201baaa
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 76 deletions.
19 changes: 10 additions & 9 deletions runtime/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,16 @@ package router
import (
"context"
"net/http"
"sort"
"strings"
)

// Router dispatches http requests to a registered http.Handler.
// It implements a similar interface to the one in github.com/julienschmidt/httprouter,
// the main differences are:
// 1. this router does not treat "/a/:b" and "/a/b/c" as conflicts (https://github.com/julienschmidt/httprouter/issues/175)
// 2. this router does not treat "/a/:b" and "/a/:c" as different routes (https://github.com/julienschmidt/httprouter/issues/6)
// 3. this router does tno treat "/a" and "/a/" as different routes
// 3. this router does not treat "/a" and "/a/" as different routes
type Router struct {
tries map[string]*Trie

Expand Down Expand Up @@ -132,21 +134,20 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}

func (r *Router) allowed(path, reqMethod string) string {
allow := ""
var allow []string

for method, trie := range r.tries {
if method == reqMethod || method == "OPTIONS" {
if method == reqMethod || method == http.MethodOptions {
continue
}

if _, _, err := trie.Get(path); err == nil {
if len(allow) == 0 {
allow = method
} else {
allow += ", " + method
}
allow = append(allow, method)
}
}
sort.Slice(allow, func(i, j int) bool {
return allow[i] < allow[j]
})

return allow
return strings.Join(allow, ", ")
}
153 changes: 86 additions & 67 deletions runtime/router/trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,12 @@ func NewTrie() *Trie {
}
}

// Set sets the value for given path, returns error if path already set
// Set sets the value for given path, returns error if path already set.
// When a http.Handler is registered for a given path, a subsequent Get returns the registered
// handler if the url passed to Get call matches the set path. Match in this context could mean either
// equality (e.g. url is "/foo" and path is "/foo") or url matches path pattern, which has two forms:
// - path ends with "/*", e.g. url "/foo" and "/foo/bar" both matches path "/*"
// - path contains colon wildcard ("/:"), e.g. url "/a/b" and "/a/c" bot matches path "/a/:var"
func (t *Trie) Set(path string, value http.Handler) error {
if path == "" || strings.Contains(path, "//") {
return errPath
Expand All @@ -90,14 +95,16 @@ func (t *Trie) Set(path string, value http.Handler) error {
return errors.New("path can not contain more than one *")
}

err := t.root.set(path, value)
err := t.root.set(path, value, false, false)
if e, ok := err.(*paramMismatch); ok {
return fmt.Errorf("path %q has a different param key %q, it should be the same key %q as in existing path %q", path, e.actual, e.expected, e.existingPath)
}
return err
}

// Get returns the value for given path, returns error if not found
// Get returns the http.Handler for given path, returns error if not found.
// It also returns the url params if given path contains any, e.g. if a handler is registered for
// "/:foo/bar", then calling Get with path "/xyz/bar" returns a param whose key is "foo" and value is "xyz".
func (t *Trie) Get(path string) (http.Handler, []Param, error) {
if path == "" || strings.Contains(path, "//") {
return nil, nil, errPath
Expand All @@ -107,64 +114,73 @@ func (t *Trie) Get(path string) (http.Handler, []Param, error) {
}
// ignore trailing slash
path = strings.TrimSuffix(path, "/")
return t.root.get(path)
return t.root.get(path, false, false, false)
}

func (t *tnode) set(path string, value http.Handler) error {
// set sets the handler for given path, creates new child node if necessary
// lastKeyCharSlash tracks whether the previous key char is a '/', used to decide it is a pattern or not
// when the current key char is ':'. lastPathCharSlash tracks whether the previous path char is a '/',
// used to decide it is a pattern or not when the current path char is ':'.
func (t *tnode) set(path string, value http.Handler, lastKeyCharSlash, lastPathCharSlash bool) error {
// find the longest common prefix
var l, i int
m, n := len(t.key), len(path)
if m > n {
l = n
var shorterLength, i int
keyLength, pathLength := len(t.key), len(path)
if keyLength > pathLength {
shorterLength = pathLength
} else {
l = m
shorterLength = keyLength
}
for i < l && t.key[i] == path[i] {
for i < shorterLength && t.key[i] == path[i] {
i++
}

// find index j, k in key and path to which they match,
// j and k is only useful to check if there is conflict,
// they are not where splits happen, splits happen at index i.
var j, k int
for j < m && k < n {
if t.key[j] == ':' || path[k] == ':' {
oj, ok := j, k
same := t.key[j] == path[k]
for j < m && t.key[j] != '/' {
j++
// Find the first character that differs between "path" and this node's key, if it exists.
// If we encounter a colon wildcard, ensure that the wildcard in path matches the wildcard
// in this node's key for that segment. The segment is a colon wildcard only when the colon
// is immediately after slash, e.g. "/:foo", "/x/:y". "/a:b" is not a colon wildcard segment.
var keyMatchIdx, pathMatchIdx int
for keyMatchIdx < keyLength && pathMatchIdx < pathLength {
if (t.key[keyMatchIdx] == ':' && lastKeyCharSlash) ||
(path[pathMatchIdx] == ':' && lastPathCharSlash) {
keyStartIdx, pathStartIdx := keyMatchIdx, pathMatchIdx
same := t.key[keyMatchIdx] == path[pathMatchIdx]
for keyMatchIdx < keyLength && t.key[keyMatchIdx] != '/' {
keyMatchIdx++
}
for k < n && path[k] != '/' {
k++
for pathMatchIdx < pathLength && path[pathMatchIdx] != '/' {
pathMatchIdx++
}
if same && (j-oj) != (k-ok) {
if same && (keyMatchIdx-keyStartIdx) != (pathMatchIdx-pathStartIdx) {
return &paramMismatch{
t.key[oj:j],
path[ok:k],
t.key[keyStartIdx:keyMatchIdx],
path[pathStartIdx:pathMatchIdx],
t.key,
}
}
} else if t.key[j] == path[k] {
j++
k++
} else if t.key[keyMatchIdx] == path[pathMatchIdx] {
keyMatchIdx++
pathMatchIdx++
} else {
break
}
lastKeyCharSlash = t.key[keyMatchIdx-1] == '/'
lastPathCharSlash = path[pathMatchIdx-1] == '/'
}

// conflicts caused by ":" is only possible when j == m
if j == m {
// If the node key is fully matched, we match the rest path with children nodes to see if a value
// already exists for the path.
if keyMatchIdx == keyLength {
for _, c := range t.children {
if c.key[0] == path[k] || c.key[0] == ':' || path[k] == ':' {
if _, _, err := c.get(path[k:]); err == nil {
if c.key[0] == path[pathMatchIdx] || c.key[0] == ':' || path[pathMatchIdx] == ':' {
if _, _, err := c.get(path[pathMatchIdx:], lastKeyCharSlash, lastPathCharSlash, true); err == nil {
return errExist
}
}
}
}

// node ley is longer than longest common prefix
if i < m {
// node key is longer than longest common prefix
if i < keyLength {
// key/path suffix being "*" means a conflict
if path[i:] == "*" || t.key[i:] == "*" {
return errExist
Expand All @@ -182,7 +198,7 @@ func (t *tnode) set(path string, value http.Handler) error {

// path is equal to longest common prefix
// set value on current node after split
if i == n {
if i == pathLength {
t.value = value
} else {
// path is longer than longest common prefix
Expand All @@ -196,17 +212,19 @@ func (t *tnode) set(path string, value http.Handler) error {
}

// node key is equal to longest common prefix
if i == m {
if i == keyLength {
// path is equal to longest common prefix
if i == n {
if i == pathLength {
// node is guaranteed to have zero value,
// otherwise it would have caused errExist earlier
t.value = value
} else {
// path is longer than node key, try to recurse on node children
for _, c := range t.children {
if c.key[0] == path[i] {
err := c.set(path[i:], value)
lastKeyCharSlash = i > 0 && t.key[i-1] == '/'
lastPathCharSlash = i > 0 && path[i-1] == '/'
err := c.set(path[i:], value, lastKeyCharSlash, lastPathCharSlash)
if e, ok := err.(*paramMismatch); ok {
e.existingPath = t.key + e.existingPath
return e
Expand All @@ -226,66 +244,67 @@ func (t *tnode) set(path string, value http.Handler) error {
return nil
}

func (t *tnode) get(path string) (http.Handler, []Param, error) {
m, n := len(t.key), len(path)
func (t *tnode) get(path string, lastKeyCharSlash, lastPathCharSlash, colonAsPattern bool) (http.Handler, []Param, error) {
keyLength, pathLength := len(t.key), len(path)
var params []Param

// find the longest matched prefix
var j, k int
for j < m && k < n {
if t.key[j] == ':' {
oj, ok := j+1, k
for j < m && t.key[j] != '/' {
j++
var keyIdx, pathIdx int
for keyIdx < keyLength && pathIdx < pathLength {
if t.key[keyIdx] == ':' && lastKeyCharSlash {
// wildcard starts - match until next slash
keyStartIdx, pathStartIdx := keyIdx+1, pathIdx
for keyIdx < keyLength && t.key[keyIdx] != '/' {
keyIdx++
}
for k < n && path[k] != '/' {
k++
for pathIdx < pathLength && path[pathIdx] != '/' {
pathIdx++
}
params = append(params, Param{t.key[oj:j], path[ok:k]})
} else if path[k] == ':' { // necessary for conflict check used in set call
for j < m && t.key[j] != '/' {
j++
params = append(params, Param{t.key[keyStartIdx:keyIdx], path[pathStartIdx:pathIdx]})
} else if path[pathIdx] == ':' && lastPathCharSlash && colonAsPattern {
// necessary for conflict check used in set call
for keyIdx < keyLength && t.key[keyIdx] != '/' {
keyIdx++
}
for k < n && path[k] != '/' {
k++
for pathIdx < pathLength && path[pathIdx] != '/' {
pathIdx++
}
} else if t.key[j] == path[k] {
j++
k++
} else if t.key[keyIdx] == path[pathIdx] {
keyIdx++
pathIdx++
} else {
break
}
lastKeyCharSlash = t.key[keyIdx-1] == '/'
lastPathCharSlash = path[pathIdx-1] == '/'
}

if j < m {
if keyIdx < keyLength {
// path matches up to node key's second to last character,
// the last char of node key is "*" and path is no shorter than longest matched prefix
if t.key[j:] == "*" && k < n {
if t.key[keyIdx:] == "*" && pathIdx < pathLength {
return t.value, params, nil
}
return nil, nil, errNotFound
}

// ':' in path matches '*' in node key
if j > 0 && t.key[j-1] == '*' {
if keyIdx > 0 && t.key[keyIdx-1] == '*' {
return t.value, params, nil
}

// longest matched prefix matches up to node key length and path length
if k == n {
if pathIdx == pathLength {
if t.value != nil {
return t.value, params, nil
}
return nil, nil, errNotFound
}

// TODO: recursion to iteration for speed
// longest matched prefix matches up to node key length but not path length
for _, c := range t.children {
if c.key[0] == path[k] || c.key[0] == ':' || path[k] == ':' {
if v, ps, err := c.get(path[k:]); err == nil {
return v, append(params, ps...), nil
}
if v, ps, err := c.get(path[pathIdx:], lastKeyCharSlash, lastPathCharSlash, colonAsPattern); err == nil {
return v, append(params, ps...), nil
}
}

Expand Down
21 changes: 21 additions & 0 deletions runtime/router/trie_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,27 @@ func TestTriePathsWithPatten(t *testing.T) {
}
runTrieTests(t, trie, tests)

trie = NewTrie()
tests = []ts{
// test ":a" is not treated as a pattern when queried as a url
{op: set, path: "/a", value: "foo"},
{op: get, path: "/:a", errMsg: errNotFound.Error()},

{op: set, path: "/a:b", value: "bar"},
{op: set, path: "/a:c", value: "baz"},
{op: get, path: "/a:b", expectedValue: "bar"},
{op: get, path: "/ac", errMsg: errNotFound.Error()},
{op: get, path: "/a:", errMsg: errNotFound.Error()},
}
runTrieTests(t, trie, tests)

trie = NewTrie()
tests = []ts{
{op: set, path: "/:a", value: "foo"},
{op: get, path: "/:a", expectedValue: "foo", expectedParams: []Param{{"a", ":a"}}},
}
runTrieTests(t, trie, tests)

trie = NewTrie()
tests = []ts{
// test "/a" does not collide with "/:a/b"
Expand Down

0 comments on commit 201baaa

Please sign in to comment.