diff --git a/x/http_secure_redirect.go b/x/http_secure_redirect.go index d4048476d6f..a318f7ab6e0 100644 --- a/x/http_secure_redirect.go +++ b/x/http_secure_redirect.go @@ -131,25 +131,20 @@ func SecureRedirectTo(r *http.Request, defaultReturnTo *url.URL, opts ...SecureR returnTo.Host = stringsx.Coalesce(returnTo.Host, o.defaultReturnTo.Host) returnTo.Scheme = stringsx.Coalesce(returnTo.Scheme, o.defaultReturnTo.Scheme) - var found bool for _, allowed := range o.allowlist { if strings.EqualFold(allowed.Scheme, returnTo.Scheme) && SecureRedirectToIsAllowedHost(returnTo, allowed) && strings.HasPrefix( stringsx.Coalesce(returnTo.Path, "/"), stringsx.Coalesce(allowed.Path, "/")) { - found = true + return returnTo, nil } } - if !found { - return nil, errors.WithStack(herodot.ErrBadRequest. - WithID(text.ErrIDRedirectURLNotAllowed). - WithReasonf("Requested return_to URL \"%s\" is not allowed.", returnTo). - WithDebugf("Allowed domains are: %v", o.allowlist)) - } - - return returnTo, nil + return nil, errors.WithStack(herodot.ErrBadRequest. + WithID(text.ErrIDRedirectURLNotAllowed). + WithReasonf("Requested return_to URL \"%s\" is not allowed.", returnTo). + WithDebugf("Allowed domains are: %v", o.allowlist)) } func SecureContentNegotiationRedirection( diff --git a/x/http_secure_redirect_test.go b/x/http_secure_redirect_test.go index 728fa982b88..f94c474e244 100644 --- a/x/http_secure_redirect_test.go +++ b/x/http_secure_redirect_test.go @@ -126,8 +126,7 @@ func TestTakeOverReturnToParameter(t *testing.T) { } func TestSecureRedirectTo(t *testing.T) { - - var newServer = func(t *testing.T, isTLS bool, isRelative bool, expectErr bool, opts func(ts *httptest.Server) []x.SecureRedirectOption) *httptest.Server { + newServer := func(t *testing.T, isTLS bool, isRelative bool, expectErr bool, opts func(ts *httptest.Server) []x.SecureRedirectOption) *httptest.Server { var ts *httptest.Server f := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if opts == nil { @@ -160,7 +159,7 @@ func TestSecureRedirectTo(t *testing.T) { return ts } - var makeRequest = func(t *testing.T, ts *httptest.Server, path string) (*http.Response, string) { + makeRequest := func(t *testing.T, ts *httptest.Server, path string) (*http.Response, string) { res, err := ts.Client().Get(ts.URL + "/" + path) require.NoError(t, err) @@ -266,4 +265,39 @@ func TestSecureRedirectTo(t *testing.T) { _, body := makeRequest(t, s, "?return_to=/original") assert.Equal(t, body, s.URL+"/override") }) + + t.Run("case=should work with subdomain wildcard", func(t *testing.T) { + s := newServer(t, false, false, false, func(ts *httptest.Server) []x.SecureRedirectOption { + return []x.SecureRedirectOption{x.SecureRedirectAllowURLs([]url.URL{*urlx.ParseOrPanic("https://*.ory.sh/")})} + }) + _, body := makeRequest(t, s, "?return_to=https://www.ory.sh/kratos") + assert.Equal(t, body, "https://www.ory.sh/kratos") + _, body = makeRequest(t, s, "?return_to=https://even.deeper.nested.ory.sh/kratos") + assert.Equal(t, body, "https://even.deeper.nested.ory.sh/kratos") + }) + + t.Run("case=should fallback to default return_to scheme", func(t *testing.T) { + s := newServer(t, false, false, false, func(ts *httptest.Server) []x.SecureRedirectOption { + return []x.SecureRedirectOption{ + x.SecureRedirectAllowURLs([]url.URL{*urlx.ParseOrPanic("https://www.ory.sh")}), + x.SecureRedirectOverrideDefaultReturnTo(urlx.ParseOrPanic("https://www.ory.sh/docs")), + } + }) + _, body := makeRequest(t, s, "?return_to=//www.ory.sh/kratos") + assert.Equal(t, body, "https://www.ory.sh/kratos") + }) + + t.Run("case=should fallback to default return_to host", func(t *testing.T) { + s := newServer(t, false, false, false, func(ts *httptest.Server) []x.SecureRedirectOption { + return []x.SecureRedirectOption{ + x.SecureRedirectAllowURLs([]url.URL{ + *urlx.ParseOrPanic("https://www.ory.sh"), + *urlx.ParseOrPanic("http://www.ory.sh"), + }), + x.SecureRedirectOverrideDefaultReturnTo(urlx.ParseOrPanic("https://www.ory.sh/docs")), + } + }) + _, body := makeRequest(t, s, "?return_to=http:///kratos") + assert.Equal(t, body, "http://www.ory.sh/kratos") + }) }