Skip to content

Commit

Permalink
refactor: rename registration request to flow
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr committed Aug 25, 2020
1 parent d7189a9 commit 8437ebc
Show file tree
Hide file tree
Showing 33 changed files with 197 additions and 168 deletions.
4 changes: 2 additions & 2 deletions cmd/daemon/serve.go
Expand Up @@ -131,8 +131,8 @@ func sqa(cmd *cobra.Command, d driver.Driver) *metricsx.Service {
login.RouteInitBrowserFlow,
login.RouteGetFlow,
logout.BrowserLogoutPath,
registration.BrowserRegistrationPath,
registration.BrowserRegistrationRequestsPath,
registration.RouteInitBrowserFlow,
registration.RouteGetFlow,
session.RouteWhoami,
identity.IdentitiesPath,
profile.PublicSettingsProfilePath,
Expand Down
2 changes: 1 addition & 1 deletion driver/registry.go
Expand Up @@ -104,7 +104,7 @@ type Registry interface {

logout.HandlerProvider

registration.RequestPersistenceProvider
registration.FlowPersistenceProvider
registration.ErrorHandlerProvider
registration.HooksProvider
registration.HookExecutorProvider
Expand Down
2 changes: 1 addition & 1 deletion driver/registry_default.go
Expand Up @@ -482,7 +482,7 @@ func (m *RegistryDefault) PrivilegedIdentityPool() identity.PrivilegedPool {
return m.persister
}

func (m *RegistryDefault) RegistrationRequestPersister() registration.RequestPersister {
func (m *RegistryDefault) RegistrationFlowPersister() registration.FlowPersister {
return m.persister
}

Expand Down
2 changes: 1 addition & 1 deletion driver/registry_default_registration.go
Expand Up @@ -65,7 +65,7 @@ func (m *RegistryDefault) RegistrationHandler() *registration.Handler {
return m.selfserviceRegistrationHandler
}

func (m *RegistryDefault) RegistrationRequestErrorHandler() *registration.ErrorHandler {
func (m *RegistryDefault) RegistrationFlowErrorHandler() *registration.ErrorHandler {
if m.selfserviceRegistrationRequestErrorHandler == nil {
m.selfserviceRegistrationRequestErrorHandler = registration.NewErrorHandler(m, m.c)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/faker.go
Expand Up @@ -99,7 +99,7 @@ func RegisterFakes() {
panic(err)
}

if err := faker.AddProvider("registration_request_methods", func(v reflect.Value) (interface{}, error) {
if err := faker.AddProvider("registration_flow_methods", func(v reflect.Value) (interface{}, error) {
var methods = make(map[identity.CredentialsType]*registration.RequestMethod)
for _, ct := range []identity.CredentialsType{identity.CredentialsTypePassword, identity.CredentialsTypeOIDC} {
var f form.HTMLForm
Expand Down
2 changes: 1 addition & 1 deletion internal/testhelpers/sql.go
Expand Up @@ -26,7 +26,7 @@ func CleanSQL(t *testing.T, c *pop.Connection) {
new(login.FlowMethods).TableName(),
new(login.Flow).TableName(),
new(registration.RequestMethods).TableName(),
new(registration.Request).TableName(),
new(registration.Flow).TableName(),
new(settings.RequestMethods).TableName(),
new(settings.Request).TableName(),
new(recoverytoken.Token).TableName(),
Expand Down
2 changes: 1 addition & 1 deletion persistence/reference.go
Expand Up @@ -26,7 +26,7 @@ type Provider interface {
type Persister interface {
continuity.Persister
identity.PrivilegedPool
registration.RequestPersister
registration.FlowPersister
login.FlowPersister
settings.RequestPersister
courier.Persister
Expand Down
4 changes: 2 additions & 2 deletions persistence/sql/migratest/migration_test.go
Expand Up @@ -135,11 +135,11 @@ func TestMigrations(t *testing.T) {
}
})
t.Run("case=registration", func(t *testing.T) {
var ids []registration.Request
var ids []registration.Flow
require.NoError(t, c.Select("id").All(&ids))

for _, id := range ids {
actual, err := d.Registry().RegistrationRequestPersister().GetRegistrationRequest(context.Background(), id.ID)
actual, err := d.Registry().RegistrationFlowPersister().GetRegistrationFlow(context.Background(), id.ID)
require.NoError(t, err)
compareWithFixture(t, actual, "registration_request", id.ID.String())
}
Expand Down
16 changes: 8 additions & 8 deletions persistence/sql/persister_registration.go
Expand Up @@ -12,14 +12,14 @@ import (
"github.com/ory/kratos/selfservice/flow/registration"
)

func (p *Persister) CreateRegistrationRequest(ctx context.Context, r *registration.Flow) error {
func (p *Persister) CreateRegistrationFlow(ctx context.Context, r *registration.Flow) error {
return p.GetConnection(ctx).Eager().Create(r)
}

func (p *Persister) UpdateRegistrationRequest(ctx context.Context, r *registration.Flow) error {
func (p *Persister) UpdateRegistrationFlow(ctx context.Context, r *registration.Flow) error {
return p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error {

rr, err := p.GetRegistrationRequest(ctx, r.ID)
rr, err := p.GetRegistrationFlow(ctx, r.ID)
if err != nil {
return err
}
Expand All @@ -31,7 +31,7 @@ func (p *Persister) UpdateRegistrationRequest(ctx context.Context, r *registrati
}

for _, of := range r.Methods {
of.RequestID = r.ID
of.FlowID = r.ID
if err := tx.Save(of); err != nil {
return sqlcon.HandleError(err)
}
Expand All @@ -41,7 +41,7 @@ func (p *Persister) UpdateRegistrationRequest(ctx context.Context, r *registrati
})
}

func (p *Persister) GetRegistrationRequest(ctx context.Context, id uuid.UUID) (*registration.Flow, error) {
func (p *Persister) GetRegistrationFlow(ctx context.Context, id uuid.UUID) (*registration.Flow, error) {
var r registration.Flow
if err := p.GetConnection(ctx).Eager().Find(&r, id); err != nil {
return nil, sqlcon.HandleError(err)
Expand All @@ -54,17 +54,17 @@ func (p *Persister) GetRegistrationRequest(ctx context.Context, id uuid.UUID) (*
return &r, nil
}

func (p *Persister) UpdateRegistrationRequestMethod(ctx context.Context, id uuid.UUID, ct identity.CredentialsType, rm *registration.RequestMethod) error {
func (p *Persister) UpdateRegistrationFlowMethod(ctx context.Context, id uuid.UUID, ct identity.CredentialsType, rm *registration.RequestMethod) error {
return p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error {

rr, err := p.GetRegistrationRequest(ctx, id)
rr, err := p.GetRegistrationFlow(ctx, id)
if err != nil {
return err
}

method, ok := rr.Methods[ct]
if !ok {
rm.RequestID = rr.ID
rm.FlowID = rr.ID
rm.Method = ct
return tx.Save(rm)
}
Expand Down
2 changes: 1 addition & 1 deletion persistence/sql/persister_test.go
Expand Up @@ -140,7 +140,7 @@ func TestPersister(t *testing.T) {
})
t.Run("contract=registration.TestFlowPersister", func(t *testing.T) {
pop.SetLogger(pl(t))
registration.TestRequestPersister(p)(t)
registration.TestFlowPersister(p)(t)
})
t.Run("contract=errorx.TestPersister", func(t *testing.T) {
pop.SetLogger(pl(t))
Expand Down
2 changes: 1 addition & 1 deletion selfservice/flow/login/flow_method.go
Expand Up @@ -31,7 +31,7 @@ type FlowMethod struct {
// FlowID is a helper struct field for gobuffalo.pop.
FlowID uuid.UUID `json:"-" db:"selfservice_login_flow_id"`

// Request is a helper struct field for gobuffalo.pop.
// Flow is a helper struct field for gobuffalo.pop.
Flow *Flow `json:"-" belongs_to:"selfservice_login_flow" fk_id:"FlowID"`

// CreatedAt is a helper struct field for gobuffalo.pop.
Expand Down
2 changes: 1 addition & 1 deletion selfservice/flow/login/flow_test.go
Expand Up @@ -39,7 +39,7 @@ func TestNewFlow(t *testing.T) {
URL: urlx.ParseOrPanic("/"),
Host: "ory.sh", TLS: &tls.ConnectionState{},
}, flow.TypeBrowser)
assert.Equal(t, r.IssuedAt, r.ExpiresAt)
assert.EqualValues(t, r.IssuedAt, r.ExpiresAt)
assert.Equal(t, flow.TypeBrowser, r.Type)
assert.False(t, r.Forced)
assert.Equal(t, "https://ory.sh/", r.RequestURL)
Expand Down
43 changes: 31 additions & 12 deletions selfservice/flow/registration/error.go
Expand Up @@ -6,6 +6,7 @@ import (
"net/url"
"time"

"github.com/ory/kratos/selfservice/flow"
"github.com/ory/kratos/text"

"github.com/pkg/errors"
Expand All @@ -29,25 +30,25 @@ type (
x.WriterProvider
x.LoggingProvider

RequestPersistenceProvider
FlowPersistenceProvider
HandlerProvider
}

ErrorHandlerProvider interface{ RegistrationRequestErrorHandler() *ErrorHandler }
ErrorHandlerProvider interface{ RegistrationFlowErrorHandler() *ErrorHandler }

ErrorHandler struct {
d errorHandlerDependencies
c configuration.Provider
}

requestExpiredError struct {
FlowExpiredError struct {
*herodot.DefaultError
ago time.Duration
}
)

func newRequestExpiredError(ago time.Duration) *requestExpiredError {
return &requestExpiredError{
func NewFlowExpiredError(ago time.Duration) *FlowExpiredError {
return &FlowExpiredError{
ago: ago,
DefaultError: herodot.ErrBadRequest.
WithError("registration request expired").
Expand All @@ -63,11 +64,11 @@ func NewErrorHandler(d errorHandlerDependencies, c configuration.Provider) *Erro
}
}

func (s *ErrorHandler) HandleRegistrationError(
func (s *ErrorHandler) WriteFlowError(
w http.ResponseWriter,
r *http.Request,
ct identity.CredentialsType,
rr *Request,
rr *Flow,
err error,
) {
s.d.Audit().
Expand All @@ -76,20 +77,20 @@ func (s *ErrorHandler) HandleRegistrationError(
WithField("registration_request", rr).
Info("Encountered self-service request error.")

if e := new(requestExpiredError); errors.As(err, &e) {
if e := new(FlowExpiredError); errors.As(err, &e) {
// create new request because the old one is not valid
a, err := s.d.RegistrationHandler().NewRegistrationRequest(w, r)
if err != nil {
// failed to create a new session and redirect to it, handle that error as a new one
s.HandleRegistrationError(w, r, ct, rr, err)
s.WriteFlowError(w, r, ct, rr, err)
return
}

a.Messages.Add(text.NewErrorValidationRegistrationRequestExpired(e.ago))
if err := s.d.RegistrationRequestPersister().UpdateRegistrationRequest(context.TODO(), a); err != nil {
if err := s.d.RegistrationFlowPersister().UpdateRegistrationFlow(context.TODO(), a); err != nil {
redirTo, err := s.d.SelfServiceErrorManager().Create(r.Context(), w, r, err)
if err != nil {
s.HandleRegistrationError(w, r, ct, rr, err)
s.WriteFlowError(w, r, ct, rr, err)
return
}
http.Redirect(w, r, redirTo, http.StatusFound)
Expand Down Expand Up @@ -119,7 +120,7 @@ func (s *ErrorHandler) HandleRegistrationError(
return
}

if err := s.d.RegistrationRequestPersister().UpdateRegistrationRequestMethod(r.Context(), rr.ID, ct, method); err != nil {
if err := s.d.RegistrationFlowPersister().UpdateRegistrationFlowMethod(r.Context(), rr.ID, ct, method); err != nil {
s.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err)
return
}
Expand All @@ -129,3 +130,21 @@ func (s *ErrorHandler) HandleRegistrationError(
http.StatusFound,
)
}

func (s *ErrorHandler) forward(w http.ResponseWriter, r *http.Request, rr *Flow, err error) {
if rr == nil {
if x.IsJSONRequest(r) {
s.d.Writer().WriteError(w, r, err)
return
}
s.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err)
return
}

if rr.Type == flow.TypeAPI {
s.d.Writer().WriteErrorCode(w, r, x.RecoverStatusCode(err, http.StatusBadRequest), err)
} else {
s.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err)
}
}

Expand Up @@ -11,18 +11,22 @@ import (
"github.com/ory/x/urlx"

"github.com/ory/kratos/identity"
"github.com/ory/kratos/selfservice/flow"
"github.com/ory/kratos/text"
"github.com/ory/kratos/x"
)

// swagger:model registrationRequest
type Request struct {
type Flow struct {
// ID represents the request's unique ID. When performing the registration flow, this
// represents the id in the registration ui's query parameter: http://<selfservice.flows.registration.ui_url>/?request=<id>
//
// required: true
ID uuid.UUID `json:"id" faker:"-" db:"id"`

// Type represents the flow's type which can be either "api" or "browser", depending on the flow interaction.
Type flow.Type `json:"type" db:"type" faker:"flow_type"`

// ExpiresAt is the time (UTC) when the request expires. If the user still wishes to log in,
// a new request has to be initiated.
//
Expand Down Expand Up @@ -54,10 +58,10 @@ type Request struct {
// processed, but for example the password is incorrect, this will contain error messages.
//
// required: true
Methods map[identity.CredentialsType]*RequestMethod `json:"methods" faker:"registration_request_methods" db:"-"`
Methods map[identity.CredentialsType]*RequestMethod `json:"methods" faker:"registration_flow_methods" db:"-"`

// MethodsRaw is a helper struct field for gobuffalo.pop.
MethodsRaw RequestMethodsRaw `json:"-" faker:"-" has_many:"selfservice_registration_request_methods" fk_id:"selfservice_registration_request_id"`
MethodsRaw RequestMethodsRaw `json:"-" faker:"-" has_many:"selfservice_registration_flow_methods" fk_id:"selfservice_registration_flow_id"`

// CreatedAt is a helper struct field for gobuffalo.pop.
CreatedAt time.Time `json:"-" faker:"-" db:"created_at"`
Expand All @@ -69,7 +73,7 @@ type Request struct {
CSRFToken string `json:"-" db:"csrf_token"`
}

func NewRequest(exp time.Duration, csrf string, r *http.Request) *Request {
func NewFlow(exp time.Duration, csrf string, r *http.Request, ft flow.Type) *Flow {
source := urlx.Copy(r.URL)
source.Host = r.Host

Expand All @@ -78,17 +82,18 @@ func NewRequest(exp time.Duration, csrf string, r *http.Request) *Request {
source.Scheme = "https"
}

return &Request{
return &Flow{
ID: x.NewUUID(),
ExpiresAt: time.Now().UTC().Add(exp),
IssuedAt: time.Now().UTC(),
RequestURL: source.String(),
Methods: map[identity.CredentialsType]*RequestMethod{},
CSRFToken: csrf,
Type: ft,
}
}

func (r *Request) BeforeSave(_ *pop.Connection) error {
func (r *Flow) BeforeSave(_ *pop.Connection) error {
r.MethodsRaw = make([]RequestMethod, 0, len(r.Methods))
for _, m := range r.Methods {
r.MethodsRaw = append(r.MethodsRaw, *m)
Expand All @@ -97,15 +102,15 @@ func (r *Request) BeforeSave(_ *pop.Connection) error {
return nil
}

func (r *Request) AfterCreate(c *pop.Connection) error {
func (r *Flow) AfterCreate(c *pop.Connection) error {
return r.AfterFind(c)
}

func (r *Request) AfterUpdate(c *pop.Connection) error {
func (r *Flow) AfterUpdate(c *pop.Connection) error {
return r.AfterFind(c)
}

func (r *Request) AfterFind(_ *pop.Connection) error {
func (r *Flow) AfterFind(_ *pop.Connection) error {
r.Methods = make(RequestMethods)
for key := range r.MethodsRaw {
m := r.MethodsRaw[key] // required for pointer dereference
Expand All @@ -115,18 +120,18 @@ func (r *Request) AfterFind(_ *pop.Connection) error {
return nil
}

func (r Request) TableName() string {
func (r Flow) TableName() string {
// This must be stay a value receiver, using a pointer receiver will cause issues with pop.
return "selfservice_registration_requests"
return "selfservice_registration_flows"
}

func (r *Request) GetID() uuid.UUID {
func (r *Flow) GetID() uuid.UUID {
return r.ID
}

func (r *Request) Valid() error {
func (r *Flow) Valid() error {
if r.ExpiresAt.Before(time.Now()) {
return errors.WithStack(newRequestExpiredError(time.Since(r.ExpiresAt)))
return errors.WithStack(NewFlowExpiredError(time.Since(r.ExpiresAt)))
}
return nil
}

0 comments on commit 8437ebc

Please sign in to comment.