Skip to content

Commit

Permalink
Split interface, match route and middleware separately [WIP]
Browse files Browse the repository at this point in the history
  • Loading branch information
vardius committed Jan 26, 2020
1 parent e6c790a commit 532090c
Show file tree
Hide file tree
Showing 15 changed files with 230 additions and 191 deletions.
24 changes: 14 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,23 @@ import (
"github.com/vardius/gorouter/v4/context"
)

func Index(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "Welcome!\n")
func index(w http.ResponseWriter, _ *http.Request) {
if _, err := fmt.Fprint(w, "Welcome!\n"); err != nil {
panic(err)
}
}

func Hello(w http.ResponseWriter, r *http.Request) {
func hello(w http.ResponseWriter, r *http.Request) {
params, _ := context.Parameters(r.Context())
fmt.Fprintf(w, "hello, %s!\n", params.Value("name"))
if _, err := fmt.Fprintf(w, "hello, %s!\n", params.Value("name")); err != nil {
panic(err)
}
}

func main() {
router := gorouter.New()
router.GET("/", http.HandlerFunc(Index))
router.GET("/hello/{name}", http.HandlerFunc(Hello))
router.GET("/", http.HandlerFunc(index))
router.GET("/hello/{name}", http.HandlerFunc(hello))

log.Fatal(http.ListenAndServe(":8080", router))
}
Expand All @@ -71,19 +75,19 @@ import (
"github.com/vardius/gorouter/v4"
)

func Index(ctx *fasthttp.RequestCtx) {
func index(_ *fasthttp.RequestCtx) {
fmt.Print("Welcome!\n")
}

func Hello(ctx *fasthttp.RequestCtx) {
func hello(ctx *fasthttp.RequestCtx) {
params := ctx.UserValue("params").(context.Params)
fmt.Printf("Hello, %s!\n", params.Value("name"))
}

func main() {
router := gorouter.NewFastHTTPRouter()
router.GET("/", Index)
router.GET("/hello/{name}", Hello)
router.GET("/", index)
router.GET("/hello/{name}", hello)

log.Fatal(fasthttp.ListenAndServe(":8080", router.HandleFastHTTP))
}
Expand Down
2 changes: 1 addition & 1 deletion doc.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Package gorouter provide request router with middleware
Package gorouter provide request router with globalMiddleware
Router
Expand Down
22 changes: 11 additions & 11 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func Example_second() {
}

func ExampleMiddlewareFunc() {
// Global middleware example
// Global globalMiddleware example
// applies to all routes
hello := func(w http.ResponseWriter, r *http.Request) {
params, _ := context.Parameters(r.Context())
Expand All @@ -81,7 +81,7 @@ func ExampleMiddlewareFunc() {
return http.HandlerFunc(fn)
}

// apply middleware to all routes
// apply globalMiddleware to all routes
// can pass as many as you want
router := gorouter.New(logger)
router.GET("/hello/{name}", http.HandlerFunc(hello))
Expand All @@ -95,7 +95,7 @@ func ExampleMiddlewareFunc() {
}

func ExampleMiddlewareFunc_second() {
// Route level middleware example
// Route level globalMiddleware example
// applies to route and its lower tree
hello := func(w http.ResponseWriter, r *http.Request) {
params, _ := context.Parameters(r.Context())
Expand All @@ -114,7 +114,7 @@ func ExampleMiddlewareFunc_second() {
router := gorouter.New()
router.GET("/hello/{name}", http.HandlerFunc(hello))

// apply middleware to route and all it children
// apply globalMiddleware to route and all it children
// can pass as many as you want
router.USE("GET", "/hello/{name}", logger)

Expand All @@ -127,7 +127,7 @@ func ExampleMiddlewareFunc_second() {
}

func ExampleMiddlewareFunc_third() {
// Http method middleware example
// Http method globalMiddleware example
// applies to all routes under this method
hello := func(w http.ResponseWriter, r *http.Request) {
params, _ := context.Parameters(r.Context())
Expand All @@ -146,7 +146,7 @@ func ExampleMiddlewareFunc_third() {
router := gorouter.New()
router.GET("/hello/{name}", http.HandlerFunc(hello))

// apply middleware to all routes with GET method
// apply globalMiddleware to all routes with GET method
// can pass as many as you want
router.USE("GET", "", logger)

Expand All @@ -159,7 +159,7 @@ func ExampleMiddlewareFunc_third() {
}

func ExampleFastHTTPMiddlewareFunc() {
// Global middleware example
// Global globalMiddleware example
// applies to all routes
hello := func(ctx *fasthttp.RequestCtx) {
params := ctx.UserValue("params").(context.Params)
Expand Down Expand Up @@ -187,7 +187,7 @@ func ExampleFastHTTPMiddlewareFunc() {
}

func ExampleFastHTTPMiddlewareFunc_second() {
// Route level middleware example
// Route level globalMiddleware example
// applies to route and its lower tree
hello := func(ctx *fasthttp.RequestCtx) {
params := ctx.UserValue("params").(context.Params)
Expand All @@ -206,7 +206,7 @@ func ExampleFastHTTPMiddlewareFunc_second() {
router := gorouter.NewFastHTTPRouter()
router.GET("/hello/{name}", hello)

// apply middleware to route and all it children
// apply globalMiddleware to route and all it children
// can pass as many as you want
router.USE("GET", "/hello/{name}", logger)

Expand All @@ -219,7 +219,7 @@ func ExampleFastHTTPMiddlewareFunc_second() {
}

func ExampleFastHTTPMiddlewareFunc_third() {
// Http method middleware example
// Http method globalMiddleware example
// applies to all routes under this method
hello := func(ctx *fasthttp.RequestCtx) {
params := ctx.UserValue("params").(context.Params)
Expand All @@ -238,7 +238,7 @@ func ExampleFastHTTPMiddlewareFunc_third() {
router := gorouter.NewFastHTTPRouter()
router.GET("/hello/{name}", hello)

// apply middleware to all routes with GET method
// apply globalMiddleware to all routes with GET method
// can pass as many as you want
router.USE("GET", "", logger)

Expand Down
37 changes: 23 additions & 14 deletions fasthttp.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,19 @@ import (
// NewFastHTTPRouter creates new Router instance, returns pointer
func NewFastHTTPRouter(fs ...FastHTTPMiddlewareFunc) FastHTTPRouter {
return &fastHTTPRouter{
routes: mux.NewTree(),
middleware: transformFastHTTPMiddlewareFunc(fs...),
routes: mux.NewTree(),
middleware: mux.NewTree(),
globalMiddleware: transformFastHTTPMiddlewareFunc(fs...),
}
}

type fastHTTPRouter struct {
routes mux.Tree
middleware middleware.Middleware
fileServer fasthttp.RequestHandler
notFound fasthttp.RequestHandler
notAllowed fasthttp.RequestHandler
routes mux.Tree // mux.RouteAware tree
middleware mux.Tree // mux.MiddlewareAware tree
globalMiddleware middleware.Middleware
fileServer fasthttp.RequestHandler
notFound fasthttp.RequestHandler
notAllowed fasthttp.RequestHandler
}

func (r *fastHTTPRouter) PrettyPrint() string {
Expand Down Expand Up @@ -68,7 +70,7 @@ func (r *fastHTTPRouter) TRACE(p string, f fasthttp.RequestHandler) {
func (r *fastHTTPRouter) USE(method, path string, fs ...FastHTTPMiddlewareFunc) {
m := transformFastHTTPMiddlewareFunc(fs...)

r.routes = r.routes.WithMiddleware(method+path, m, 0)
r.middleware = r.middleware.WithMiddleware(method+path, m)
}

func (r *fastHTTPRouter) Handle(method, path string, h fasthttp.RequestHandler) {
Expand Down Expand Up @@ -123,11 +125,13 @@ func (r *fastHTTPRouter) HandleFastHTTP(ctx *fasthttp.RequestCtx) {
path := pathutils.TrimSlash(pathAsString)

if root := r.routes.Find(method); root != nil {
if node, treeMiddleware, params, subPath := root.Tree().Match(path); node != nil && node.Route() != nil {
route := node.Route()
handler := route.Handler()
allMiddleware := r.middleware.Merge(root.Middleware().Merge(treeMiddleware))
computedHandler := allMiddleware.Compose(handler)
if route, params, subPath := root.Tree().MatchRoute(path); route != nil {
allMiddleware := r.globalMiddleware
if treeMiddleware := r.middleware.MatchMiddleware(method + path); treeMiddleware != nil {
allMiddleware = allMiddleware.Merge(treeMiddleware)
}

computedHandler := allMiddleware.Compose(route.Handler())

h := computedHandler.(fasthttp.RequestHandler)

Expand All @@ -144,7 +148,12 @@ func (r *fastHTTPRouter) HandleFastHTTP(ctx *fasthttp.RequestCtx) {
}

if pathAsString == "/" && root.Route() != nil {
root.Route().Handler().(fasthttp.RequestHandler)(ctx)
rootMiddleware := r.globalMiddleware
if root.Middleware() != nil {
rootMiddleware = rootMiddleware.Merge(root.Middleware())
}
rootHandler := rootMiddleware.Compose(root.Route().Handler())
rootHandler.(fasthttp.RequestHandler)(ctx)
return
}
}
Expand Down
18 changes: 9 additions & 9 deletions fasthttp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ func TestFastHTTPNilMiddleware(t *testing.T) {
router.HandleFastHTTP(ctx)

if string(ctx.Response.Body()) != "test" {
t.Error("Nil middleware works")
t.Error("Nil globalMiddleware works")
}
}

Expand Down Expand Up @@ -399,15 +399,15 @@ func TestFastHTTPNodeApplyMiddleware(t *testing.T) {
router.HandleFastHTTP(ctx)

if string(ctx.Response.Body()) != "m1y" {
t.Errorf("Use middleware error %s", string(ctx.Response.Body()))
t.Errorf("Use globalMiddleware error %s", string(ctx.Response.Body()))
}

ctx = buildFastHTTPRequestContext(http.MethodGet, "/x/x")

router.HandleFastHTTP(ctx)

if string(ctx.Response.Body()) != "m1m2x" {
t.Errorf("Use middleware error %s", string(ctx.Response.Body()))
t.Errorf("Use globalMiddleware error %s", string(ctx.Response.Body()))
}
}

Expand All @@ -422,10 +422,10 @@ func TestFastHTTPTreeOrphanMiddlewareOrder(t *testing.T) {
}
})

// Method global middleware
// Method global globalMiddleware
router.USE(http.MethodGet, "/", mockFastHTTPMiddleware("m1->"))
router.USE(http.MethodGet, "/", mockFastHTTPMiddleware("m2->"))
// Path middleware
// Path globalMiddleware
router.USE(http.MethodGet, "/x", mockFastHTTPMiddleware("mx1->"))
router.USE(http.MethodGet, "/x", mockFastHTTPMiddleware("mx2->"))
router.USE(http.MethodGet, "/x/y", mockFastHTTPMiddleware("mxy1->"))
Expand All @@ -440,7 +440,7 @@ func TestFastHTTPTreeOrphanMiddlewareOrder(t *testing.T) {
router.HandleFastHTTP(ctx)

if string(ctx.Response.Body()) != "m1->m2->mx1->mx2->mparam1->mparam2->mxy1->mxy2->mxy3->mxy4->handler" {
t.Errorf("Use middleware error %s", string(ctx.Response.Body()))
t.Errorf("Use globalMiddleware error %s", string(ctx.Response.Body()))
}
}

Expand All @@ -462,7 +462,7 @@ func TestFastHTTPNodeApplyMiddlewareStatic(t *testing.T) {
router.HandleFastHTTP(ctx)

if string(ctx.Response.Body()) != "m1x" {
t.Errorf("Use middleware error %s", string(ctx.Response.Body()))
t.Errorf("Use globalMiddleware error %s", string(ctx.Response.Body()))
}
}

Expand All @@ -485,7 +485,7 @@ func TestFastHTTPNodeApplyMiddlewareInvalidNodeReference(t *testing.T) {
router.HandleFastHTTP(ctx)

if string(ctx.Response.Body()) != "y" {
t.Errorf("Use middleware error %s", string(ctx.Response.Body()))
t.Errorf("Use globalMiddleware error %s", string(ctx.Response.Body()))
}
}

Expand Down Expand Up @@ -631,6 +631,6 @@ func TestFastHTTPMountSubRouter(t *testing.T) {
mainRouter.HandleFastHTTP(ctx)

if string(ctx.Response.Body()) != "[rg1][rg2][r1][r2][sg1][sg2][s1][s2][s]" {
t.Errorf("Router mount sub router middleware error: %s", string(ctx.Response.Body()))
t.Errorf("Router mount sub router globalMiddleware error: %s", string(ctx.Response.Body()))
}
}
8 changes: 6 additions & 2 deletions mocks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ func (mfs *mockFileSystem) Open(_ string) (http.File, error) {
func mockMiddleware(body string) MiddlewareFunc {
fn := func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(body))
if _, err := w.Write([]byte(body)); err != nil {
panic(err)
}
h.ServeHTTP(w, r)
})
}
Expand All @@ -65,7 +67,9 @@ func mockServeHTTP(h http.Handler, method, path string) error {
func mockFastHTTPMiddleware(body string) FastHTTPMiddlewareFunc {
fn := func(h fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
fmt.Fprintf(ctx, body)
if _, err := fmt.Fprintf(ctx, body); err != nil {
panic(err)
}

h(ctx)
}
Expand Down
10 changes: 5 additions & 5 deletions mux/benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ func BenchmarkMux(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
n, _, _, _ := root.Tree().Match("pl/blog/comments/123/new")
route, _, _ := root.Tree().MatchRoute("pl/blog/comments/123/new")

if n == nil {
b.Fatalf("%v", n)
if route == nil {
b.Fatalf("%v", route)
}

if n.Name() != commentNew.Name() {
b.Fatalf("%s != %s", n.Name(), commentNew.Name())
if route != commentNew.Route() {
b.Fatalf("%s != %s (%s)", route, commentNew.Route(), commentNew.Name())
}
}
})
Expand Down

0 comments on commit 532090c

Please sign in to comment.