diff --git a/pkg/s3-proxy/server/hostrouter.go b/pkg/s3-proxy/server/hostrouter.go index 7c99687e..b947ffee 100644 --- a/pkg/s3-proxy/server/hostrouter.go +++ b/pkg/s3-proxy/server/hostrouter.go @@ -12,8 +12,10 @@ import ( // Fork dead project https://github.com/go-chi/hostrouter/ // Add wildcard support, not found handler and internal server handler // Remove not necessary parts +// Update to ensure that all wildcard domains will be tested in the injection order type HostRouter struct { + domainList []string routes map[string]chi.Router notFoundHandler http.HandlerFunc internalServerHandler func(err error) http.HandlerFunc @@ -21,18 +23,21 @@ type HostRouter struct { func NewHostRouter(notFoundHandler http.HandlerFunc, internalServerHandler func(err error) http.HandlerFunc) HostRouter { return HostRouter{ + domainList: []string{}, routes: map[string]chi.Router{}, notFoundHandler: notFoundHandler, internalServerHandler: internalServerHandler, } } -func (hr HostRouter) Get(domain string) chi.Router { - return hr.routes[domain] +func (hr *HostRouter) Get(domain string) chi.Router { + return hr.routes[strings.ToLower(domain)] } -func (hr HostRouter) Map(host string, h chi.Router) { - hr.routes[strings.ToLower(host)] = h +func (hr *HostRouter) Map(host string, h chi.Router) { + lowercaseHost := strings.ToLower(host) + hr.domainList = append(hr.domainList, lowercaseHost) + hr.routes[lowercaseHost] = h } func (hr HostRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -64,8 +69,8 @@ func (hr HostRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { hr.notFoundHandler(w, r) } -func (hr HostRouter) getRouterWithWildcard(host string) (chi.Router, error) { - for wh, rt := range hr.routes { +func (hr *HostRouter) getRouterWithWildcard(host string) (chi.Router, error) { + for _, wh := range hr.domainList { g, err := glob.Compile(wh) // Check if error exists if err != nil { @@ -73,7 +78,7 @@ func (hr HostRouter) getRouterWithWildcard(host string) (chi.Router, error) { } // Check if wildcard host match current host if g.Match(host) { - return rt, nil + return hr.routes[wh], nil } } diff --git a/pkg/s3-proxy/server/hostrouter_test.go b/pkg/s3-proxy/server/hostrouter_test.go index 6382f040..6905661c 100644 --- a/pkg/s3-proxy/server/hostrouter_test.go +++ b/pkg/s3-proxy/server/hostrouter_test.go @@ -28,20 +28,24 @@ func TestHostRouter_ServeHTTP(t *testing.T) { w.Write([]byte("localhost")) }) + type routeInput struct { + domain string + router chi.Router + } tests := []struct { name string inputURL string - routes map[string]chi.Router + routes []*routeInput expectedStatus int expectedBody string }{ { name: "should match the star glob", inputURL: "http://fake/", - routes: map[string]chi.Router{ - "localhost": localhostRouter, - "*.localhost": starLocalhostRouter, - "*": starRouter, + routes: []*routeInput{ + {"localhost", localhostRouter}, + {"*.localhost", starLocalhostRouter}, + {"*", starRouter}, }, expectedStatus: 200, expectedBody: "star", @@ -49,10 +53,10 @@ func TestHostRouter_ServeHTTP(t *testing.T) { { name: "should match the perfect host", inputURL: "http://localhost/", - routes: map[string]chi.Router{ - "localhost": localhostRouter, - "*.localhost": starLocalhostRouter, - "*": starRouter, + routes: []*routeInput{ + {"localhost", localhostRouter}, + {"*.localhost", starLocalhostRouter}, + {"*", starRouter}, }, expectedStatus: 200, expectedBody: "localhost", @@ -60,10 +64,10 @@ func TestHostRouter_ServeHTTP(t *testing.T) { { name: "should match the glob host", inputURL: "http://api.localhost/", - routes: map[string]chi.Router{ - "localhost": localhostRouter, - "*.localhost": starLocalhostRouter, - "*": starRouter, + routes: []*routeInput{ + {"localhost", localhostRouter}, + {"*.localhost", starLocalhostRouter}, + {"*", starRouter}, }, expectedStatus: 200, expectedBody: "starLocalhost", @@ -71,10 +75,10 @@ func TestHostRouter_ServeHTTP(t *testing.T) { { name: "should match the glob host (2)", inputURL: "http://ui.localhost/", - routes: map[string]chi.Router{ - "localhost": localhostRouter, - "*.localhost": starLocalhostRouter, - "*": starRouter, + routes: []*routeInput{ + {"localhost", localhostRouter}, + {"*.localhost", starLocalhostRouter}, + {"*", starRouter}, }, expectedStatus: 200, expectedBody: "starLocalhost", @@ -82,8 +86,8 @@ func TestHostRouter_ServeHTTP(t *testing.T) { { name: "should return a not found error", inputURL: "http://ui.localhost/", - routes: map[string]chi.Router{ - "localhost": localhostRouter, + routes: []*routeInput{ + {"localhost", localhostRouter}, }, expectedStatus: 404, expectedBody: "hostrouter not found", @@ -91,18 +95,21 @@ func TestHostRouter_ServeHTTP(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - hr := HostRouter{ - routes: tt.routes, - notFoundHandler: func(w http.ResponseWriter, r *http.Request) { + hr := NewHostRouter( + func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(404) w.Write([]byte("hostrouter not found")) }, - internalServerHandler: func(err error) http.HandlerFunc { + func(err error) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(500) w.Write([]byte("hostrouter internal server error")) } }, + ) + + for _, it := range tt.routes { + hr.Map(it.domain, it.router) } w := httptest.NewRecorder()