Skip to content

Commit

Permalink
feat: add tests and helpers to test recovery/verifiable addresses (#579)
Browse files Browse the repository at this point in the history
Closes #576
  • Loading branch information
aeneasr committed Jul 16, 2020
1 parent 05e55f3 commit 29979e6
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 34 deletions.
3 changes: 2 additions & 1 deletion driver/driver_default_test.go
Expand Up @@ -4,9 +4,10 @@ import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"

driver "github.com/ory/kratos/driver"
"github.com/ory/kratos/driver/configuration"
"github.com/stretchr/testify/assert"
)

func TestDriverDefault_SQLiteMemoryMode(t *testing.T) {
Expand Down
107 changes: 79 additions & 28 deletions identity/manager_test.go
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/ory/kratos/driver/configuration"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/internal"
"github.com/ory/kratos/x"
)

func TestManager(t *testing.T) {
Expand All @@ -28,12 +29,20 @@ func TestManager(t *testing.T) {
require.Error(t, reg.IdentityManager().Create(context.Background(), i))
})

newTraits := func(email string, unprotected string) identity.Traits {
return identity.Traits(fmt.Sprintf(`{"email":"%[1]s","email_verify":"%[1]s","email_recovery":"%[1]s","email_creds":"%[1]s","unprotected": "%[2]s"}`, email, unprotected))
}

checkExtensionFields := func(i *identity.Identity, expected string) func(*testing.T) {
return func(t *testing.T) {
require.Len(t, i.VerifiableAddresses, 1)
assert.EqualValues(t, expected, i.VerifiableAddresses[0].Value)
assert.EqualValues(t, identity.VerifiableAddressTypeEmail, i.VerifiableAddresses[0].Via)

require.Len(t, i.RecoveryAddresses, 1)
assert.EqualValues(t, expected, i.RecoveryAddresses[0].Value)
assert.EqualValues(t, identity.VerifiableAddressTypeEmail, i.RecoveryAddresses[0].Via)

require.NotNil(t, i.Credentials[identity.CredentialsTypePassword])
assert.Equal(t, []string{expected}, i.Credentials[identity.CredentialsTypePassword].Identifiers)
}
Expand All @@ -50,7 +59,7 @@ func TestManager(t *testing.T) {
t.Run("method=Create", func(t *testing.T) {
t.Run("case=should create identity and track extension fields", func(t *testing.T) {
original := identity.NewIdentity(configuration.DefaultIdentityTraitsSchemaID)
original.Traits = identity.Traits(`{"email":"foo@ory.sh"}`)
original.Traits = newTraits("foo@ory.sh", "")
require.NoError(t, reg.IdentityManager().Create(context.Background(), original))
checkExtensionFieldsForIdentities(t, "foo@ory.sh", original)
})
Expand All @@ -75,39 +84,21 @@ func TestManager(t *testing.T) {
t.Run("method=Update", func(t *testing.T) {
t.Run("case=should update identity and update extension fields", func(t *testing.T) {
original := identity.NewIdentity(configuration.DefaultIdentityTraitsSchemaID)
original.Traits = identity.Traits(`{"email":"baz@ory.sh"}`)
original.Traits = newTraits("baz@ory.sh", "")
require.NoError(t, reg.IdentityManager().Create(context.Background(), original))

original.Traits = identity.Traits(`{"email":"bar@ory.sh"}`)
original.Traits = newTraits("bar@ory.sh", "")
require.NoError(t, reg.IdentityManager().Update(context.Background(), original, identity.ManagerAllowWriteProtectedTraits))

checkExtensionFieldsForIdentities(t, "bar@ory.sh", original)
})

t.Run("case=should update identity and update extension fields", func(t *testing.T) {
original := identity.NewIdentity(configuration.DefaultIdentityTraitsSchemaID)
original.Traits = identity.Traits(`{"email":"baz@ory.sh","email_verify":"baz@ory.sh","email_creds":"baz@ory.sh","unprotected": "foo"}`)
require.NoError(t, reg.IdentityManager().Create(context.Background(), original))

// These should all fail because they modify existing keys
require.Error(t, reg.IdentityManager().UpdateTraits(context.Background(), original.ID, identity.Traits(`{"email":"not-baz@ory.sh","email_verify":"baz@ory.sh","email_creds":"baz@ory.sh","unprotected": "foo"}`)))
require.Error(t, reg.IdentityManager().UpdateTraits(context.Background(), original.ID, identity.Traits(`{"email":"baz@ory.sh","email_verify":"not-baz@ory.sh","email_creds":"baz@ory.sh","unprotected": "foo"}`)))
require.Error(t, reg.IdentityManager().UpdateTraits(context.Background(), original.ID, identity.Traits(`{"email":"baz@ory.sh","email_verify":"baz@ory.sh","email_creds":"not-baz@ory.sh","unprotected": "foo"}`)))

require.NoError(t, reg.IdentityManager().UpdateTraits(context.Background(), original.ID, identity.Traits(`{"email":"baz@ory.sh","email_verify":"baz@ory.sh","email_creds":"baz@ory.sh","unprotected": "bar"}`)))
checkExtensionFieldsForIdentities(t, "baz@ory.sh", original)

actual, err := reg.IdentityPool().GetIdentity(context.Background(), original.ID)
require.NoError(t, err)
assert.JSONEq(t, `{"email":"baz@ory.sh","email_verify":"baz@ory.sh","email_creds":"baz@ory.sh","unprotected": "bar"}`, string(actual.Traits))
})

t.Run("case=should not update protected traits without option", func(t *testing.T) {
original := identity.NewIdentity(configuration.DefaultIdentityTraitsSchemaID)
original.Traits = identity.Traits(`{"email":"email-update-1@ory.sh"}`)
original.Traits = newTraits("email-update-1@ory.sh", "")
require.NoError(t, reg.IdentityManager().Create(context.Background(), original))

original.Traits = identity.Traits(`{"email":"email-update-2@ory.sh"}`)
original.Traits = newTraits("email-update-2@ory.sh", "")
err := reg.IdentityManager().Update(context.Background(), original)
require.Error(t, err)
assert.Equal(t, identity.ErrProtectedFieldModified, errors.Cause(err))
Expand All @@ -118,16 +109,58 @@ func TestManager(t *testing.T) {
// That is why we only check the identity in the store.
checkExtensionFields(fromStore, "email-update-1@ory.sh")(t)
})

t.Run("case=changing recovery address removes it from the store", func(t *testing.T) {
originalEmail := x.NewUUID().String() + "@ory.sh"
original := identity.NewIdentity(configuration.DefaultIdentityTraitsSchemaID)
original.Traits = newTraits(originalEmail, "")
require.NoError(t, reg.IdentityManager().Create(context.Background(), original))

fromStore, err := reg.PrivilegedIdentityPool().GetIdentityConfidential(context.Background(), original.ID)
require.NoError(t, err)
checkExtensionFields(fromStore, originalEmail)(t)

newEmail := x.NewUUID().String() + "@ory.sh"
original.Traits = newTraits(newEmail, "")
require.NoError(t, reg.IdentityManager().Update(context.Background(), original, identity.ManagerAllowWriteProtectedTraits))

fromStore, err = reg.PrivilegedIdentityPool().GetIdentityConfidential(context.Background(), original.ID)
require.NoError(t, err)
checkExtensionFields(fromStore, newEmail)(t)

recoveryAddresses, err := reg.PrivilegedIdentityPool().ListRecoveryAddresses(context.Background(), 0, 500)
require.NoError(t, err)

var foundRecoveryAddress bool
for _, a := range recoveryAddresses {
assert.NotEqual(t, a.Value, originalEmail)
if a.Value == newEmail {
foundRecoveryAddress = true
}
}
require.True(t, foundRecoveryAddress)

verifiableAddresses, err := reg.PrivilegedIdentityPool().ListVerifiableAddresses(context.Background(), 0, 500)
require.NoError(t, err)
var foundVerifiableAddress bool
for _, a := range verifiableAddresses {
assert.NotEqual(t, a.Value, originalEmail)
if a.Value == newEmail {
foundVerifiableAddress = true
}
}
require.True(t, foundVerifiableAddress)
})
})

t.Run("method=UpdateTraits", func(t *testing.T) {
t.Run("case=should update protected traits with option", func(t *testing.T) {
original := identity.NewIdentity(configuration.DefaultIdentityTraitsSchemaID)
original.Traits = identity.Traits(`{"email":"email-updatetraits-1@ory.sh"}`)
original.Traits = newTraits("email-updatetraits-1@ory.sh", "")
require.NoError(t, reg.IdentityManager().Create(context.Background(), original))

require.NoError(t, reg.IdentityManager().UpdateTraits(
context.Background(), original.ID, identity.Traits(`{"email":"email-updatetraits-2@ory.sh"}`),
context.Background(), original.ID, newTraits("email-updatetraits-2@ory.sh", ""),
identity.ManagerAllowWriteProtectedTraits))

fromStore, err := reg.PrivilegedIdentityPool().GetIdentityConfidential(context.Background(), original.ID)
Expand All @@ -137,13 +170,31 @@ func TestManager(t *testing.T) {
checkExtensionFields(fromStore, "email-updatetraits-2@ory.sh")(t)
})

t.Run("case=should update identity and update extension fields", func(t *testing.T) {
original := identity.NewIdentity(configuration.DefaultIdentityTraitsSchemaID)
original.Traits = identity.Traits(`{"email":"baz@ory.sh","email_verify":"baz@ory.sh","email_recovery":"baz@ory.sh","email_creds":"baz@ory.sh","unprotected": "foo"}`)
require.NoError(t, reg.IdentityManager().Create(context.Background(), original))

// These should all fail because they modify existing keys
require.Error(t, reg.IdentityManager().UpdateTraits(context.Background(), original.ID, identity.Traits(`{"email":"not-baz@ory.sh","email_verify":"baz@ory.sh","email_recovery":"baz@ory.sh","email_creds":"baz@ory.sh","unprotected": "foo"}`)))
require.Error(t, reg.IdentityManager().UpdateTraits(context.Background(), original.ID, identity.Traits(`{"email":"baz@ory.sh","email_verify":"not-baz@ory.sh","email_recovery":"not-baz@ory.sh","email_creds":"baz@ory.sh","unprotected": "foo"}`)))
require.Error(t, reg.IdentityManager().UpdateTraits(context.Background(), original.ID, identity.Traits(`{"email":"baz@ory.sh","email_verify":"baz@ory.sh","email_recovery":"baz@ory.sh","email_creds":"not-baz@ory.sh","unprotected": "foo"}`)))

require.NoError(t, reg.IdentityManager().UpdateTraits(context.Background(), original.ID, identity.Traits(`{"email":"baz@ory.sh","email_verify":"baz@ory.sh","email_recovery":"baz@ory.sh","email_creds":"baz@ory.sh","unprotected": "bar"}`)))
checkExtensionFieldsForIdentities(t, "baz@ory.sh", original)

actual, err := reg.IdentityPool().GetIdentity(context.Background(), original.ID)
require.NoError(t, err)
assert.JSONEq(t, `{"email":"baz@ory.sh","email_verify":"baz@ory.sh","email_recovery":"baz@ory.sh","email_creds":"baz@ory.sh","unprotected": "bar"}`, string(actual.Traits))
})

t.Run("case=should not update protected traits without option", func(t *testing.T) {
original := identity.NewIdentity(configuration.DefaultIdentityTraitsSchemaID)
original.Traits = identity.Traits(`{"email":"email-updatetraits-1@ory.sh"}`)
original.Traits = newTraits("email-updatetraits-1@ory.sh", "")
require.NoError(t, reg.IdentityManager().Create(context.Background(), original))

err := reg.IdentityManager().UpdateTraits(
context.Background(), original.ID, identity.Traits(`{"email":"email-updatetraits-2@ory.sh"}`))
context.Background(), original.ID, newTraits("email-updatetraits-2@ory.sh", ""))
require.Error(t, err)
assert.Equal(t, identity.ErrProtectedFieldModified, errors.Cause(err))

Expand All @@ -157,7 +208,7 @@ func TestManager(t *testing.T) {

t.Run("method=RefreshVerifyAddress", func(t *testing.T) {
original := identity.NewIdentity(configuration.DefaultIdentityTraitsSchemaID)
original.Traits = identity.Traits(`{"email":"verifyme@ory.sh"}`)
original.Traits = newTraits("verifyme@ory.sh", "")
require.NoError(t, reg.IdentityManager().Create(context.Background(), original))

address, err := reg.IdentityPool().FindVerifiableAddressByValue(context.Background(), identity.VerifiableAddressTypeEmail, "verifyme@ory.sh")
Expand Down
9 changes: 8 additions & 1 deletion identity/pool.go
Expand Up @@ -74,8 +74,15 @@ type (
// UpdateIdentity updates an identity including its confidential / privileged / protected data.
UpdateIdentity(context.Context, *Identity) error

// GetClassified returns the identity including it's raw credentials. This should only be used internally.
// GetIdentityConfidential returns the identity including it's raw credentials. This should only be used internally.
GetIdentityConfidential(context.Context, uuid.UUID) (*Identity, error)

// ListVerifiableAddresses lists all tracked verifiable addresses, regardless of whether they are already verified
// or not.
ListVerifiableAddresses(ctx context.Context, page, itemsPerPage int) ([]VerifiableAddress, error)

// ListRecoveryAddresses lists all tracked recovery addresses.
ListRecoveryAddresses(ctx context.Context, page, itemsPerPage int) ([]RecoveryAddress, error)
}
)

Expand Down
12 changes: 9 additions & 3 deletions identity/stub/manager.schema.json
Expand Up @@ -15,9 +15,6 @@
"password": {
"identifier": true
}
},
"verification": {
"via": "email"
}
}
},
Expand All @@ -41,6 +38,15 @@
}
}
},
"email_recovery": {
"type": "string",
"format": "email",
"ory.sh/kratos": {
"recovery": {
"via": "email"
}
}
},
"unprotected": {
"type": "string"
}
Expand Down
17 changes: 17 additions & 0 deletions persistence/sql/persister_identity.go
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/ory/kratos/driver/configuration"
"github.com/ory/kratos/otp"
"github.com/ory/kratos/x"

"github.com/gobuffalo/pop/v5"
"github.com/gofrs/uuid"
Expand All @@ -27,6 +28,22 @@ import (
var _ identity.Pool = new(Persister)
var _ identity.PrivilegedPool = new(Persister)

func (p *Persister) ListVerifiableAddresses(ctx context.Context, page, itemsPerPage int) (a []identity.VerifiableAddress, err error) {
if err := p.GetConnection(ctx).Order("id desc").Paginate(page, x.MaxItemsPerPage(itemsPerPage)).All(&a); err != nil {
return nil, sqlcon.HandleError(err)
}

return a, err
}

func (p *Persister) ListRecoveryAddresses(ctx context.Context, page, itemsPerPage int) (a []identity.RecoveryAddress, err error) {
if err := p.GetConnection(ctx).Order("id desc").Paginate(page, x.MaxItemsPerPage(itemsPerPage)).All(&a); err != nil {
return nil, sqlcon.HandleError(err)
}

return a, err
}

func (p *Persister) FindByCredentialsIdentifier(ctx context.Context, ct identity.CredentialsType, match string) (*identity.Identity, *identity.Credentials, error) {
var cts []identity.CredentialsTypeTable
if err := p.GetConnection(ctx).All(&cts); err != nil {
Expand Down
3 changes: 2 additions & 1 deletion selfservice/strategy/oidc/provider_microsoft.go
Expand Up @@ -9,9 +9,10 @@ import (
"github.com/gofrs/uuid"

gooidc "github.com/coreos/go-oidc"
"github.com/ory/herodot"
"github.com/pkg/errors"
"golang.org/x/oauth2"

"github.com/ory/herodot"
)

type ProviderMicrosoft struct {
Expand Down
9 changes: 9 additions & 0 deletions x/maxitems.go
@@ -0,0 +1,9 @@
package x

// MaxItemsPerPage is used to prevent DoS attacks against large lists by limiting the items per page to 500.
func MaxItemsPerPage(is int) int {
if is > 500 {
return 500
}
return is
}

0 comments on commit 29979e6

Please sign in to comment.