Skip to content

Commit

Permalink
fix: respect the after recovery return to URL from config (#3141)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonas-jonas committed Mar 9, 2023
1 parent 90977ca commit 3467fd3
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 76 deletions.
4 changes: 2 additions & 2 deletions driver/config/config.go
Expand Up @@ -1192,8 +1192,8 @@ func (p *Config) SelfServiceFlowVerificationAfterHooks(ctx context.Context, stra
return p.selfServiceHooks(ctx, HookStrategyKey(ViperKeySelfServiceVerificationAfter, strategy))
}

func (p *Config) SelfServiceFlowRecoveryReturnTo(ctx context.Context) *url.URL {
return p.GetProvider(ctx).RequestURIF(ViperKeySelfServiceRecoveryBrowserDefaultReturnTo, p.SelfServiceBrowserDefaultReturnTo(ctx))
func (p *Config) SelfServiceFlowRecoveryReturnTo(ctx context.Context, defaultReturnTo *url.URL) *url.URL {
return p.GetProvider(ctx).RequestURIF(ViperKeySelfServiceRecoveryBrowserDefaultReturnTo, defaultReturnTo)
}

func (p *Config) SelfServiceFlowRecoveryRequestLifespan(ctx context.Context) time.Duration {
Expand Down
4 changes: 2 additions & 2 deletions driver/config/config_test.go
Expand Up @@ -579,10 +579,10 @@ func TestViperProvider_ReturnTo(t *testing.T) {

p.MustSet(ctx, config.ViperKeySelfServiceBrowserDefaultReturnTo, "https://www.ory.sh/")
assert.Equal(t, "https://www.ory.sh/", p.SelfServiceFlowVerificationReturnTo(ctx, urlx.ParseOrPanic("https://www.ory.sh/")).String())
assert.Equal(t, "https://www.ory.sh/", p.SelfServiceFlowRecoveryReturnTo(ctx).String())
assert.Equal(t, "https://www.ory.sh/", p.SelfServiceFlowRecoveryReturnTo(ctx, urlx.ParseOrPanic("https://www.ory.sh/")).String())

p.MustSet(ctx, config.ViperKeySelfServiceRecoveryBrowserDefaultReturnTo, "https://www.ory.sh/recovery")
assert.Equal(t, "https://www.ory.sh/recovery", p.SelfServiceFlowRecoveryReturnTo(ctx).String())
assert.Equal(t, "https://www.ory.sh/recovery", p.SelfServiceFlowRecoveryReturnTo(ctx, urlx.ParseOrPanic("https://www.ory.sh/")).String())

p.MustSet(ctx, config.ViperKeySelfServiceVerificationBrowserDefaultReturnTo, "https://www.ory.sh/verification")
assert.Equal(t, "https://www.ory.sh/verification", p.SelfServiceFlowVerificationReturnTo(ctx, urlx.ParseOrPanic("https://www.ory.sh/")).String())
Expand Down
8 changes: 6 additions & 2 deletions internal/testhelpers/fake.go
Expand Up @@ -3,8 +3,12 @@

package testhelpers

import "github.com/ory/x/randx"
import (
"strings"

"github.com/ory/x/randx"
)

func RandomEmail() string {
return randx.MustString(16, randx.Alpha) + "@ory.sh"
return strings.ToLower(randx.MustString(16, randx.Alpha) + "@ory.sh")
}
18 changes: 6 additions & 12 deletions selfservice/strategy/code/strategy_recovery.go
Expand Up @@ -380,22 +380,16 @@ func (s *Strategy) recoveryIssueSession(w http.ResponseWriter, r *http.Request,
return s.retryRecoveryFlowWithError(w, r, f.Type, err)
}

// Carry `return_to` parameter over from recovery flow
sfRequestURL, err := url.Parse(sf.RequestURL)
if err != nil {
return s.retryRecoveryFlowWithError(w, r, f.Type, err)
returnToURL := s.deps.Config().SelfServiceFlowRecoveryReturnTo(r.Context(), nil)
returnTo := ""
if returnToURL != nil {
returnTo = returnToURL.String()
}

fRequestURL, err := url.Parse(f.RequestURL)
sf.RequestURL, err = x.TakeOverReturnToParameter(f.RequestURL, sf.RequestURL, returnTo)
if err != nil {
return s.retryRecoveryFlowWithError(w, r, f.Type, err)
return s.retryRecoveryFlowWithError(w, r, flow.TypeBrowser, err)
}

sfQuery := sfRequestURL.Query()
sfQuery.Set("return_to", fRequestURL.Query().Get("return_to"))
sfRequestURL.RawQuery = sfQuery.Encode()
sf.RequestURL = sfRequestURL.String()

if err := s.deps.RecoveryExecutor().PostRecoveryHook(w, r, f, sess); err != nil {
return s.retryRecoveryFlowWithError(w, r, f.Type, err)
}
Expand Down
110 changes: 76 additions & 34 deletions selfservice/strategy/code/strategy_recovery_test.go
Expand Up @@ -12,11 +12,11 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"

"github.com/davecgh/go-spew/spew"
"github.com/gofrs/uuid"
errors "github.com/pkg/errors"

"github.com/ory/kratos/driver"
Expand Down Expand Up @@ -202,7 +202,7 @@ func TestAdminStrategy(t *testing.T) {
})

t.Run("case=should not be able to use code from different flow", func(t *testing.T) {
email := strings.ToLower(testhelpers.RandomEmail())
email := testhelpers.RandomEmail()
i := createIdentityToRecover(t, reg, email)

c1, _, err := createCode(i.ID.String(), pointerx.String("1h"))
Expand All @@ -218,7 +218,7 @@ func TestAdminStrategy(t *testing.T) {
})

t.Run("case=form should not contain email field when creating recovery code", func(t *testing.T) {
email := strings.ToLower(testhelpers.RandomEmail())
email := testhelpers.RandomEmail()
i := createIdentityToRecover(t, reg, email)

c1, _, err := createCode(i.ID.String(), pointerx.String("1h"))
Expand Down Expand Up @@ -401,7 +401,7 @@ func TestRecovery(t *testing.T) {
}

t.Run("description=should recover an account", func(t *testing.T) {
var checkRecovery = func(t *testing.T, client *http.Client, flowType, recoveryEmail, recoverySubmissionResponse, returnTo string) string {
var checkRecovery = func(t *testing.T, client *http.Client, flowType, recoveryEmail, recoverySubmissionResponse string) string {

ExpectVerfiableAddressStatus(t, recoveryEmail, identity.VerifiableAddressStatusPending)

Expand All @@ -427,7 +427,7 @@ func TestRecovery(t *testing.T) {
recoverySubmissionResponse := submitRecovery(t, client, RecoveryFlowTypeBrowser, func(v url.Values) {
v.Set("email", email)
}, http.StatusOK)
body := checkRecovery(t, client, RecoveryFlowTypeBrowser, email, recoverySubmissionResponse, "")
body := checkRecovery(t, client, RecoveryFlowTypeBrowser, email, recoverySubmissionResponse)

assert.Equal(t, text.NewRecoverySuccessful(time.Now().Add(time.Hour)).Text,
gjson.Get(body, "ui.messages.0.text").String())
Expand All @@ -447,7 +447,7 @@ func TestRecovery(t *testing.T) {
recoverySubmissionResponse := submitRecovery(t, client, RecoveryFlowTypeSPA, func(v url.Values) {
v.Set("email", email)
}, http.StatusOK)
body := checkRecovery(t, client, RecoveryFlowTypeSPA, email, recoverySubmissionResponse, "")
body := checkRecovery(t, client, RecoveryFlowTypeSPA, email, recoverySubmissionResponse)
assert.Equal(t, "browser_location_change_required", gjson.Get(body, "error.id").String())
assert.Contains(t, gjson.Get(body, "redirect_browser_to").String(), "settings-ts?")
})
Expand All @@ -459,40 +459,82 @@ func TestRecovery(t *testing.T) {
recoverySubmissionResponse := submitRecovery(t, client, RecoveryFlowTypeAPI, func(v url.Values) {
v.Set("email", email)
}, http.StatusOK)
body := checkRecovery(t, client, RecoveryFlowTypeAPI, email, recoverySubmissionResponse, "")
body := checkRecovery(t, client, RecoveryFlowTypeAPI, email, recoverySubmissionResponse)
assert.Equal(t, "browser_location_change_required", gjson.Get(body, "error.id").String())
assert.Contains(t, gjson.Get(body, "redirect_browser_to").String(), "settings-ts?")
})

t.Run("description=should return browser to return url", func(t *testing.T) {
client := testhelpers.NewClientWithCookies(t)
email := "recoverme@ory.sh"
returnTo := "https://www.ory.sh"
createIdentityToRecover(t, reg, email)
returnTo := public.URL + "/return-to"
conf.Set(ctx, config.ViperKeyURLsAllowedReturnToDomains, []string{returnTo})
for _, tc := range []struct {
desc string
returnTo string
f func(t *testing.T, client *http.Client) *kratos.RecoveryFlow
}{
{
desc: "should use return_to from recovery flow",
returnTo: returnTo,
f: func(t *testing.T, client *http.Client) *kratos.RecoveryFlow {
return testhelpers.InitializeRecoveryFlowViaBrowser(t, client, false, public, url.Values{"return_to": []string{returnTo}})
},
},
{
desc: "should use return_to from config",
returnTo: returnTo,
f: func(t *testing.T, client *http.Client) *kratos.RecoveryFlow {
conf.Set(ctx, config.ViperKeySelfServiceRecoveryBrowserDefaultReturnTo, returnTo)
t.Cleanup(func() {
conf.Set(ctx, config.ViperKeySelfServiceRecoveryBrowserDefaultReturnTo, "")
})
return testhelpers.InitializeRecoveryFlowViaBrowser(t, client, false, public, nil)
},
},
{
desc: "no return to",
returnTo: "",
f: func(t *testing.T, client *http.Client) *kratos.RecoveryFlow {
return testhelpers.InitializeRecoveryFlowViaBrowser(t, client, false, public, nil)
},
},
} {
t.Run(fmt.Sprintf("%s", tc.desc), func(t *testing.T) {
client := testhelpers.NewClientWithCookies(t)
email := testhelpers.RandomEmail()
createIdentityToRecover(t, reg, email)

client.Transport = testhelpers.NewTransportWithLogger(http.DefaultTransport, t).RoundTripper
client.Transport = testhelpers.NewTransportWithLogger(http.DefaultTransport, t).RoundTripper

f := testhelpers.InitializeRecoveryFlowViaBrowser(t, client, false, public, url.Values{"return_to": []string{returnTo}})
f := tc.f(t, client)

formPayload := testhelpers.SDKFormFieldsToURLValues(f.Ui.Nodes)
formPayload.Set("email", email)
formPayload := testhelpers.SDKFormFieldsToURLValues(f.Ui.Nodes)
formPayload.Set("email", email)

body, res := testhelpers.RecoveryMakeRequest(t, false, f, client, formPayload.Encode())
assert.EqualValues(t, http.StatusOK, res.StatusCode, "%s", body)
expectedURL := testhelpers.ExpectURL(false, public.URL+recovery.RouteSubmitFlow, conf.SelfServiceFlowRecoveryUI(ctx).String())
assert.Contains(t, res.Request.URL.String(), expectedURL, "%+v\n\t%s", res.Request, body)
body, res := testhelpers.RecoveryMakeRequest(t, false, f, client, formPayload.Encode())
assert.EqualValues(t, http.StatusOK, res.StatusCode, "%s", body)
expectedURL := testhelpers.ExpectURL(false, public.URL+recovery.RouteSubmitFlow, conf.SelfServiceFlowRecoveryUI(ctx).String())
assert.Contains(t, res.Request.URL.String(), expectedURL, "%+v\n\t%s", res.Request, body)

body = checkRecovery(t, client, RecoveryFlowTypeBrowser, email, body, returnTo)
body = checkRecovery(t, client, RecoveryFlowTypeBrowser, email, body)

assert.Equal(t, text.NewRecoverySuccessful(time.Now().Add(time.Hour)).Text,
gjson.Get(body, "ui.messages.0.text").String())
assert.Equal(t, text.NewRecoverySuccessful(time.Now().Add(time.Hour)).Text,
gjson.Get(body, "ui.messages.0.text").String())

res, err := client.Get(public.URL + session.RouteWhoami)
require.NoError(t, err)
body = string(x.MustReadAll(res.Body))
require.NoError(t, res.Body.Close())
assert.Equal(t, "code_recovery", gjson.Get(body, "authentication_methods.0.method").String(), "%s", body)
assert.Equal(t, "aal1", gjson.Get(body, "authenticator_assurance_level").String(), "%s", body)
settingsId := gjson.Get(body, "id").String()

sf, err := reg.SettingsFlowPersister().GetSettingsFlow(ctx, uuid.Must(uuid.FromString(settingsId)))
require.NoError(t, err)

require.Equal(t, tc.returnTo, sf.ReturnTo)

res, err = client.Get(public.URL + session.RouteWhoami)
require.NoError(t, err)
body = string(x.MustReadAll(res.Body))
require.NoError(t, res.Body.Close())
assert.Equal(t, "code_recovery", gjson.Get(body, "authentication_methods.0.method").String(), "%s", body)
assert.Equal(t, "aal1", gjson.Get(body, "authenticator_assurance_level").String(), "%s", body)
})
}
})
})

Expand Down Expand Up @@ -671,7 +713,7 @@ func TestRecovery(t *testing.T) {
conf.MustSet(ctx, config.HookStrategyKey(config.ViperKeySelfServiceRegistrationAfter, identity.CredentialsTypePassword.String()), nil)
})

email := strings.ToLower(testhelpers.RandomEmail())
email := testhelpers.RandomEmail()
id := createIdentityToRecover(t, reg, email)

req := httptest.NewRequest("GET", "/sessions/whoami", nil)
Expand Down Expand Up @@ -711,7 +753,7 @@ func TestRecovery(t *testing.T) {
})

t.Run("description=should not be able to use an invalid code more than 5 times", func(t *testing.T) {
email := strings.ToLower(testhelpers.RandomEmail())
email := testhelpers.RandomEmail()
createIdentityToRecover(t, reg, email)
c := testhelpers.NewClientWithCookies(t)
body := submitRecovery(t, c, RecoveryFlowTypeBrowser, func(v url.Values) {
Expand Down Expand Up @@ -741,7 +783,7 @@ func TestRecovery(t *testing.T) {
for _, testCase := range flowTypeCases {
t.Run("type="+testCase.FlowType, func(t *testing.T) {
c := testCase.GetClient(t)
recoveryEmail := strings.ToLower(testhelpers.RandomEmail())
recoveryEmail := testhelpers.RandomEmail()
_ = createIdentityToRecover(t, reg, recoveryEmail)

actual := submitRecovery(t, c, testCase.FlowType, func(v url.Values) {
Expand Down Expand Up @@ -898,7 +940,7 @@ func TestRecovery(t *testing.T) {
})

t.Run("description=should be able to re-send the recovery code", func(t *testing.T) {
recoveryEmail := strings.ToLower(testhelpers.RandomEmail())
recoveryEmail := testhelpers.RandomEmail()
createIdentityToRecover(t, reg, recoveryEmail)

c := testhelpers.NewClientWithCookies(t)
Expand All @@ -921,7 +963,7 @@ func TestRecovery(t *testing.T) {
})

t.Run("description=should not be able to use first code after re-sending email", func(t *testing.T) {
recoveryEmail := strings.ToLower(testhelpers.RandomEmail())
recoveryEmail := testhelpers.RandomEmail()
createIdentityToRecover(t, reg, recoveryEmail)

c := testhelpers.NewClientWithCookies(t)
Expand Down Expand Up @@ -952,7 +994,7 @@ func TestRecovery(t *testing.T) {
})

t.Run("description=should not show outdated validation message if newer message appears #2799", func(t *testing.T) {
recoveryEmail := strings.ToLower(testhelpers.RandomEmail())
recoveryEmail := testhelpers.RandomEmail()
createIdentityToRecover(t, reg, recoveryEmail)

c := testhelpers.NewClientWithCookies(t)
Expand Down
8 changes: 7 additions & 1 deletion selfservice/strategy/link/strategy_recovery.go
Expand Up @@ -303,7 +303,13 @@ func (s *Strategy) recoveryIssueSession(w http.ResponseWriter, r *http.Request,
return s.retryRecoveryFlowWithError(w, r, flow.TypeBrowser, err)
}

sf.RequestURL, err = x.TakeOverReturnToParameter(f.RequestURL, sf.RequestURL)
returnToURL := s.d.Config().SelfServiceFlowRecoveryReturnTo(r.Context(), nil)
returnTo := ""
if returnToURL != nil {
returnTo = returnToURL.String()
}

sf.RequestURL, err = x.TakeOverReturnToParameter(f.RequestURL, sf.RequestURL, returnTo)
if err != nil {
return s.retryRecoveryFlowWithError(w, r, flow.TypeBrowser, err)
}
Expand Down
77 changes: 56 additions & 21 deletions selfservice/strategy/link/strategy_recovery_test.go
Expand Up @@ -558,27 +558,62 @@ func TestRecovery(t *testing.T) {
}), email, "")
})

t.Run("type=browser set return_to", func(t *testing.T) {
email := "recoverme2@ory.sh"
returnTo := "https://www.ory.sh"
createIdentityToRecover(t, reg, email)

hc := testhelpers.NewClientWithCookies(t)
hc.Transport = testhelpers.NewTransportWithLogger(http.DefaultTransport, t).RoundTripper

f := testhelpers.InitializeRecoveryFlowViaBrowser(t, hc, false, public, url.Values{"return_to": []string{returnTo}})

time.Sleep(time.Millisecond) // add a bit of delay to allow `1ns` to time out.

formPayload := testhelpers.SDKFormFieldsToURLValues(f.Ui.Nodes)
formPayload.Set("email", email)

b, res := testhelpers.RecoveryMakeRequest(t, false, f, hc, testhelpers.EncodeFormAsJSON(t, false, formPayload))
assert.EqualValues(t, http.StatusOK, res.StatusCode, "%s", b)
expectedURL := testhelpers.ExpectURL(false, public.URL+recovery.RouteSubmitFlow, conf.SelfServiceFlowRecoveryUI(ctx).String())
assert.Contains(t, res.Request.URL.String(), expectedURL, "%+v\n\t%s", res.Request, b)

check(t, b, email, returnTo)
t.Run("description=should return browser to return url", func(t *testing.T) {
returnTo := public.URL + "/return-to"
conf.Set(ctx, config.ViperKeyURLsAllowedReturnToDomains, []string{returnTo})
for _, tc := range []struct {
desc string
returnTo string
f func(t *testing.T, client *http.Client) *kratos.RecoveryFlow
}{
{
desc: "should use return_to from recovery flow",
returnTo: returnTo,
f: func(t *testing.T, client *http.Client) *kratos.RecoveryFlow {
return testhelpers.InitializeRecoveryFlowViaBrowser(t, client, false, public, url.Values{"return_to": []string{returnTo}})
},
},
{
desc: "should use return_to from config",
returnTo: returnTo,
f: func(t *testing.T, client *http.Client) *kratos.RecoveryFlow {
conf.Set(ctx, config.ViperKeySelfServiceRecoveryBrowserDefaultReturnTo, returnTo)
t.Cleanup(func() {
conf.Set(ctx, config.ViperKeySelfServiceRecoveryBrowserDefaultReturnTo, "")
})
return testhelpers.InitializeRecoveryFlowViaBrowser(t, client, false, public, nil)
},
},
{
desc: "no return to",
returnTo: "",
f: func(t *testing.T, client *http.Client) *kratos.RecoveryFlow {
return testhelpers.InitializeRecoveryFlowViaBrowser(t, client, false, public, nil)
},
},
} {
t.Run(fmt.Sprintf("%s", tc.desc), func(t *testing.T) {
email := testhelpers.RandomEmail()
createIdentityToRecover(t, reg, email)

hc := testhelpers.NewClientWithCookies(t)
hc.Transport = testhelpers.NewTransportWithLogger(http.DefaultTransport, t).RoundTripper

f := tc.f(t, hc)

time.Sleep(time.Millisecond) // add a bit of delay to allow `1ns` to time out.

formPayload := testhelpers.SDKFormFieldsToURLValues(f.Ui.Nodes)
formPayload.Set("email", email)

b, res := testhelpers.RecoveryMakeRequest(t, false, f, hc, testhelpers.EncodeFormAsJSON(t, false, formPayload))
assert.EqualValues(t, http.StatusOK, res.StatusCode, "%s", b)
expectedURL := testhelpers.ExpectURL(false, public.URL+recovery.RouteSubmitFlow, conf.SelfServiceFlowRecoveryUI(ctx).String())
assert.Contains(t, res.Request.URL.String(), expectedURL, "%+v\n\t%s", res.Request, b)

check(t, b, email, tc.returnTo)
})
}
})

t.Run("type=spa", func(t *testing.T) {
Expand Down

0 comments on commit 3467fd3

Please sign in to comment.