Skip to content

Commit

Permalink
Allow user filter registration and setup on http
Browse files Browse the repository at this point in the history
  • Loading branch information
anuptalwalkar committed Jan 18, 2017
1 parent 42bdff1 commit 0268778
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 47 deletions.
12 changes: 12 additions & 0 deletions examples/simple/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
package main

import (
"context"
"fmt"
"io"
"net/http"
Expand All @@ -47,3 +48,14 @@ func registerHTTPers(service service.Host) []uhttp.RouteHandler {
uhttp.NewRouteHandler("/", handler),
}
}

func simpleFilter() uhttp.FilterFunc {
return func(ctx context.Context, w http.ResponseWriter, r *http.Request, next uhttp.Handler) {
fxctx, ok := ctx.(fx.Context)
if ok {
next.ServeHTTP(fxctx, w, r)
} else {
uhttp.Wrap(service.NopHost(), next).ServeHTTP(w, r)
}
}
}
2 changes: 1 addition & 1 deletion examples/simple/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (

func main() {
svc, err := service.WithModules(
uhttp.New(registerHTTPers),
uhttp.New(registerHTTPers, []uhttp.Filter{simpleFilter()}),
).Build()

if err != nil {
Expand Down
97 changes: 97 additions & 0 deletions modules/uhttp/filterchain_builder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright (c) 2016 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

package uhttp

import (
"net/http"

"go.uber.org/fx"
"go.uber.org/fx/service"
)

type filterChain struct {
currentFilter int
finalHandler Handler
filters []Filter
}

func newFilterChain(filters []Filter, finalHandler Handler) filterChain {
return filterChain{
finalHandler: finalHandler,
filters: filters,
}
}

func (fc filterChain) ServeHTTP(ctx fx.Context, w http.ResponseWriter, r *http.Request) {
if fc.currentFilter == len(fc.filters) {
fc.finalHandler.ServeHTTP(ctx, w, r)
} else {
filter := fc.filters[fc.currentFilter]
fc.currentFilter++
filter.Apply(ctx, w, r, fc)
}
}

// FilterChainBuilder builds a filterChain object with added filters
type FilterChainBuilder interface {
// AddFilter is used to add the next filter to the chain during construction time.
// The calls to AddFilter can be chained.
AddFilter(filter Filter) FilterChainBuilder

// Build creates an immutable FilterChain.
Build(finalHandler Handler) filterChain
}

type filterChainBuilder struct {
service.Host

finalHandler Handler
filters []Filter
}

func defaultFilterChainBuilder(host service.Host) FilterChainBuilder {
fcb := NewFilterChainBuilder(host)
return fcb.AddFilter(contextFilter(host)).
AddFilter(tracingServerFilter(host)).
AddFilter(authorizationFilter(host)).
AddFilter(panicFilter(host))
}

// NewFilterChainBuilder creates an empty filterChainBuilder for setup
func NewFilterChainBuilder(host service.Host) FilterChainBuilder {
return &filterChainBuilder{
Host: host,
}
}

func (f filterChainBuilder) AddFilter(filter Filter) FilterChainBuilder {
f.filters = append(f.filters, filter)
return f
}

func (f filterChainBuilder) Build(finalHandler Handler) filterChain {
fc := filterChain{}
for _, ff := range f.filters {
fc.filters = append(fc.filters, ff)
}
fc.finalHandler = finalHandler
return fc
}
27 changes: 2 additions & 25 deletions modules/uhttp/filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ import (
// Filter applies filters on requests, request contexts or responses such as
// adding tracing to the context
type Filter interface {
Apply(ctx context.Context, w http.ResponseWriter, r *http.Request, next Handler)
Apply(ctx fx.Context, w http.ResponseWriter, r *http.Request, next Handler)
}

// FilterFunc is an adaptor to call normal functions to apply filters
type FilterFunc func(ctx context.Context, w http.ResponseWriter, r *http.Request, next Handler)

// Apply implements Apply from the Filter interface and simply delegates to the function
func (f FilterFunc) Apply(ctx context.Context, w http.ResponseWriter, r *http.Request, next Handler) {
func (f FilterFunc) Apply(ctx fx.Context, w http.ResponseWriter, r *http.Request, next Handler) {
f(ctx, w, r, next)
}

Expand Down Expand Up @@ -117,26 +117,3 @@ func panicFilter(host service.Host) FilterFunc {
next.ServeHTTP(fxctx, w, r)
}
}

func newFilterChain(filters []Filter, finalHandler Handler) filterChain {
return filterChain{
filters: filters,
finalHandler: finalHandler,
}
}

type filterChain struct {
currentFilter int
filters []Filter
finalHandler Handler
}

func (ec filterChain) ServeHTTP(ctx fx.Context, w http.ResponseWriter, req *http.Request) {
if ec.currentFilter < len(ec.filters) {
filter := ec.filters[ec.currentFilter]
ec.currentFilter++
filter.Apply(ctx, w, req, ec)
} else {
ec.finalHandler.ServeHTTP(ctx, w, req)
}
}
21 changes: 11 additions & 10 deletions modules/uhttp/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ type Module struct {
listener net.Listener
handlers []RouteHandler
listenMu sync.RWMutex
filters []Filter
fcb FilterChainBuilder
}

var _ service.Module = &Module{}
Expand All @@ -96,9 +96,9 @@ type Config struct {
type GetHandlersFunc func(service service.Host) []RouteHandler

// New returns a new HTTP module
func New(hookup GetHandlersFunc, options ...modules.Option) service.ModuleCreateFunc {
func New(hookup GetHandlersFunc, filters []Filter, options ...modules.Option) service.ModuleCreateFunc {
return func(mi service.ModuleCreateInfo) ([]service.Module, error) {
mod, err := newModule(mi, hookup, options...)
mod, err := newModule(mi, hookup, filters, options...)
if err != nil {
return nil, errors.Wrap(err, "unable to instantiate HTTP module")
}
Expand All @@ -109,6 +109,7 @@ func New(hookup GetHandlersFunc, options ...modules.Option) service.ModuleCreate
func newModule(
mi service.ModuleCreateInfo,
getHandlers GetHandlersFunc,
filters []Filter,
options ...modules.Option,
) (*Module, error) {
// setup config defaults
Expand All @@ -122,16 +123,16 @@ func newModule(
}

handlers := addHealth(getHandlers(mi.Host))

// TODO (madhu): Add other middleware - logging, metrics.
module := &Module{
ModuleBase: *modules.NewModuleBase(ModuleType, mi.Name, mi.Host, []string{}),
handlers: handlers,
filters: []Filter{
contextFilter(mi.Host),
tracingServerFilter(mi.Host),
authorizationFilter(mi.Host),
panicFilter(mi.Host),
},
fcb: defaultFilterChainBuilder(mi.Host),
}

for _, filter := range filters {
module.fcb = module.fcb.AddFilter(filter)
}

err := module.Host().Config().Get(getConfigKey(mi.Name)).PopulateStruct(cfg)
Expand Down Expand Up @@ -161,7 +162,7 @@ func (m *Module) Start(ready chan<- struct{}) <-chan error {
mux.Handle("/", router)

for _, h := range m.handlers {
router.Handle(h.Path, newFilterChain(m.filters, h.Handler))
router.Handle(h.Path, m.fcb.Build(h.Handler))
}

if m.config.Debug == nil || *m.config.Debug {
Expand Down
42 changes: 31 additions & 11 deletions modules/uhttp/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
package uhttp

import (
"context"
"errors"
"fmt"
"io"
Expand All @@ -45,7 +46,7 @@ import (
var _defaultHTTPClient = &http.Client{Timeout: 2 * time.Second}

func TestNew_OK(t *testing.T) {
WithService(New(registerNothing), nil, []service.Option{configOption()}, func(s service.Owner) {
WithService(New(registerNothing, nil), nil, []service.Option{configOption()}, func(s service.Owner) {
assert.NotNil(t, s, "Should create a module")
})
}
Expand All @@ -55,13 +56,24 @@ func TestNew_WithOptions(t *testing.T) {
modules.WithRoles("testing"),
}

withModule(t, registerPanic, options, false, func(m *Module) {
withModule(t, registerPanic, nil, options, false, func(m *Module) {
assert.NotNil(t, m, "Expected OK with options")
})
}

func TestHTTPModule_WithFilter(t *testing.T) {
withModule(t, registerPanic, []Filter{fakeFilter()}, nil, false, func(m *Module) {
assert.NotNil(t, m)
makeRequest(m, "GET", "/", nil, func(r *http.Response) {
body, err := ioutil.ReadAll(r.Body)
assert.NoError(t, err)
assert.Contains(t, string(body), "filter is executed")
})
})
}

func TestHTTPModule_Panic_OK(t *testing.T) {
withModule(t, registerPanic, nil, false, func(m *Module) {
withModule(t, registerPanic, nil, nil, false, func(m *Module) {
assert.NotNil(t, m)
makeRequest(m, "GET", "/", nil, func(r *http.Response) {
assert.Equal(t, http.StatusInternalServerError, r.StatusCode, "Expected 500 with panic wrapper")
Expand All @@ -70,7 +82,7 @@ func TestHTTPModule_Panic_OK(t *testing.T) {
}

func TestHTTPModule_Tracer(t *testing.T) {
withModule(t, registerTracerCheckHandler, nil, false, func(m *Module) {
withModule(t, registerTracerCheckHandler, nil, nil, false, func(m *Module) {
assert.NotNil(t, m)
makeRequest(m, "GET", "/", nil, func(r *http.Response) {
assert.Equal(t, http.StatusOK, r.StatusCode, "Expected 200 with tracer check")
Expand All @@ -79,13 +91,13 @@ func TestHTTPModule_Tracer(t *testing.T) {
}

func TestHTTPModule_StartsAndStops(t *testing.T) {
withModule(t, registerPanic, nil, false, func(m *Module) {
withModule(t, registerPanic, nil, nil, false, func(m *Module) {
assert.True(t, m.IsRunning(), "Start should be successful")
})
}

func TestBuiltinHealth_OK(t *testing.T) {
withModule(t, registerNothing, nil, false, func(m *Module) {
withModule(t, registerNothing, nil, nil, false, func(m *Module) {
assert.NotNil(t, m)
makeRequest(m, "GET", "/health", nil, func(r *http.Response) {
assert.Equal(t, http.StatusOK, r.StatusCode, "Expected 200 with default health handler")
Expand All @@ -94,7 +106,7 @@ func TestBuiltinHealth_OK(t *testing.T) {
}

func TestOverrideHealth_OK(t *testing.T) {
withModule(t, registerCustomHealth, nil, false, func(m *Module) {
withModule(t, registerCustomHealth, nil, nil, false, func(m *Module) {
assert.NotNil(t, m)
makeRequest(m, "GET", "/health", nil, func(r *http.Response) {
assert.Equal(t, http.StatusOK, r.StatusCode, "Expected 200 with default health handler")
Expand All @@ -106,7 +118,7 @@ func TestOverrideHealth_OK(t *testing.T) {
}

func TestPProf_Registered(t *testing.T) {
withModule(t, registerNothing, nil, false, func(m *Module) {
withModule(t, registerNothing, nil, nil, false, func(m *Module) {
assert.NotNil(t, m)
makeRequest(m, "GET", "/debug/pprof", nil, func(r *http.Response) {
assert.Equal(t, http.StatusOK, r.StatusCode, "Expected 200 from pprof handler")
Expand All @@ -119,7 +131,7 @@ func TestHookupOptions(t *testing.T) {
modules.WithName("an optional name"),
}

withModule(t, registerNothing, options, false, func(m *Module) {
withModule(t, registerNothing, nil, options, false, func(m *Module) {
assert.NotNil(t, m)
})
}
Expand All @@ -131,7 +143,7 @@ func TestHookupOptions_Error(t *testing.T) {
},
}

withModule(t, registerNothing, options, true, func(m *Module) {
withModule(t, registerNothing, nil, options, true, func(m *Module) {
assert.Nil(t, m)
})
}
Expand All @@ -145,14 +157,15 @@ func configOption() service.Option {
func withModule(
t testing.TB,
hookup GetHandlersFunc,
filters []Filter,
options []modules.Option,
expectError bool,
fn func(*Module),
) {
mi := service.ModuleCreateInfo{
Host: service.NopHost(),
}
mod, err := newModule(mi, hookup, options...)
mod, err := newModule(mi, hookup, filters, options...)
if expectError {
require.Error(t, err, "Expected error instantiating module")
fn(nil)
Expand Down Expand Up @@ -252,3 +265,10 @@ func registerPanic(_ service.Host) []RouteHandler {
panic("Intentional panic for:" + r.URL.Path)
})
}

func fakeFilter() FilterFunc {
return func(ctx context.Context, w http.ResponseWriter, r *http.Request, next Handler) {
io.WriteString(w, "filter is executed")
Wrap(service.NopHost(), next).ServeHTTP(w, r)
}
}

0 comments on commit 0268778

Please sign in to comment.