Skip to content

Commit

Permalink
refactor: replace all registration request occurrences with registrat…
Browse files Browse the repository at this point in the history
…ion flow
  • Loading branch information
aeneasr committed Aug 25, 2020
1 parent 8437ebc commit 308ef47
Show file tree
Hide file tree
Showing 15 changed files with 100 additions and 100 deletions.
6 changes: 3 additions & 3 deletions internal/faker.go
Expand Up @@ -100,15 +100,15 @@ func RegisterFakes() {
}

if err := faker.AddProvider("registration_flow_methods", func(v reflect.Value) (interface{}, error) {
var methods = make(map[identity.CredentialsType]*registration.RequestMethod)
var methods = make(map[identity.CredentialsType]*registration.FlowMethod)
for _, ct := range []identity.CredentialsType{identity.CredentialsTypePassword, identity.CredentialsTypeOIDC} {
var f form.HTMLForm
if err := faker.FakeData(&f); err != nil {
return nil, errors.WithStack(err)
}
methods[ct] = &registration.RequestMethod{
methods[ct] = &registration.FlowMethod{
Method: ct,
Config: &registration.RequestMethodConfig{RequestMethodConfigurator: &f},
Config: &registration.FlowMethodConfig{FlowMethodConfigurator: &f},
}
}
return methods, nil
Expand Down
2 changes: 1 addition & 1 deletion internal/testhelpers/sql.go
Expand Up @@ -25,7 +25,7 @@ func CleanSQL(t *testing.T, c *pop.Connection) {
new(courier.Message).TableName(),
new(login.FlowMethods).TableName(),
new(login.Flow).TableName(),
new(registration.RequestMethods).TableName(),
new(registration.FlowMethods).TableName(),
new(registration.Flow).TableName(),
new(settings.RequestMethods).TableName(),
new(settings.Request).TableName(),
Expand Down
2 changes: 1 addition & 1 deletion persistence/sql/persister_registration.go
Expand Up @@ -54,7 +54,7 @@ func (p *Persister) GetRegistrationFlow(ctx context.Context, id uuid.UUID) (*reg
return &r, nil
}

func (p *Persister) UpdateRegistrationFlowMethod(ctx context.Context, id uuid.UUID, ct identity.CredentialsType, rm *registration.RequestMethod) error {
func (p *Persister) UpdateRegistrationFlowMethod(ctx context.Context, id uuid.UUID, ct identity.CredentialsType, rm *registration.FlowMethod) error {
return p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error {

rr, err := p.GetRegistrationFlow(ctx, id)
Expand Down
16 changes: 8 additions & 8 deletions selfservice/flow/registration/error.go
Expand Up @@ -51,9 +51,9 @@ func NewFlowExpiredError(ago time.Duration) *FlowExpiredError {
return &FlowExpiredError{
ago: ago,
DefaultError: herodot.ErrBadRequest.
WithError("registration request expired").
WithReasonf(`The registration request has expired. Please restart the flow.`).
WithReasonf("The registration request expired %.2f minutes ago, please try again.", ago.Minutes()),
WithError("registration flow expired").
WithReasonf(`The registration flow has expired. Please restart the flow.`).
WithReasonf("The registration flow expired %.2f minutes ago, please try again.", ago.Minutes()),
}
}

Expand All @@ -74,19 +74,19 @@ func (s *ErrorHandler) WriteFlowError(
s.d.Audit().
WithError(err).
WithRequest(r).
WithField("registration_request", rr).
Info("Encountered self-service request error.")
WithField("registration_flow", rr).
Info("Encountered self-service flow error.")

if e := new(FlowExpiredError); errors.As(err, &e) {
// create new request because the old one is not valid
a, err := s.d.RegistrationHandler().NewRegistrationRequest(w, r)
// create new flow because the old one is not valid
a, err := s.d.RegistrationHandler().NewRegistrationFlow(w, r)
if err != nil {
// failed to create a new session and redirect to it, handle that error as a new one
s.WriteFlowError(w, r, ct, rr, err)
return
}

a.Messages.Add(text.NewErrorValidationRegistrationRequestExpired(e.ago))
a.Messages.Add(text.NewErrorValidationRegistrationFlowExpired(e.ago))
if err := s.d.RegistrationFlowPersister().UpdateRegistrationFlow(context.TODO(), a); err != nil {
redirTo, err := s.d.SelfServiceErrorManager().Create(r.Context(), w, r, err)
if err != nil {
Expand Down
26 changes: 13 additions & 13 deletions selfservice/flow/registration/flow.go
Expand Up @@ -18,22 +18,22 @@ import (

// swagger:model registrationRequest
type Flow struct {
// ID represents the request's unique ID. When performing the registration flow, this
// represents the id in the registration ui's query parameter: http://<selfservice.flows.registration.ui_url>/?request=<id>
// ID represents the flow's unique ID. When performing the registration flow, this
// represents the id in the registration ui's query parameter: http://<selfservice.flows.registration.ui_url>/?flow=<id>
//
// required: true
ID uuid.UUID `json:"id" faker:"-" db:"id"`

// Type represents the flow's type which can be either "api" or "browser", depending on the flow interaction.
Type flow.Type `json:"type" db:"type" faker:"flow_type"`

// ExpiresAt is the time (UTC) when the request expires. If the user still wishes to log in,
// a new request has to be initiated.
// ExpiresAt is the time (UTC) when the flow expires. If the user still wishes to log in,
// a new flow has to be initiated.
//
// required: true
ExpiresAt time.Time `json:"expires_at" faker:"time_type" db:"expires_at"`

// IssuedAt is the time (UTC) when the request occurred.
// IssuedAt is the time (UTC) when the flow occurred.
//
// required: true
IssuedAt time.Time `json:"issued_at" faker:"time_type" db:"issued_at"`
Expand All @@ -54,23 +54,23 @@ type Flow struct {
// More documentation on messages can be found in the [User Interface Documentation](https://www.ory.sh/kratos/docs/concepts/ui-user-interface/).
Messages text.Messages `json:"messages" db:"messages" faker:"-"`

// Methods contains context for all enabled registration methods. If a registration request has been
// Methods contains context for all enabled registration methods. If a registration flow has been
// processed, but for example the password is incorrect, this will contain error messages.
//
// required: true
Methods map[identity.CredentialsType]*RequestMethod `json:"methods" faker:"registration_flow_methods" db:"-"`
Methods map[identity.CredentialsType]*FlowMethod `json:"methods" faker:"registration_flow_methods" db:"-"`

// MethodsRaw is a helper struct field for gobuffalo.pop.
MethodsRaw RequestMethodsRaw `json:"-" faker:"-" has_many:"selfservice_registration_flow_methods" fk_id:"selfservice_registration_flow_id"`
MethodsRaw FlowMethodsRaw `json:"-" faker:"-" has_many:"selfservice_registration_flow_methods" fk_id:"selfservice_registration_flow_id"`

// CreatedAt is a helper struct field for gobuffalo.pop.
CreatedAt time.Time `json:"-" faker:"-" db:"created_at"`

// UpdatedAt is a helper struct field for gobuffalo.pop.
UpdatedAt time.Time `json:"-" faker:"-" db:"updated_at"`

// CSRFToken contains the anti-csrf token associated with this request.
CSRFToken string `json:"-" db:"csrf_token"`
// CSRFToken contains the anti-csrf token associated with this flow. Only set for browser flows.
CSRFToken string `json:"-" db:"csrf_token,omitempty"`
}

func NewFlow(exp time.Duration, csrf string, r *http.Request, ft flow.Type) *Flow {
Expand All @@ -87,14 +87,14 @@ func NewFlow(exp time.Duration, csrf string, r *http.Request, ft flow.Type) *Flo
ExpiresAt: time.Now().UTC().Add(exp),
IssuedAt: time.Now().UTC(),
RequestURL: source.String(),
Methods: map[identity.CredentialsType]*RequestMethod{},
Methods: map[identity.CredentialsType]*FlowMethod{},
CSRFToken: csrf,
Type: ft,
}
}

func (r *Flow) BeforeSave(_ *pop.Connection) error {
r.MethodsRaw = make([]RequestMethod, 0, len(r.Methods))
r.MethodsRaw = make([]FlowMethod, 0, len(r.Methods))
for _, m := range r.Methods {
r.MethodsRaw = append(r.MethodsRaw, *m)
}
Expand All @@ -111,7 +111,7 @@ func (r *Flow) AfterUpdate(c *pop.Connection) error {
}

func (r *Flow) AfterFind(_ *pop.Connection) error {
r.Methods = make(RequestMethods)
r.Methods = make(FlowMethods)
for key := range r.MethodsRaw {
m := r.MethodsRaw[key] // required for pointer dereference
r.Methods[m.Method] = &m
Expand Down
42 changes: 21 additions & 21 deletions selfservice/flow/registration/flow_method.go
Expand Up @@ -14,12 +14,12 @@ import (
)

// swagger:model registrationRequestMethod
type RequestMethod struct {
// Method contains the request credentials type.
type FlowMethod struct {
// Method contains the flow method's credentials type.
Method identity.CredentialsType `json:"method" faker:"string" db:"method"`

// Config is the credential type's config.
Config *RequestMethodConfig `json:"config" db:"config"`
Config *FlowMethodConfig `json:"config" db:"config"`

// ID is a helper struct field for gobuffalo.pop.
ID uuid.UUID `json:"-" faker:"-" db:"id"`
Expand All @@ -37,25 +37,25 @@ type RequestMethod struct {
UpdatedAt time.Time `json:"-" faker:"-" db:"updated_at"`
}

func (u RequestMethod) TableName() string {
func (u FlowMethod) TableName() string {
return "selfservice_registration_flow_methods"
}

type RequestMethodsRaw []RequestMethod // workaround for https://github.com/gobuffalo/pop/pull/478
type RequestMethods map[identity.CredentialsType]*RequestMethod
type FlowMethodsRaw []FlowMethod // workaround for https://github.com/gobuffalo/pop/pull/478
type FlowMethods map[identity.CredentialsType]*FlowMethod

func (u RequestMethods) TableName() string {
func (u FlowMethods) TableName() string {
// This must be stay a value receiver, using a pointer receiver will cause issues with pop.
return "selfservice_registration_flow_methods"
}

func (u RequestMethodsRaw) TableName() string {
func (u FlowMethodsRaw) TableName() string {
// This must be stay a value receiver, using a pointer receiver will cause issues with pop.
return "selfservice_registration_flow_methods"
}

// swagger:ignore
type RequestMethodConfigurator interface {
type FlowMethodConfigurator interface {
form.ErrorParser
form.FieldSetter
form.FieldUnsetter
Expand All @@ -68,34 +68,34 @@ type RequestMethodConfigurator interface {
}

// swagger:model registrationRequestMethodConfig
type RequestMethodConfig struct {
type FlowMethodConfig struct {
// swagger:ignore
RequestMethodConfigurator
FlowMethodConfigurator

requestMethodConfigMock
flowMethodConfigMock
}

// swagger:model registrationRequestMethodConfigPayload
type requestMethodConfigMock struct {
type flowMethodConfigMock struct {
*form.HTMLForm

// Providers is set for the "oidc" request method.
// Providers is set for the "oidc" registration method.
Providers []form.Field `json:"providers" faker:"len=3"`
}

func (c *RequestMethodConfig) Scan(value interface{}) error {
func (c *FlowMethodConfig) Scan(value interface{}) error {
return sqlxx.JSONScan(c, value)
}

func (c *RequestMethodConfig) Value() (driver.Value, error) {
func (c *FlowMethodConfig) Value() (driver.Value, error) {
return sqlxx.JSONValue(c)
}

func (c *RequestMethodConfig) UnmarshalJSON(data []byte) error {
c.RequestMethodConfigurator = form.NewHTMLForm("")
return json.Unmarshal(data, c.RequestMethodConfigurator)
func (c *FlowMethodConfig) UnmarshalJSON(data []byte) error {
c.FlowMethodConfigurator = form.NewHTMLForm("")
return json.Unmarshal(data, c.FlowMethodConfigurator)
}

func (c *RequestMethodConfig) MarshalJSON() ([]byte, error) {
return json.Marshal(c.RequestMethodConfigurator)
func (c *FlowMethodConfig) MarshalJSON() ([]byte, error) {
return json.Marshal(c.FlowMethodConfigurator)
}
20 changes: 10 additions & 10 deletions selfservice/flow/registration/handler.go
Expand Up @@ -49,14 +49,14 @@ func NewHandler(d handlerDependencies, c configuration.Provider) *Handler {

func (h *Handler) RegisterPublicRoutes(public *x.RouterPublic) {
public.GET(RouteInitBrowserFlow, h.d.SessionHandler().IsNotAuthenticated(h.initBrowserFlow, session.RedirectOnAuthenticated(h.c)))
public.GET(RouteGetFlow, h.publicFetchRegistrationRequest)
public.GET(RouteGetFlow, h.publicFetchRegistrationFlow)
}

func (h *Handler) RegisterAdminRoutes(admin *x.RouterAdmin) {
admin.GET(RouteGetFlow, h.adminFetchRegistrationRequest)
admin.GET(RouteGetFlow, h.adminFetchRegistrationFlow)
}

func (h *Handler) NewRegistrationRequest(w http.ResponseWriter, r *http.Request) (*Flow, error) {
func (h *Handler) NewRegistrationFlow(w http.ResponseWriter, r *http.Request) (*Flow, error) {
a := NewFlow(h.c.SelfServiceFlowRegistrationRequestLifespan(), h.d.GenerateCSRFToken(r), r, flow.TypeBrowser)
for _, s := range h.d.RegistrationStrategies() {
if err := s.PopulateRegistrationMethod(r, a); err != nil {
Expand All @@ -80,7 +80,7 @@ func (h *Handler) NewRegistrationRequest(w http.ResponseWriter, r *http.Request)
// Initialize browser-based registration user flow
//
// This endpoint initializes a browser-based user registration flow. Once initialized, the browser will be redirected to
// `selfservice.flows.registration.ui_url` with the request ID set as a query parameter. If a valid user session exists already, the browser will be
// `selfservice.flows.registration.ui_url` with the flow ID set as a query parameter. If a valid user session exists already, the browser will be
// redirected to `urls.default_redirect_url`.
//
// > This endpoint is NOT INTENDED for API clients and only works
Expand All @@ -94,7 +94,7 @@ func (h *Handler) NewRegistrationRequest(w http.ResponseWriter, r *http.Request)
// 302: emptyResponse
// 500: genericError
func (h *Handler) initBrowserFlow(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
a, err := h.NewRegistrationRequest(w, r)
a, err := h.NewRegistrationFlow(w, r)
if err != nil {
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err)
return
Expand Down Expand Up @@ -143,22 +143,22 @@ type getSelfServiceBrowserRegistrationRequestParameters struct {
// 404: genericError
// 410: genericError
// 500: genericError
func (h *Handler) publicFetchRegistrationRequest(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
if err := h.fetchRegistrationRequest(w, r, true); err != nil {
func (h *Handler) publicFetchRegistrationFlow(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
if err := h.fetchRegistrationFlow(w, r, true); err != nil {
h.d.Writer().WriteError(w, r, err)
return
}

}

func (h *Handler) adminFetchRegistrationRequest(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
if err := h.fetchRegistrationRequest(w, r, false); err != nil {
func (h *Handler) adminFetchRegistrationFlow(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
if err := h.fetchRegistrationFlow(w, r, false); err != nil {
h.d.Writer().WriteError(w, r, err)
return
}
}

func (h *Handler) fetchRegistrationRequest(w http.ResponseWriter, r *http.Request, isPublic bool) error {
func (h *Handler) fetchRegistrationFlow(w http.ResponseWriter, r *http.Request, isPublic bool) error {
ar, err := h.d.RegistrationFlowPersister().GetRegistrationFlow(r.Context(), x.ParseUUID(r.URL.Query().Get("request")))
if err != nil {
if isPublic {
Expand Down
12 changes: 6 additions & 6 deletions selfservice/flow/registration/handler_test.go
Expand Up @@ -70,7 +70,7 @@ func TestRegistrationHandler(t *testing.T) {
}))
}

assertRequestPayload := func(t *testing.T, body []byte) {
assertFlowPayload := func(t *testing.T, body []byte) {
assert.Equal(t, "password", gjson.GetBytes(body, "methods.password.method").String(), "%s", body)
assert.NotEmpty(t, gjson.GetBytes(body, "methods.password.config.fields.#(name==csrf_token).value").String(), "%s", body)
assert.NotEmpty(t, gjson.GetBytes(body, "id").String(), "%s", body)
Expand All @@ -84,7 +84,7 @@ func TestRegistrationHandler(t *testing.T) {
assert.Equal(t, public.URL+registration.RouteInitBrowserFlow, gjson.GetBytes(body, "error.details.redirect_to").String(), "%s", body)
}

newExpiredRequest := func() *registration.Flow {
newExpiredFLow := func() *registration.Flow {
return &registration.Flow{
ID: x.NewUUID(),
ExpiresAt: time.Now().Add(-time.Minute),
Expand All @@ -109,11 +109,11 @@ func TestRegistrationHandler(t *testing.T) {
viper.Set(configuration.ViperKeySelfServiceRegistrationUI, regTS.URL)

t.Run("case=valid", func(t *testing.T) {
assertRequestPayload(t, x.EasyGetBody(t, public.Client(), public.URL+registration.RouteInitBrowserFlow))
assertFlowPayload(t, x.EasyGetBody(t, public.Client(), public.URL+registration.RouteInitBrowserFlow))
})

t.Run("case=expired", func(t *testing.T) {
rr := newExpiredRequest()
rr := newExpiredFLow()
require.NoError(t, reg.RegistrationFlowPersister().CreateRegistrationFlow(context.Background(), rr))
res, body := x.EasyGet(t, admin.Client(), admin.URL+registration.RouteGetFlow+"?request="+rr.ID.String())
assertExpiredPayload(t, res, body)
Expand All @@ -131,7 +131,7 @@ func TestRegistrationHandler(t *testing.T) {
viper.Set(configuration.ViperKeySelfServiceRegistrationUI, regTS.URL)

body := x.EasyGetBody(t, hc, public.URL+registration.RouteInitBrowserFlow)
assertRequestPayload(t, body)
assertFlowPayload(t, body)
})

t.Run("case=without_csrf", func(t *testing.T) {
Expand All @@ -158,7 +158,7 @@ func TestRegistrationHandler(t *testing.T) {
regTS := newRegistrationTS(t, public.URL, hc)
defer regTS.Close()

rr := newExpiredRequest()
rr := newExpiredFLow()
require.NoError(t, reg.RegistrationFlowPersister().CreateRegistrationFlow(context.Background(), rr))
res, body := x.EasyGet(t, admin.Client(), admin.URL+registration.RouteGetFlow+"?request="+rr.ID.String())
assertExpiredPayload(t, res, body)
Expand Down

0 comments on commit 308ef47

Please sign in to comment.