diff --git a/filterconfig.go b/filterconfig.go index d6ca42694..7b72078ee 100644 --- a/filterconfig.go +++ b/filterconfig.go @@ -1,6 +1,9 @@ package revel -import "reflect" +import ( + "reflect" + "strings" +) // Map from "Controller" or "Controller.Method" to the Filter chain var filterOverrides = make(map[string][]Filter) @@ -35,6 +38,14 @@ var filterOverrides = make(map[string][]Filter) // // => RouterFilter, FilterConfiguringFilter, OtherFilter, SessionFilter, ActionInvoker // +// Filter modifications may be combined between Controller and Action. For example: +// FilterController(App{}). +// Add(Filter1) +// FilterAction(App.Action). +// Add(Filter2) +// +// .. would result in App.Action being filtered by both Filter1 and Filter2. +// // Note: the last filter stage is not subject to the configurator. In // particular, Add() adds a filter to the second-to-last place. type FilterConfigurator struct { @@ -89,21 +100,31 @@ func FilterAction(methodRef interface{}) FilterConfigurator { // Add the given filter in the second-to-last position in the filter chain. // (Second-to-last so that it is before ActionInvoker) func (conf FilterConfigurator) Add(f Filter) FilterConfigurator { - fc := conf.getOverrideFilters() - filterOverrides[conf.key] = append(fc[:len(fc)-1], f, fc[len(fc)-1]) + conf.apply(func(fc []Filter) []Filter { + return conf.addFilter(f, fc) + }) return conf } +func (conf FilterConfigurator) addFilter(f Filter, fc []Filter) []Filter { + return append(fc[:len(fc)-1], f, fc[len(fc)-1]) +} + // Remove a filter from the filter chain. func (conf FilterConfigurator) Remove(target Filter) FilterConfigurator { - filters := conf.getOverrideFilters() - for i, f := range filters { + conf.apply(func(fc []Filter) []Filter { + return conf.rmFilter(target, fc) + }) + return conf +} + +func (conf FilterConfigurator) rmFilter(target Filter, fc []Filter) []Filter { + for i, f := range fc { if FilterEq(f, target) { - filterOverrides[conf.key] = append(filters[:i], filters[i+1:]...) - return conf + return append(fc[:i], fc[i+1:]...) } } - panic("Did not find target filter to remove") + return fc } // Insert a filter into the filter chain before or after another. @@ -115,44 +136,63 @@ func (conf FilterConfigurator) Insert(insert Filter, where When, target Filter) if where != BEFORE && where != AFTER { panic("where must be BEFORE or AFTER") } - filters := conf.getOverrideFilters() - for i, f := range filters { + conf.apply(func(fc []Filter) []Filter { + return conf.insertFilter(insert, where, target, fc) + }) + return conf +} + +func (conf FilterConfigurator) insertFilter(insert Filter, where When, target Filter, fc []Filter) []Filter { + for i, f := range fc { if FilterEq(f, target) { - filterOverrides[conf.key] = append(filters[:i], append([]Filter{insert}, filters[i:]...)...) - return conf + if where == BEFORE { + return append(fc[:i], append([]Filter{insert}, fc[i:]...)...) + } else { + return append(fc[:i+1], append([]Filter{insert}, fc[i+1:]...)...) + } } } - panic("Did not find target filter for insert") + return fc } -// getOverrideFilters returns the filter chain that applies to the given -// controller or action. If no overrides are configured, then a copy of the -// default filter chain is returned. -func (conf FilterConfigurator) getOverrideFilters() []Filter { - var ( - filters []Filter - ok bool - ) - filters, ok = filterOverrides[conf.key] - if !ok { - filters, ok = filterOverrides[conf.controllerName] - if !ok { - // The override starts with all filters after FilterConfiguringFilter - for i, f := range Filters { - if FilterEq(f, FilterConfiguringFilter) { - filters = make([]Filter, len(Filters)-i-1) - copy(filters, Filters[i+1:]) - break - } - } - if filters == nil { - panic("FilterConfiguringFilter not found in revel.Filters.") +// getChain returns the filter chain that applies to the given controller or +// action. If no overrides are configured, then a copy of the default filter +// chain is returned. +func (conf FilterConfigurator) getChain() []Filter { + var filters []Filter + if filters = getOverrideChain(conf.controllerName, conf.key); filters == nil { + // The override starts with all filters after FilterConfiguringFilter + for i, f := range Filters { + if FilterEq(f, FilterConfiguringFilter) { + filters = make([]Filter, len(Filters)-i-1) + copy(filters, Filters[i+1:]) + break } } + if filters == nil { + panic("FilterConfiguringFilter not found in revel.Filters.") + } } return filters } +// apply applies the given functional change to the filter overrides. +// No other function modifies the filterOverrides map. +func (conf FilterConfigurator) apply(f func([]Filter) []Filter) { + // Updates any actions that have had their filters overridden, if this is a + // Controller configurator. + if conf.controllerName == conf.key { + for k, v := range filterOverrides { + if strings.HasPrefix(k, conf.controllerName+".") { + filterOverrides[k] = f(v) + } + } + } + + // Update the Controller or Action overrides. + filterOverrides[conf.key] = f(conf.getChain()) +} + // FilterEq returns true if the two filters reference the same filter. func FilterEq(a, b Filter) bool { return reflect.ValueOf(a).Pointer() == reflect.ValueOf(b).Pointer() @@ -161,15 +201,20 @@ func FilterEq(a, b Filter) bool { // FilterConfiguringFilter is a filter stage that customizes the remaining // filter chain for the action being invoked. var FilterConfiguringFilter = func(c *Controller, fc []Filter) { - if newChain, ok := filterOverrides[c.Name+"."+c.Action]; ok { + if newChain := getOverrideChain(c.Name, c.Action); newChain != nil { newChain[0](c, newChain[1:]) return } + fc[0](c, fc[1:]) +} - if newChain, ok := filterOverrides[c.Name]; ok { - newChain[0](c, newChain[1:]) - return +// getOverrideChain retrieves the overrides for the action that is set +func getOverrideChain(controllerName, action string) []Filter { + if newChain, ok := filterOverrides[action]; ok { + return newChain } - - fc[0](c, fc[1:]) + if newChain, ok := filterOverrides[controllerName]; ok { + return newChain + } + return nil } diff --git a/filterconfig_test.go b/filterconfig_test.go index 1ee405c8a..5b47cf5b0 100644 --- a/filterconfig_test.go +++ b/filterconfig_test.go @@ -4,8 +4,8 @@ import "testing" type FakeController struct{} -func (c FakeController) FakeAction() {} -func (c *FakeController) FakeAction2() {} +func (c FakeController) Foo() {} +func (c *FakeController) Bar() {} func TestFilterConfiguratorKey(t *testing.T) { conf := FilterController(FakeController{}) @@ -18,18 +18,18 @@ func TestFilterConfiguratorKey(t *testing.T) { t.Errorf("Expected key 'FakeController', was %s", conf.key) } - conf = FilterAction(FakeController.FakeAction) - if conf.key != "FakeController.FakeAction" { - t.Errorf("Expected key 'FakeController.FakeAction', was %s", conf.key) + conf = FilterAction(FakeController.Foo) + if conf.key != "FakeController.Foo" { + t.Errorf("Expected key 'FakeController.Foo', was %s", conf.key) } - conf = FilterAction((*FakeController).FakeAction2) - if conf.key != "FakeController.FakeAction2" { - t.Errorf("Expected key 'FakeController.FakeAction2', was %s", conf.key) + conf = FilterAction((*FakeController).Bar) + if conf.key != "FakeController.Bar" { + t.Errorf("Expected key 'FakeController.Bar', was %s", conf.key) } } -func TestFilterConfiguratorOps(t *testing.T) { +func TestFilterConfigurator(t *testing.T) { // Filters is global state. Restore it after this test. oldFilters := make([]Filter, len(Filters)) copy(oldFilters, Filters) @@ -45,32 +45,81 @@ func TestFilterConfiguratorOps(t *testing.T) { ActionInvoker, } - // First, verify getOverrideFilters returns just the filters after - // FilterConfiguringFilter - conf := FilterAction(FakeController.FakeAction) + // Do one of each operation. + conf := FilterAction(FakeController.Foo). + Add(NilFilter). + Remove(FlashFilter). + Insert(ValidationFilter, BEFORE, NilFilter). + Insert(I18nFilter, AFTER, NilFilter) expected := []Filter{ + SessionFilter, + ValidationFilter, + NilFilter, + I18nFilter, + ActionInvoker, + } + actual := getOverride("Foo") + if len(actual) != len(expected) || !filterSliceEqual(actual, expected) { + t.Errorf("Ops failed.\nActual: %#v\nExpect: %#v\nConf:%v", actual, expected, conf) + } + + // Action2 should be unchanged + if getOverride("Bar") != nil { + t.Errorf("Filtering Action should not affect Action2.") + } + + // Test that combining overrides on both the Controller and Action works. + FilterController(FakeController{}). + Add(PanicFilter) + expected = []Filter{ + SessionFilter, + ValidationFilter, + NilFilter, + I18nFilter, + PanicFilter, + ActionInvoker, + } + actual = getOverride("Foo") + if len(actual) != len(expected) || !filterSliceEqual(actual, expected) { + t.Errorf("Expected PanicFilter added to Foo.\nActual: %#v\nExpect: %#v", actual, expected) + } + + expected = []Filter{ SessionFilter, FlashFilter, + PanicFilter, ActionInvoker, } - actual := conf.getOverrideFilters() + actual = getOverride("Bar") if len(actual) != len(expected) || !filterSliceEqual(actual, expected) { - t.Errorf("getOverrideFilter failed.\nActual: %#v\nExpect: %#v", actual, expected) + t.Errorf("Expected PanicFilter added to Bar.\nActual: %#v\nExpect: %#v", actual, expected) } - // Now do one of each operation. - conf.Add(NilFilter). - Remove(FlashFilter). - Insert(ValidationFilter, BEFORE, NilFilter) + FilterAction((*FakeController).Bar). + Add(NilFilter) expected = []Filter{ SessionFilter, ValidationFilter, NilFilter, + I18nFilter, + PanicFilter, + ActionInvoker, + } + actual = getOverride("Foo") + if len(actual) != len(expected) || !filterSliceEqual(actual, expected) { + t.Errorf("Expected no change to Foo.\nActual: %#v\nExpect: %#v", actual, expected) + } + + expected = []Filter{ + SessionFilter, + FlashFilter, + PanicFilter, + NilFilter, ActionInvoker, } - actual = filterOverrides[conf.key] + actual = getOverride("Bar") if len(actual) != len(expected) || !filterSliceEqual(actual, expected) { - t.Errorf("Ops failed.\nActual: %#v\nExpect: %#v", actual, expected) + t.Errorf("Expected NilFilter added to Bar.\nActual: %#v\nExpect: %#v", actual, expected) } } @@ -82,3 +131,7 @@ func filterSliceEqual(a, e []Filter) bool { } return true } + +func getOverride(methodName string) []Filter { + return getOverrideChain("FakeController", "FakeController."+methodName) +}