Skip to content

Commit

Permalink
Special router logic for whitelisted paths (#692)
Browse files Browse the repository at this point in the history
* Special router logic for whitelisted paths

* Fixed cyclic imports

* Refactored some code

* Made WhitelistedPaths param optional

* Added test cases

* Added test cases

* Fixed test case

* Addressed CR feedback

* Added test case

* Addressed CR feedback
  • Loading branch information
sandeepboys committed Mar 26, 2020
1 parent ab8c042 commit 00caa1b
Show file tree
Hide file tree
Showing 7 changed files with 258 additions and 56 deletions.
2 changes: 2 additions & 0 deletions examples/example-gateway/config/production.yaml
Expand Up @@ -88,3 +88,5 @@ clients.baz.alternates:
port: 8114
grpc.clientServiceNameMapping:
echo: echo
router.whitelistedPaths:
- /path/whitelisted
12 changes: 11 additions & 1 deletion runtime/router.go
Expand Up @@ -24,7 +24,6 @@ import (
"context"
"fmt"
"net/http"

"net/url"

"github.com/opentracing/opentracing-go"
Expand Down Expand Up @@ -150,6 +149,7 @@ func NewHTTPRouter(gateway *Gateway) HTTPRouter {
ContextLogger: gateway.ContextLogger,
Scope: gateway.RootScope,
Tracer: gateway.Tracer,
Config: gateway.Config,
}

router := &httpRouter{
Expand All @@ -173,6 +173,7 @@ func NewHTTPRouter(gateway *Gateway) HTTPRouter {
NotFound: http.HandlerFunc(router.handleNotFound),
MethodNotAllowed: http.HandlerFunc(router.handleMethodNotAllowed),
PanicHandler: router.handlePanic,
WhitelistedPaths: router.getWhitelistedPaths(),
}
return router
}
Expand Down Expand Up @@ -263,3 +264,12 @@ func (router *httpRouter) handleMethodNotAllowed(
req.res.StatusCode = http.StatusMethodNotAllowed
req.res.finish(ctx)
}

func (router *httpRouter) getWhitelistedPaths() []string {
var whitelistedPaths []string
if router.gateway.Config != nil &&
router.gateway.Config.ContainsKey("router.whitelistedPaths") {
router.gateway.Config.MustGetStruct("router.whitelistedPaths", &whitelistedPaths)
}
return whitelistedPaths
}
38 changes: 33 additions & 5 deletions runtime/router/router.go
Expand Up @@ -33,6 +33,7 @@ import (
// 1. this router does not treat "/a/:b" and "/a/b/c" as conflicts (https://github.com/julienschmidt/httprouter/issues/175)
// 2. this router does not treat "/a/:b" and "/a/:c" as different routes and therefore does not allow them to be registered at the same time (https://github.com/julienschmidt/httprouter/issues/6)
// 3. this router does not treat "/a" and "/a/" as different routes
// 4. this router treats "/a" and "/:b" as different paths for whitelisted paths
// Also the `*` pattern is greedy, if a handler is register for `/a/*`, then no handler
// can be further registered for any path that starts with `/a/`
type Router struct {
Expand Down Expand Up @@ -64,6 +65,10 @@ type Router struct {
// unrecovered panics.
PanicHandler func(http.ResponseWriter, *http.Request, interface{})

// Used for special behavior using which different handlers can configured
// for paths such as /a and /:b in router.
WhitelistedPaths []string

// TODO: (clu) maybe support OPTIONS
}

Expand All @@ -90,7 +95,7 @@ func (r *Router) Handle(method, path string, handler http.Handler) error {
trie = NewTrie()
r.tries[method] = trie
}
return trie.Set(path, handler)
return trie.Set(path, handler, r.isWhitelistedPath(path))
}

// ServeHTTP dispatches the request to a register handler to handle.
Expand All @@ -104,8 +109,9 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}

reqPath := req.URL.Path
isWhitelisted := r.isWhitelistedPath(reqPath)
if trie, ok := r.tries[req.Method]; ok {
if handler, params, err := trie.Get(reqPath); err == nil {
if handler, params, err := trie.Get(reqPath, isWhitelisted); err == nil {
ctx := context.WithValue(req.Context(), urlParamsKey, params)
req = req.WithContext(ctx)
handler.ServeHTTP(w, req)
Expand All @@ -114,7 +120,7 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}

if r.HandleMethodNotAllowed {
if allowed := r.allowed(reqPath, req.Method); allowed != "" {
if allowed := r.allowed(reqPath, req.Method, isWhitelisted); allowed != "" {
w.Header().Set("Allow", allowed)
if r.MethodNotAllowed != nil {
r.MethodNotAllowed.ServeHTTP(w, req)
Expand All @@ -135,15 +141,15 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}
}

func (r *Router) allowed(path, reqMethod string) string {
func (r *Router) allowed(path, reqMethod string, isWhitelisted bool) string {
var allow []string

for method, trie := range r.tries {
if method == reqMethod || method == http.MethodOptions {
continue
}

if _, _, err := trie.Get(path); err == nil {
if _, _, err := trie.Get(path, isWhitelisted); err == nil {
allow = append(allow, method)
}
}
Expand All @@ -153,3 +159,25 @@ func (r *Router) allowed(path, reqMethod string) string {

return strings.Join(allow, ", ")
}

func (r *Router) isWhitelistedPath(path string) bool {
for _, whitelistedPath := range r.WhitelistedPaths {
whitelistedPathTokens := strings.Split(whitelistedPath, "/")
pathTokens := strings.Split(path, "/")
if len(whitelistedPathTokens) != len(pathTokens) {
continue
}

isMatched := true
for i, token := range whitelistedPathTokens {
if pathTokens[i] != token && token[0] != ':' {
isMatched = false
break
}
}
if isMatched {
return true
}
}
return false
}
86 changes: 78 additions & 8 deletions runtime/router/router_test.go
Expand Up @@ -47,21 +47,91 @@ func TestParamsFromContext(t *testing.T) {
r := &Router{}

handled := false
err := r.Handle("GET", "/:var",
http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
params := ParamsFromContext(req.Context())
assert.Equal(t, 1, len(params))
assert.Equal(t, "var", params[0].Key)
assert.Equal(t, "foo", params[0].Value)
handled = true
}))
handlerFunc := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
params := ParamsFromContext(req.Context())
assert.Equal(t, 1, len(params))
assert.Equal(t, "var", params[0].Key)
assert.Equal(t, "foo", params[0].Value)
handled = true
})
err := r.Handle("GET", "/:var", handlerFunc)
assert.NoError(t, err, "unexpected error")

req, _ := http.NewRequest("GET", "/foo", nil)
r.ServeHTTP(nil, req)
assert.True(t, handled)
}

func TestParamsFromContextForWhitelistedPaths(t *testing.T) {
// Test case with no whitelisted paths
r := &Router{}

handled1 := false
handlerFunc1 := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
params := ParamsFromContext(req.Context())
assert.Equal(t, 1, len(params))
assert.Equal(t, "var", params[0].Key)
assert.Equal(t, "some", params[0].Value)
handled1 = true
})

handled2 := false
handlerFunc2 := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
handled2 = true
})

err := r.Handle("GET", "/bar/:var", handlerFunc1)
assert.NoError(t, err, "unexpected error")

err = r.Handle("GET", "/bar/foo", handlerFunc2)
assert.Error(t, err, "path value already set")

req, _ := http.NewRequest("GET", "/bar/some", nil)
r.ServeHTTP(nil, req)
assert.True(t, handled1)
assert.False(t, handled2)

// Test case for paths not in whitelisted paths
r = &Router{}
r.WhitelistedPaths = []string{"/test", "/bar/foo"}

handled1 = false
handled2 = false
err = r.Handle("GET", "/bar/foo", handlerFunc2)
assert.NoError(t, err, "unexpected error")

err = r.Handle("GET", "/bar/:var", handlerFunc1)
assert.Error(t, err, "path value already set")

req, _ = http.NewRequest("GET", "/bar/foo", nil)
r.ServeHTTP(nil, req)
assert.True(t, handled2)
assert.False(t, handled1)

// Test case with whitelisted paths
r = &Router{}
r.WhitelistedPaths = []string{"/test", "/bar/foo", "/bar/:var"}

handled1 = false
handled2 = false
err = r.Handle("GET", "/bar/:var", handlerFunc1)
assert.NoError(t, err, "unexpected error")

err = r.Handle("GET", "/bar/foo", handlerFunc2)
assert.NoError(t, err, "unexpected error")

req, _ = http.NewRequest("GET", "/bar/some", nil)
r.ServeHTTP(nil, req)
assert.True(t, handled1)
assert.False(t, handled2)

handled1 = false
req, _ = http.NewRequest("GET", "/bar/foo", nil)
r.ServeHTTP(nil, req)
assert.False(t, handled1)
assert.True(t, handled2)
}

func TestPanicHandler(t *testing.T) {
handled := false
r := &Router{
Expand Down

0 comments on commit 00caa1b

Please sign in to comment.