diff --git a/.schema/api.swagger.json b/.schema/api.swagger.json index 3cfe676851e..440787ed5a4 100755 --- a/.schema/api.swagger.json +++ b/.schema/api.swagger.json @@ -1289,7 +1289,7 @@ { "type": "string", "description": "The Registration Flow ID\n\nThe value for this parameter comes from `flow` URL Query parameter sent to your\napplication (e.g. `/registration?flow=abcde`).", - "name": "flow", + "name": "id", "in": "query", "required": true } diff --git a/driver/configuration/provider.go b/driver/configuration/provider.go index fa40b3c89a3..f448ca3e5a6 100644 --- a/driver/configuration/provider.go +++ b/driver/configuration/provider.go @@ -78,7 +78,7 @@ type Provider interface { SelfServiceFlowLoginReturnTo(strategy string) *url.URL SelfServiceFlowLoginRequestLifespan() time.Duration - SelfServiceFlowRegisterUI() *url.URL + SelfServiceFlowRegistrationUI() *url.URL SelfServiceFlowRegistrationBeforeHooks() []SelfServiceHook SelfServiceFlowRegistrationAfterHooks(strategy string) []SelfServiceHook SelfServiceFlowRegistrationReturnTo(strategy string) *url.URL diff --git a/driver/configuration/provider_viper.go b/driver/configuration/provider_viper.go index aa54adf8ced..c0744985602 100644 --- a/driver/configuration/provider_viper.go +++ b/driver/configuration/provider_viper.go @@ -376,7 +376,7 @@ func (p *ViperProvider) SelfServiceFlowErrorURL() *url.URL { return mustParseURLFromViper(p.l, ViperKeySelfServiceErrorUI) } -func (p *ViperProvider) SelfServiceFlowRegisterUI() *url.URL { +func (p *ViperProvider) SelfServiceFlowRegistrationUI() *url.URL { return mustParseURLFromViper(p.l, ViperKeySelfServiceRegistrationUI) } diff --git a/driver/configuration/provider_viper_test.go b/driver/configuration/provider_viper_test.go index a84c57edb31..41f5e7d61c6 100644 --- a/driver/configuration/provider_viper_test.go +++ b/driver/configuration/provider_viper_test.go @@ -54,7 +54,7 @@ func TestViperProvider(t *testing.T) { t.Run("group=urls", func(t *testing.T) { assert.Equal(t, "http://test.kratos.ory.sh/login", p.SelfServiceFlowLoginUI().String()) assert.Equal(t, "http://test.kratos.ory.sh/settings", p.SelfServiceFlowSettingsUI().String()) - assert.Equal(t, "http://test.kratos.ory.sh/register", p.SelfServiceFlowRegisterUI().String()) + assert.Equal(t, "http://test.kratos.ory.sh/register", p.SelfServiceFlowRegistrationUI().String()) assert.Equal(t, "http://test.kratos.ory.sh/error", p.SelfServiceFlowErrorURL().String()) assert.Equal(t, "http://admin.kratos.ory.sh", p.SelfAdminURL().String()) @@ -400,7 +400,7 @@ func TestViperProvider_Defaults(t *testing.T) { p := configuration.NewViperProvider(logrusx.New("", ""), false) assert.Equal(t, "https://www.ory.sh/kratos/docs/fallback/login", p.SelfServiceFlowLoginUI().String()) assert.Equal(t, "https://www.ory.sh/kratos/docs/fallback/settings", p.SelfServiceFlowSettingsUI().String()) - assert.Equal(t, "https://www.ory.sh/kratos/docs/fallback/registration", p.SelfServiceFlowRegisterUI().String()) + assert.Equal(t, "https://www.ory.sh/kratos/docs/fallback/registration", p.SelfServiceFlowRegistrationUI().String()) assert.Equal(t, "https://www.ory.sh/kratos/docs/fallback/recovery", p.SelfServiceFlowRecoveryUI().String()) assert.Equal(t, "https://www.ory.sh/kratos/docs/fallback/verification", p.SelfServiceFlowVerificationUI().String()) }) diff --git a/internal/httpclient/client/common/get_self_service_registration_flow_parameters.go b/internal/httpclient/client/common/get_self_service_registration_flow_parameters.go index 1aa1dd305f9..039d3fca779 100644 --- a/internal/httpclient/client/common/get_self_service_registration_flow_parameters.go +++ b/internal/httpclient/client/common/get_self_service_registration_flow_parameters.go @@ -60,14 +60,14 @@ for the get self service registration flow operation typically these are written */ type GetSelfServiceRegistrationFlowParams struct { - /*Flow + /*ID The Registration Flow ID The value for this parameter comes from `flow` URL Query parameter sent to your application (e.g. `/registration?flow=abcde`). */ - Flow string + ID string timeout time.Duration Context context.Context @@ -107,15 +107,15 @@ func (o *GetSelfServiceRegistrationFlowParams) SetHTTPClient(client *http.Client o.HTTPClient = client } -// WithFlow adds the flow to the get self service registration flow params -func (o *GetSelfServiceRegistrationFlowParams) WithFlow(flow string) *GetSelfServiceRegistrationFlowParams { - o.SetFlow(flow) +// WithID adds the id to the get self service registration flow params +func (o *GetSelfServiceRegistrationFlowParams) WithID(id string) *GetSelfServiceRegistrationFlowParams { + o.SetID(id) return o } -// SetFlow adds the flow to the get self service registration flow params -func (o *GetSelfServiceRegistrationFlowParams) SetFlow(flow string) { - o.Flow = flow +// SetID adds the id to the get self service registration flow params +func (o *GetSelfServiceRegistrationFlowParams) SetID(id string) { + o.ID = id } // WriteToRequest writes these params to a swagger request @@ -126,11 +126,11 @@ func (o *GetSelfServiceRegistrationFlowParams) WriteToRequest(r runtime.ClientRe } var res []error - // query param flow - qrFlow := o.Flow - qFlow := qrFlow - if qFlow != "" { - if err := r.SetQueryParam("flow", qFlow); err != nil { + // query param id + qrID := o.ID + qID := qrID + if qID != "" { + if err := r.SetQueryParam("id", qID); err != nil { return err } } diff --git a/selfservice/flow/login/handler_test.go b/selfservice/flow/login/handler_test.go index 222bebce1d3..bc1064bf4be 100644 --- a/selfservice/flow/login/handler_test.go +++ b/selfservice/flow/login/handler_test.go @@ -176,7 +176,7 @@ func TestGetFlow(t *testing.T) { IssuedAt: time.Now().Add(-time.Minute * 2), RequestURL: public.URL + login.RouteInitBrowserFlow, CSRFToken: x.FakeCSRFToken, - Type: flow.TypeBrowser, + Type: flow.TypeBrowser, } } diff --git a/selfservice/flow/registration/error.go b/selfservice/flow/registration/error.go index 9e9efa50338..e06f4e06584 100644 --- a/selfservice/flow/registration/error.go +++ b/selfservice/flow/registration/error.go @@ -1,7 +1,6 @@ package registration import ( - "context" "net/http" "net/url" "time" @@ -69,67 +68,72 @@ func (s *ErrorHandler) WriteFlowError( w http.ResponseWriter, r *http.Request, ct identity.CredentialsType, - rr *Flow, + f *Flow, err error, ) { s.d.Audit(). WithError(err). WithRequest(r). - WithField("registration_flow", rr). + WithField("registration_flow", f). Info("Encountered self-service flow error.") + if f == nil { + s.forward(w, r, nil, err) + return + } + if e := new(FlowExpiredError); errors.As(err, &e) { // create new flow because the old one is not valid - a, err := s.d.RegistrationHandler().NewRegistrationFlow(w, r, rr.Type) + a, err := s.d.RegistrationHandler().NewRegistrationFlow(w, r, f.Type) if err != nil { // failed to create a new session and redirect to it, handle that error as a new one - s.WriteFlowError(w, r, ct, rr, err) + s.WriteFlowError(w, r, ct, f, err) return } a.Messages.Add(text.NewErrorValidationRegistrationFlowExpired(e.ago)) - if err := s.d.RegistrationFlowPersister().UpdateRegistrationFlow(context.TODO(), a); err != nil { - redirTo, err := s.d.SelfServiceErrorManager().Create(r.Context(), w, r, err) - if err != nil { - s.WriteFlowError(w, r, ct, rr, err) - return - } - http.Redirect(w, r, redirTo, http.StatusFound) + if err := s.d.RegistrationFlowPersister().UpdateRegistrationFlow(r.Context(), a); err != nil { + s.forward(w, r, a, err) return } - http.Redirect(w, r, urlx.CopyWithQuery(s.c.SelfServiceFlowRegisterUI(), url.Values{"request": {a.ID.String()}}).String(), http.StatusFound) + if f.Type == flow.TypeAPI { + http.Redirect(w, r, urlx.CopyWithQuery(urlx.AppendPaths(s.c.SelfPublicURL(), + RouteGetFlow), url.Values{"id": {a.ID.String()}}).String(), http.StatusFound) + } else { + http.Redirect(w, r, a.AppendTo(s.c.SelfServiceFlowRegistrationUI()).String(), http.StatusFound) + } return } - if rr == nil { - s.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err) - return - } else if x.IsJSONRequest(r) { - s.d.Writer().WriteError(w, r, err) + method, ok := f.Methods[ct] + if !ok { + s.forward(w, r, f, errors.WithStack(herodot.ErrInternalServerError. + WithErrorf(`Expected registration method "%s" to exist in flow. This is a bug in the code and should be reported on GitHub.`, ct))) return } - method, ok := rr.Methods[ct] - if !ok { - s.d.Writer().WriteError(w, r, errors.WithStack(herodot.ErrInternalServerError.WithDebugf("Methods: %+v", rr.Methods).WithErrorf(`Expected registration method "%s" to exist in request. This is a bug in the code and should be reported on GitHub.`, ct))) + if err := method.Config.ParseError(err); err != nil { + s.forward(w, r, f, err) return } - if err := method.Config.ParseError(err); err != nil { - s.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err) + if err := s.d.RegistrationFlowPersister().UpdateRegistrationFlowMethod(r.Context(), f.ID, ct, method); err != nil { + s.forward(w, r, f, err) return } - if err := s.d.RegistrationFlowPersister().UpdateRegistrationFlowMethod(r.Context(), rr.ID, ct, method); err != nil { - s.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err) + if f.Type == flow.TypeBrowser { + http.Redirect(w, r, f.AppendTo(s.c.SelfServiceFlowRegistrationUI()).String(), http.StatusFound) return } - http.Redirect(w, r, - urlx.CopyWithQuery(s.c.SelfServiceFlowRegisterUI(), url.Values{"request": {rr.ID.String()}}).String(), - http.StatusFound, - ) + innerRegistrationFlow, innerErr := s.d.RegistrationFlowPersister().GetRegistrationFlow(r.Context(), f.ID) + if innerErr != nil { + s.forward(w, r, innerRegistrationFlow, innerErr) + } + + s.d.Writer().WriteCode(w, r, x.RecoverStatusCode(err, http.StatusBadRequest), innerRegistrationFlow) } func (s *ErrorHandler) forward(w http.ResponseWriter, r *http.Request, rr *Flow, err error) { diff --git a/selfservice/flow/registration/error_test.go b/selfservice/flow/registration/error_test.go new file mode 100644 index 00000000000..441cf49f3da --- /dev/null +++ b/selfservice/flow/registration/error_test.go @@ -0,0 +1,250 @@ +package registration_test + +import ( + "context" + "io/ioutil" + "net/http" + "testing" + "time" + + "github.com/gobuffalo/httptest" + "github.com/julienschmidt/httprouter" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" + + "github.com/ory/viper" + + "github.com/ory/x/assertx" + "github.com/ory/x/urlx" + + "github.com/ory/herodot" + + "github.com/ory/kratos/driver/configuration" + "github.com/ory/kratos/identity" + "github.com/ory/kratos/internal" + "github.com/ory/kratos/internal/httpclient/client/common" + "github.com/ory/kratos/internal/httpclient/models" + "github.com/ory/kratos/internal/testhelpers" + "github.com/ory/kratos/schema" + "github.com/ory/kratos/selfservice/flow" + "github.com/ory/kratos/selfservice/flow/registration" + "github.com/ory/kratos/text" + "github.com/ory/kratos/x" +) + +func TestHandleError(t *testing.T) { + conf, reg := internal.NewFastRegistryWithMocks(t) + viper.Set(configuration.ViperKeyDefaultIdentitySchemaURL, "file://./stub/login.schema.json") + + public, admin := testhelpers.NewKratosServer(t, reg) + + router := httprouter.New() + ts := httptest.NewServer(router) + t.Cleanup(ts.Close) + + testhelpers.NewRegistrationUIFlowEchoServer(t, reg) + testhelpers.NewErrorTestServer(t, reg) + + h := reg.RegistrationFlowErrorHandler() + sdk := testhelpers.NewSDKClient(admin) + + var registrationFlow *registration.Flow + var flowError error + var ct identity.CredentialsType + router.GET("/error", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + h.WriteFlowError(w, r, ct, registrationFlow, flowError) + }) + + reset := func() { + registrationFlow = nil + flowError = nil + ct = "" + } + + newFlow := func(t *testing.T, ttl time.Duration, ft flow.Type) *registration.Flow { + req := &http.Request{URL: urlx.ParseOrPanic("/")} + f := registration.NewFlow(ttl, "csrf_token", req, ft) + for _, s := range reg.RegistrationStrategies() { + require.NoError(t, s.PopulateRegistrationMethod(req, f)) + } + + require.NoError(t, reg.RegistrationFlowPersister().CreateRegistrationFlow(context.Background(), f)) + return f + } + + expectErrorUI := func(t *testing.T) (interface{}, *http.Response) { + res, err := ts.Client().Get(ts.URL + "/error") + require.NoError(t, err) + defer res.Body.Close() + require.Contains(t, res.Request.URL.String(), conf.SelfServiceFlowErrorURL().String()+"?error=") + + sse, err := sdk.Common.GetSelfServiceError(common.NewGetSelfServiceErrorParams(). + WithError(res.Request.URL.Query().Get("error"))) + require.NoError(t, err) + + return sse.Payload.Errors, nil + } + + t.Run("case=error with nil flow defaults to error ui redirect", func(t *testing.T) { + t.Cleanup(reset) + + flowError = herodot.ErrInternalServerError.WithReason("system error") + ct = identity.CredentialsTypePassword + + sse, _ := expectErrorUI(t) + assertx.EqualAsJSON(t, []interface{}{flowError}, sse) + }) + + t.Run("case=error with nil flow detects application/json", func(t *testing.T) { + t.Cleanup(reset) + + flowError = herodot.ErrInternalServerError.WithReason("system error") + ct = identity.CredentialsTypePassword + + res, err := ts.Client().Do(testhelpers.NewHTTPGetJSONRequest(t, ts.URL+"/error")) + require.NoError(t, err) + defer res.Body.Close() + assert.Contains(t, res.Header.Get("Content-Type"), "application/json") + assert.NotContains(t, res.Request.URL.String(), conf.SelfServiceFlowErrorURL().String()+"?error=") + + body, err := ioutil.ReadAll(res.Body) + require.NoError(t, err) + assert.Contains(t, string(body), "system error") + }) + + t.Run("flow=api", func(t *testing.T) { + t.Run("case=expired error", func(t *testing.T) { + t.Cleanup(reset) + + registrationFlow = newFlow(t, time.Minute, flow.TypeAPI) + flowError = registration.NewFlowExpiredError(time.Hour) + ct = identity.CredentialsTypePassword + + res, err := ts.Client().Do(testhelpers.NewHTTPGetJSONRequest(t, ts.URL+"/error")) + require.NoError(t, err) + defer res.Body.Close() + require.Contains(t, res.Request.URL.String(), public.URL+registration.RouteGetFlow) + require.Equal(t, http.StatusOK, res.StatusCode) + + body, err := ioutil.ReadAll(res.Body) + require.NoError(t, err) + assert.Equal(t, int(text.ErrorValidationRegistrationFlowExpired), int(gjson.GetBytes(body, "messages.0.id").Int())) + assert.NotEqual(t, registrationFlow.ID.String(), gjson.GetBytes(body, "id").String()) + }) + + t.Run("case=validation error", func(t *testing.T) { + t.Cleanup(reset) + + registrationFlow = newFlow(t, time.Minute, flow.TypeAPI) + flowError = schema.NewInvalidCredentialsError() + ct = identity.CredentialsTypePassword + + res, err := ts.Client().Do(testhelpers.NewHTTPGetJSONRequest(t, ts.URL+"/error")) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusBadRequest, res.StatusCode) + + body, err := ioutil.ReadAll(res.Body) + require.NoError(t, err) + assert.Equal(t, int(text.ErrorValidationInvalidCredentials), int(gjson.GetBytes(body, "methods.password.config.messages.0.id").Int()), "%s", body) + assert.Equal(t, registrationFlow.ID.String(), gjson.GetBytes(body, "id").String()) + }) + + t.Run("case=generic error", func(t *testing.T) { + t.Cleanup(reset) + + registrationFlow = newFlow(t, time.Minute, flow.TypeAPI) + flowError = herodot.ErrInternalServerError.WithReason("system error") + ct = identity.CredentialsTypePassword + + res, err := ts.Client().Do(testhelpers.NewHTTPGetJSONRequest(t, ts.URL+"/error")) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusInternalServerError, res.StatusCode) + + body, err := ioutil.ReadAll(res.Body) + require.NoError(t, err) + assert.JSONEq(t, x.MustEncodeJSON(t, flowError), gjson.GetBytes(body, "error").Raw) + }) + + t.Run("case=method is unknown", func(t *testing.T) { + t.Cleanup(reset) + + registrationFlow = newFlow(t, time.Minute, flow.TypeAPI) + flowError = herodot.ErrInternalServerError.WithReason("system error") + ct = "invalid-method" + + res, err := ts.Client().Do(testhelpers.NewHTTPGetJSONRequest(t, ts.URL+"/error")) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusInternalServerError, res.StatusCode) + + body, err := ioutil.ReadAll(res.Body) + require.NoError(t, err) + assert.Contains(t, gjson.GetBytes(body, "error.message").String(), "invalid-method", "%s", body) + }) + }) + + t.Run("flow=browser", func(t *testing.T) { + expectRegistrationUI := func(t *testing.T) (*models.RegistrationFlow, *http.Response) { + res, err := ts.Client().Get(ts.URL + "/error") + require.NoError(t, err) + defer res.Body.Close() + assert.Contains(t, res.Request.URL.String(), conf.SelfServiceFlowRegistrationUI().String()+"?flow=") + + lf, err := sdk.Common.GetSelfServiceRegistrationFlow(common.NewGetSelfServiceRegistrationFlowParams(). + WithID(res.Request.URL.Query().Get("flow"))) + require.NoError(t, err) + return lf.Payload, res + } + + t.Run("case=expired error", func(t *testing.T) { + t.Cleanup(reset) + + registrationFlow = ®istration.Flow{Type: flow.TypeBrowser} + flowError = registration.NewFlowExpiredError(time.Hour) + ct = identity.CredentialsTypePassword + + lf, _ := expectRegistrationUI(t) + require.Len(t, lf.Messages, 1) + assert.Equal(t, int(text.ErrorValidationRegistrationFlowExpired), int(lf.Messages[0].ID)) + }) + + t.Run("case=validation error", func(t *testing.T) { + t.Cleanup(reset) + + registrationFlow = newFlow(t, time.Minute, flow.TypeBrowser) + flowError = schema.NewInvalidCredentialsError() + ct = identity.CredentialsTypePassword + + lf, _ := expectRegistrationUI(t) + require.NotEmpty(t, lf.Methods[string(ct)], x.MustEncodeJSON(t, lf)) + require.Len(t, lf.Methods[string(ct)].Config.Messages, 1, x.MustEncodeJSON(t, lf)) + assert.Equal(t, int(text.ErrorValidationInvalidCredentials), int(lf.Methods[string(ct)].Config.Messages[0].ID), x.MustEncodeJSON(t, lf)) + }) + + t.Run("case=generic error", func(t *testing.T) { + t.Cleanup(reset) + + registrationFlow = newFlow(t, time.Minute, flow.TypeBrowser) + flowError = herodot.ErrInternalServerError.WithReason("system error") + ct = identity.CredentialsTypePassword + + sse, _ := expectErrorUI(t) + assertx.EqualAsJSON(t, []interface{}{flowError}, sse) + }) + + t.Run("case=method is unknown", func(t *testing.T) { + t.Cleanup(reset) + + registrationFlow = newFlow(t, time.Minute, flow.TypeBrowser) + flowError = herodot.ErrInternalServerError.WithReason("system error") + ct = "invalid-method" + + sse, _ := expectErrorUI(t) + body := x.MustEncodeJSON(t, sse) + assert.Contains(t, gjson.Get(body, "0.message").String(), "invalid-method", "%s", body) + }) + }) +} diff --git a/selfservice/flow/registration/handler.go b/selfservice/flow/registration/handler.go index 81e5ca26b68..922ccca2de3 100644 --- a/selfservice/flow/registration/handler.go +++ b/selfservice/flow/registration/handler.go @@ -149,7 +149,7 @@ func (h *Handler) initBrowserFlow(w http.ResponseWriter, r *http.Request, ps htt return } - redirTo := a.AppendTo(h.c.SelfServiceFlowRegisterUI()).String() + redirTo := a.AppendTo(h.c.SelfServiceFlowRegistrationUI()).String() if _, err := h.d.SessionManager().FetchFromRequest(r.Context(), r); err == nil { redirTo = h.c.SelfServiceBrowserDefaultReturnTo().String() } diff --git a/selfservice/flow/registration/handler_test.go b/selfservice/flow/registration/handler_test.go index 4334e90fe2b..3eccfd33257 100644 --- a/selfservice/flow/registration/handler_test.go +++ b/selfservice/flow/registration/handler_test.go @@ -9,11 +9,12 @@ import ( "testing" "time" - "github.com/ory/x/assertx" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" + "github.com/ory/x/assertx" + "github.com/ory/viper" "github.com/ory/kratos/driver/configuration" @@ -97,7 +98,7 @@ func TestInitFlow(t *testing.T) { t.Run("flow=api", func(t *testing.T) { t.Run("case=creates a new flow on unauthenticated request", func(t *testing.T) { - res, body := initFlow(t,true) + res, body := initFlow(t, true) assert.Contains(t, res.Request.URL.String(), registration.RouteInitAPIFlow) assertion(body, false, true) })