Skip to content

Commit

Permalink
fix: improve identity post-processing
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr committed Jan 4, 2023
1 parent 4c098fe commit 5ae2be3
Show file tree
Hide file tree
Showing 15 changed files with 98 additions and 83 deletions.
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ require (
github.com/pkg/errors v0.9.1
github.com/pquerna/otp v1.4.0
github.com/rs/cors v1.8.2
github.com/samber/lo v1.37.0
github.com/sirupsen/logrus v1.9.0
github.com/slack-go/slack v0.7.4
github.com/spf13/cobra v1.6.1
Expand Down Expand Up @@ -315,6 +316,7 @@ require (
go.uber.org/atomic v1.10.0 // indirect
go.uber.org/multierr v1.7.0 // indirect
go.uber.org/zap v1.17.0 // indirect
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect
golang.org/x/mod v0.7.0 // indirect
golang.org/x/sys v0.3.0 // indirect
golang.org/x/term v0.3.0 // indirect
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1268,6 +1268,8 @@ github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQD
github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts=
github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts=
github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc=
github.com/samber/lo v1.37.0 h1:XjVcB8g6tgUp8rsPsJ2CvhClfImrpL04YpQHXeHPhRw=
github.com/samber/lo v1.37.0/go.mod h1:9vaz2O4o8oOnK23pd2TrXufcbdbJIa3b6cstBWKpopA=
github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da/go.mod h1:gi+0XIa01GRL2eRQVjQkKGqKF3SF9vZR/HnPullcV2E=
github.com/sassoftware/go-rpmutils v0.0.0-20190420191620-a8f1baeba37b/go.mod h1:am+Fp8Bt506lA3Rk3QCmSqmYmLMnPDhdDUcosQCAx+I=
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
Expand Down Expand Up @@ -1598,6 +1600,8 @@ golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u0
golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
golang.org/x/exp v0.0.0-20200331195152-e8c3332aa8e5/go.mod h1:4M0jN8W1tt0AVLNr8HDosyJCDCDuyL9N9+3m7wDWgKw=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
Expand Down
3 changes: 3 additions & 0 deletions identity/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ func (c *Credentials) setCredentials() error {
c.Type = c.IdentityCredentialType.Name
c.Identifiers = make([]string, 0, len(c.CredentialIdentifiers))
for _, id := range c.CredentialIdentifiers {
if c.NID != id.NID {
continue
}
c.Identifiers = append(c.Identifiers, id.Identifier)
}
return nil
Expand Down
11 changes: 5 additions & 6 deletions credentialmigrate/migrate.go → identity/credentials_migrate.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package credentialmigrate
package identity

import (
"encoding/json"
"fmt"

"github.com/tidwall/gjson"
"github.com/tidwall/sjson"

"github.com/pkg/errors"

"github.com/ory/kratos/identity"
)

func UpgradeWebAuthnCredentials(i *identity.Identity, c *identity.Credentials) (err error) {
if c.Type != identity.CredentialsTypeWebAuthn {
func UpgradeWebAuthnCredentials(i *Identity, c *Credentials) (err error) {
if c.Type != CredentialsTypeWebAuthn {
return nil
}

Expand Down Expand Up @@ -59,7 +58,7 @@ func UpgradeWebAuthnCredentials(i *identity.Identity, c *identity.Credentials) (
}

// UpgradeCredentials migrates a set of older WebAuthn credentials to newer ones.
func UpgradeCredentials(i *identity.Identity) error {
func UpgradeCredentials(i *Identity) error {
for k := range i.Credentials {
c := i.Credentials[k]
if err := UpgradeWebAuthnCredentials(i, &c); err != nil {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package credentialmigrate
package identity

import (
_ "embed"
Expand All @@ -11,7 +11,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/ory/kratos/identity"
"github.com/ory/x/snapshotx"
)

Expand All @@ -23,51 +22,51 @@ var webAuthnV1 []byte

func TestUpgradeCredentials(t *testing.T) {
t.Run("empty credentials", func(t *testing.T) {
i := &identity.Identity{}
i := &Identity{}

err := UpgradeCredentials(i)
require.NoError(t, err)
wc := identity.WithCredentialsAndAdminMetadataInJSON(*i)
wc := WithCredentialsAndAdminMetadataInJSON(*i)
snapshotx.SnapshotTExcept(t, &wc, nil)
})

identityID := uuid.FromStringOrNil("4d64fa08-20fc-450d-bebd-ebd7c7b6e249")
t.Run("type=webauthn", func(t *testing.T) {
t.Run("from=v0", func(t *testing.T) {
i := &identity.Identity{
i := &Identity{
ID: identityID,
Credentials: map[identity.CredentialsType]identity.Credentials{
identity.CredentialsTypeWebAuthn: {
Credentials: map[CredentialsType]Credentials{
CredentialsTypeWebAuthn: {
Identifiers: []string{"4d64fa08-20fc-450d-bebd-ebd7c7b6e249"},
Type: identity.CredentialsTypeWebAuthn,
Type: CredentialsTypeWebAuthn,
Version: 0,
Config: webAuthnV0,
}},
}

require.NoError(t, UpgradeCredentials(i))
wc := identity.WithCredentialsAndAdminMetadataInJSON(*i)
wc := WithCredentialsAndAdminMetadataInJSON(*i)
snapshotx.SnapshotTExcept(t, &wc, nil)

assert.Equal(t, 1, i.Credentials[identity.CredentialsTypeWebAuthn].Version)
assert.Equal(t, 1, i.Credentials[CredentialsTypeWebAuthn].Version)
})

t.Run("from=v1", func(t *testing.T) {
i := &identity.Identity{
i := &Identity{
ID: identityID,
Credentials: map[identity.CredentialsType]identity.Credentials{
identity.CredentialsTypeWebAuthn: {
Type: identity.CredentialsTypeWebAuthn,
Credentials: map[CredentialsType]Credentials{
CredentialsTypeWebAuthn: {
Type: CredentialsTypeWebAuthn,
Version: 1,
Config: webAuthnV1,
}},
}

require.NoError(t, UpgradeCredentials(i))
wc := identity.WithCredentialsAndAdminMetadataInJSON(*i)
wc := WithCredentialsAndAdminMetadataInJSON(*i)
snapshotx.SnapshotTExcept(t, &wc, nil)

assert.Equal(t, 1, i.Credentials[identity.CredentialsTypeWebAuthn].Version)
assert.Equal(t, 1, i.Credentials[CredentialsTypeWebAuthn].Version)
})
})
}
37 changes: 22 additions & 15 deletions identity/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"sync"
"time"

"github.com/samber/lo"

"github.com/gobuffalo/pop/v6"

"github.com/tidwall/sjson"
Expand Down Expand Up @@ -137,14 +139,21 @@ func (i *Identity) AfterEagerFind(tx *pop.Connection) error {
return err
}

return i.ValidateNID()
if err := i.validate(); err != nil {
return err
}

return UpgradeCredentials(i)
}

func (i *Identity) setCredentials(tx *pop.Connection) error {
creds := i.InternalCredentials
i.Credentials = make(map[CredentialsType]Credentials, len(creds))
for k := range creds {
cred := &creds[k]
if cred.NID != i.NID {
continue
}
if err := cred.AfterEagerFind(tx); err != nil {
return err

Expand Down Expand Up @@ -368,27 +377,25 @@ func (i WithCredentialsMetadataAndAdminMetadataInJSON) MarshalJSON() ([]byte, er
return json.Marshal(localIdentity(i))
}

func (i *Identity) ValidateNID() error {
func (i *Identity) validate() error {
expected := i.NID
if expected == uuid.Nil {
return errors.WithStack(herodot.ErrInternalServerError.WithReason("Received empty nid."))
}

for _, r := range i.RecoveryAddresses {
if r.NID != expected {
return errors.WithStack(herodot.ErrInternalServerError.WithReason("Mismatching nid for recovery addresses."))
}
}
i.RecoveryAddresses = lo.Filter(i.RecoveryAddresses, func(v RecoveryAddress, key int) bool {
return v.NID == expected
})

for _, r := range i.VerifiableAddresses {
if r.NID != expected {
return errors.WithStack(herodot.ErrInternalServerError.WithReason("Mismatching nid for verifiable addresses."))
}
}
i.VerifiableAddresses = lo.Filter(i.VerifiableAddresses, func(v VerifiableAddress, key int) bool {
return v.NID == expected
})

for _, r := range i.Credentials {
if r.NID != expected {
return errors.WithStack(herodot.ErrInternalServerError.WithReason("Mismatching nid for credentials."))
for k := range i.Credentials {
c := i.Credentials[k]
if c.NID != expected {
delete(i.Credentials, k)
continue
}
}

Expand Down
47 changes: 38 additions & 9 deletions identity/identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,24 +229,53 @@ func TestValidateNID(t *testing.T) {
nid := x.NewUUID()
for k, tc := range []struct {
i *Identity
expect *Identity
expectedErr bool
}{
{i: &Identity{}, expectedErr: true},
{i: &Identity{NID: nid}},
{i: &Identity{NID: nid, RecoveryAddresses: []RecoveryAddress{{NID: x.NewUUID()}}}, expectedErr: true},
{i: &Identity{NID: nid, VerifiableAddresses: []VerifiableAddress{{NID: x.NewUUID()}}}, expectedErr: true},
{i: &Identity{NID: nid, Credentials: map[CredentialsType]Credentials{CredentialsTypePassword: {NID: x.NewUUID()}}}, expectedErr: true},
{i: &Identity{NID: nid, RecoveryAddresses: []RecoveryAddress{{NID: nid}}, VerifiableAddresses: []VerifiableAddress{{NID: x.NewUUID()}}}, expectedErr: true},
{i: &Identity{NID: nid, RecoveryAddresses: []RecoveryAddress{{NID: x.NewUUID()}}, VerifiableAddresses: []VerifiableAddress{{NID: nid}}}, expectedErr: true},
{i: &Identity{NID: nid, RecoveryAddresses: []RecoveryAddress{{NID: nid}}, VerifiableAddresses: []VerifiableAddress{{NID: nid}}}, expectedErr: false},
{i: &Identity{NID: nid, Credentials: map[CredentialsType]Credentials{CredentialsTypePassword: {NID: x.NewUUID()}}, RecoveryAddresses: []RecoveryAddress{{NID: nid}}, VerifiableAddresses: []VerifiableAddress{{NID: nid}}}, expectedErr: true},
{i: &Identity{NID: nid, Credentials: map[CredentialsType]Credentials{CredentialsTypePassword: {NID: nid}}, RecoveryAddresses: []RecoveryAddress{{NID: nid}}, VerifiableAddresses: []VerifiableAddress{{NID: nid}}}},
{
i: &Identity{NID: nid, RecoveryAddresses: []RecoveryAddress{{NID: x.NewUUID()}}},
expect: &Identity{NID: nid, VerifiableAddresses: []VerifiableAddress{}, RecoveryAddresses: []RecoveryAddress{}},
},
{
i: &Identity{NID: nid, VerifiableAddresses: []VerifiableAddress{{NID: x.NewUUID()}}},
expect: &Identity{NID: nid, VerifiableAddresses: []VerifiableAddress{}, RecoveryAddresses: []RecoveryAddress{}},
},
{
i: &Identity{NID: nid, Credentials: map[CredentialsType]Credentials{CredentialsTypePassword: {NID: x.NewUUID()}}},
expect: &Identity{NID: nid, Credentials: map[CredentialsType]Credentials{}, VerifiableAddresses: []VerifiableAddress{}, RecoveryAddresses: []RecoveryAddress{}},
},
{
i: &Identity{NID: nid, VerifiableAddresses: []VerifiableAddress{{NID: x.NewUUID()}}, RecoveryAddresses: []RecoveryAddress{{NID: nid}}},
expect: &Identity{NID: nid, VerifiableAddresses: []VerifiableAddress{}, RecoveryAddresses: []RecoveryAddress{{NID: nid}}},
},
{
i: &Identity{NID: nid, VerifiableAddresses: []VerifiableAddress{{NID: nid}}, RecoveryAddresses: []RecoveryAddress{{NID: x.NewUUID()}}},
expect: &Identity{NID: nid, VerifiableAddresses: []VerifiableAddress{{NID: nid}}, RecoveryAddresses: []RecoveryAddress{}},
},
{
i: &Identity{NID: nid, VerifiableAddresses: []VerifiableAddress{{NID: nid}}, RecoveryAddresses: []RecoveryAddress{{NID: nid}}},
expect: &Identity{NID: nid, VerifiableAddresses: []VerifiableAddress{{NID: nid}}, RecoveryAddresses: []RecoveryAddress{{NID: nid}}},
},
{
i: &Identity{NID: nid, Credentials: map[CredentialsType]Credentials{CredentialsTypePassword: {NID: x.NewUUID()}}, RecoveryAddresses: []RecoveryAddress{{NID: nid}}, VerifiableAddresses: []VerifiableAddress{{NID: nid}}},
expect: &Identity{NID: nid, Credentials: map[CredentialsType]Credentials{}, RecoveryAddresses: []RecoveryAddress{{NID: nid}}, VerifiableAddresses: []VerifiableAddress{{NID: nid}}},
},
{
i: &Identity{NID: nid, Credentials: map[CredentialsType]Credentials{CredentialsTypePassword: {NID: nid}}, RecoveryAddresses: []RecoveryAddress{{NID: nid}}, VerifiableAddresses: []VerifiableAddress{{NID: nid}}},
expect: &Identity{NID: nid, Credentials: map[CredentialsType]Credentials{CredentialsTypePassword: {NID: nid}}, RecoveryAddresses: []RecoveryAddress{{NID: nid}}, VerifiableAddresses: []VerifiableAddress{{NID: nid}}},
},
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
err := tc.i.ValidateNID()
err := tc.i.validate()
if tc.expectedErr {
require.Error(t, err)
} else {
require.NoError(t, err)
if tc.expect != nil {
assert.EqualValues(t, tc.expect, tc.i)
}
}
})
}
Expand Down
File renamed without changes.
File renamed without changes.
31 changes: 2 additions & 29 deletions persistence/sql/persister_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ import (
"strings"
"time"

"github.com/ory/kratos/credentialmigrate"

"go.opentelemetry.io/otel/attribute"

"github.com/ory/x/otelx"
Expand Down Expand Up @@ -350,11 +348,11 @@ func (p *Persister) HydrateIdentityAssociations(ctx context.Context, i *identity
return err
}

if err := p.injectTraitsSchemaURL(ctx, i); err != nil {
if err := i.AfterEagerFind(con); err != nil {
return err
}

return p.afterFindIdentity(ctx, con, i, expand)
return p.injectTraitsSchemaURL(ctx, i)
}

func (p *Persister) ListIdentities(ctx context.Context, expand identity.Expandables, page, perPage int) (res []identity.Identity, err error) {
Expand Down Expand Up @@ -387,13 +385,6 @@ func (p *Persister) ListIdentities(ctx context.Context, expand identity.Expandab
schemaCache := map[string]string{}
for k := range is {
i := &is[k]
if err := p.afterFindIdentity(ctx, con, i, expand); err != nil {
return nil, err
}

if err := i.ValidateNID(); err != nil {
return nil, sqlcon.HandleError(err)
}

if u, ok := schemaCache[i.SchemaID]; ok {
i.SchemaURL = u
Expand Down Expand Up @@ -480,10 +471,6 @@ func (p *Persister) GetIdentity(ctx context.Context, id uuid.UUID, expand identi
return nil, sqlcon.HandleError(err)
}

if err := p.afterFindIdentity(ctx, con, &i, expand); err != nil {
return nil, err
}

if err := p.injectTraitsSchemaURL(ctx, &i); err != nil {
return nil, err
}
Expand Down Expand Up @@ -577,20 +564,6 @@ func (p *Persister) validateIdentity(ctx context.Context, i *identity.Identity)
return nil
}

func (p *Persister) afterFindIdentity(ctx context.Context, con *pop.Connection, i *identity.Identity, expand identity.Expandables) error {
if err := i.AfterEagerFind(con); err != nil {
return err
}

if expand.Has(identity.ExpandFieldCredentials) {
if err := credentialmigrate.UpgradeCredentials(i); err != nil {
return err
}
}

return nil
}

func (p *Persister) injectTraitsSchemaURL(ctx context.Context, i *identity.Identity) error {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.injectTraitsSchemaURL")
defer span.End()
Expand Down
9 changes: 5 additions & 4 deletions persistence/sql/persister_identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"testing"
"time"

"github.com/ory/kratos/credentialmigrate"
"github.com/ory/kratos/driver"
"github.com/ory/kratos/identity"

Expand Down Expand Up @@ -63,7 +62,7 @@ func (suite *PersisterTestSuite) TestIdentityExpand() {
expected := identity.NewIdentity(expandSchema.ID)
expected.Traits = identity.Traits(`{"email":"` + uuid.Must(uuid.NewV4()).String() + "@ory.sh" + `","name":"john doe"}`)
require.NoError(t, reg.IdentityManager().Create(ctx, expected))
require.NoError(t, credentialmigrate.UpgradeCredentials(expected))
require.NoError(t, identity.UpgradeCredentials(expected))

assert.NotEmpty(t, expected.RecoveryAddresses)
assert.NotEmpty(t, expected.VerifiableAddresses)
Expand Down Expand Up @@ -1023,8 +1022,10 @@ func (suite *PersisterTestSuite) TestIdentity() {
_, _, err = p.FindByCredentialsIdentifier(ctx, m[0].Name, "nid2")
require.ErrorIs(t, err, sqlcon.ErrNoRows)

_, err = p.GetIdentityConfidential(ctx, iid)
require.NoError(t, err, "expect an error because the nids are mixed up")
i, err = p.GetIdentityConfidential(ctx, iid)
require.NoError(t, err)
require.Len(t, i.Credentials, 1)
assert.Equal(t, "nid1", i.Credentials[m[0].Name].Identifiers[0])
})
})
}

0 comments on commit 5ae2be3

Please sign in to comment.