diff --git a/cmd/identities/delete_test.go b/cmd/identities/delete_test.go index 995a3030814..70b7282b365 100644 --- a/cmd/identities/delete_test.go +++ b/cmd/identities/delete_test.go @@ -37,7 +37,7 @@ func TestDeleteCmd(t *testing.T) { assert.Equal(t, i.ID.String(), gjson.Parse(stdOut).String()) // expect identity to be deleted - _, err := reg.Persister().GetIdentity(context.Background(), i.ID) + _, err := reg.Persister().GetIdentity(context.Background(), i.ID, identity.ExpandNothing) assert.True(t, errors.Is(err, sqlcon.ErrNoRows)) }) @@ -49,7 +49,7 @@ func TestDeleteCmd(t *testing.T) { assert.Equal(t, `["`+strings.Join(ids, "\",\"")+"\"]\n", stdOut) for _, i := range is { - _, err := reg.Persister().GetIdentity(context.Background(), i.ID) + _, err := reg.Persister().GetIdentity(context.Background(), i.ID, identity.ExpandNothing) assert.Error(t, err) } }) diff --git a/cmd/identities/import_test.go b/cmd/identities/import_test.go index d6768e59040..567cac25695 100644 --- a/cmd/identities/import_test.go +++ b/cmd/identities/import_test.go @@ -10,6 +10,8 @@ import ( "os" "testing" + "github.com/ory/kratos/identity" + "github.com/ory/kratos/cmd/identities" "github.com/gofrs/uuid" @@ -42,7 +44,7 @@ func TestImportCmd(t *testing.T) { id, err := uuid.FromString(gjson.Get(stdOut, "id").String()) require.NoError(t, err) - _, err = reg.Persister().GetIdentity(context.Background(), id) + _, err = reg.Persister().GetIdentity(context.Background(), id, identity.ExpandNothing) assert.NoError(t, err) }) @@ -69,12 +71,12 @@ func TestImportCmd(t *testing.T) { id, err := uuid.FromString(gjson.Get(stdOut, "0.id").String()) require.NoError(t, err) - _, err = reg.Persister().GetIdentity(context.Background(), id) + _, err = reg.Persister().GetIdentity(context.Background(), id, identity.ExpandNothing) assert.NoError(t, err) id, err = uuid.FromString(gjson.Get(stdOut, "1.id").String()) require.NoError(t, err) - _, err = reg.Persister().GetIdentity(context.Background(), id) + _, err = reg.Persister().GetIdentity(context.Background(), id, identity.ExpandNothing) assert.NoError(t, err) }) @@ -97,12 +99,12 @@ func TestImportCmd(t *testing.T) { id, err := uuid.FromString(gjson.Get(stdOut, "0.id").String()) require.NoError(t, err) - _, err = reg.Persister().GetIdentity(context.Background(), id) + _, err = reg.Persister().GetIdentity(context.Background(), id, identity.ExpandNothing) assert.NoError(t, err) id, err = uuid.FromString(gjson.Get(stdOut, "1.id").String()) require.NoError(t, err) - _, err = reg.Persister().GetIdentity(context.Background(), id) + _, err = reg.Persister().GetIdentity(context.Background(), id, identity.ExpandNothing) assert.NoError(t, err) }) @@ -119,7 +121,7 @@ func TestImportCmd(t *testing.T) { id, err := uuid.FromString(gjson.Get(stdOut, "id").String()) require.NoError(t, err) - _, err = reg.Persister().GetIdentity(context.Background(), id) + _, err = reg.Persister().GetIdentity(context.Background(), id, identity.ExpandNothing) assert.NoError(t, err) }) } diff --git a/credentialmigrate/migrate.go b/credentialmigrate/migrate.go deleted file mode 100644 index ef0035d9c28..00000000000 --- a/credentialmigrate/migrate.go +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright © 2023 Ory Corp -// SPDX-License-Identifier: Apache-2.0 - -package credentialmigrate - -import ( - "encoding/json" - - "github.com/pkg/errors" - - "github.com/ory/kratos/identity" - "github.com/ory/kratos/selfservice/strategy/webauthn" -) - -// UpgradeWebAuthnCredential migrates a webauthn credential from an older version to a newer version. -func UpgradeWebAuthnCredential(i *identity.Identity, ic *identity.Credentials, c *webauthn.CredentialsConfig) { - if ic.Version == 0 { - if len(c.UserHandle) == 0 { - c.UserHandle = i.ID[:] - } - - // We do not set c.IsPasswordless as it defaults to false anyways, which is the correct migration . - - ic.Version = 1 - } -} - -func UpgradeWebAuthnCredentials(i *identity.Identity, c *identity.Credentials) error { - if c.Type != identity.CredentialsTypeWebAuthn { - return nil - } - - var cred webauthn.CredentialsConfig - if err := json.Unmarshal(c.Config, &cred); err != nil { - return errors.WithStack(err) - } - - UpgradeWebAuthnCredential(i, c, &cred) - - updatedConf, err := json.Marshal(&cred) - if err != nil { - return errors.WithStack(err) - } - - c.Config = updatedConf - return nil -} - -// UpgradeCredentials migrates a set of older WebAuthn credentials to newer ones. -func UpgradeCredentials(i *identity.Identity) error { - for k := range i.Credentials { - c := i.Credentials[k] - if err := UpgradeWebAuthnCredentials(i, &c); err != nil { - return errors.WithStack(err) - } - i.Credentials[k] = c - } - return nil -} diff --git a/driver/registry_default.go b/driver/registry_default.go index 110e00d4f77..f6f1c3e1aa2 100644 --- a/driver/registry_default.go +++ b/driver/registry_default.go @@ -723,6 +723,7 @@ func (m *RegistryDefault) VerificationTokenPersister() link.VerificationTokenPer func (m *RegistryDefault) VerificationCodePersister() code.VerificationCodePersister { return m.Persister() } + func (m *RegistryDefault) Persister() persistence.Persister { return m.persister } diff --git a/go.mod b/go.mod index 957e6ce45d8..cd0b1a33109 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.19 replace ( github.com/bradleyjkemp/cupaloy/v2 => github.com/aeneasr/cupaloy/v2 v2.6.1-0.20210924214125-3dfdd01210a3 + github.com/gorilla/sessions => github.com/ory/sessions v1.2.2-0.20220110165800-b09c17334dc2 github.com/knadh/koanf => github.com/aeneasr/koanf v0.14.1-0.20211230115640-aa3902b3267a @@ -41,8 +42,8 @@ require ( github.com/go-swagger/go-swagger v0.30.3 github.com/gobuffalo/fizz v1.14.4 github.com/gobuffalo/httptest v1.5.2 - github.com/gobuffalo/pop/v6 v6.0.8 - github.com/gofrs/uuid v4.3.0+incompatible + github.com/gobuffalo/pop/v6 v6.1.1-0.20230102153939-35967190380a + github.com/gofrs/uuid v4.3.1+incompatible github.com/golang-jwt/jwt/v4 v4.1.0 github.com/golang/gddo v0.0.0-20190904175337-72a348e765d2 github.com/golang/mock v1.6.0 @@ -77,11 +78,12 @@ require ( github.com/ory/jsonschema/v3 v3.0.7 github.com/ory/mail/v3 v3.0.0 github.com/ory/nosurf v1.2.7 - github.com/ory/x v0.0.523 + github.com/ory/x v0.0.527 github.com/phayes/freeport v0.0.0-20180830031419-95f893ade6f2 github.com/pkg/errors v0.9.1 github.com/pquerna/otp v1.4.0 github.com/rs/cors v1.8.2 + github.com/samber/lo v1.37.0 github.com/sirupsen/logrus v1.9.0 github.com/slack-go/slack v0.7.4 github.com/spf13/cobra v1.6.1 @@ -313,6 +315,7 @@ require ( go.uber.org/atomic v1.10.0 // indirect go.uber.org/multierr v1.7.0 // indirect go.uber.org/zap v1.17.0 // indirect + golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect golang.org/x/mod v0.7.0 // indirect golang.org/x/sys v0.3.0 // indirect golang.org/x/term v0.3.0 // indirect diff --git a/go.sum b/go.sum index 6d2cbbf528e..83d74c2ff8a 100644 --- a/go.sum +++ b/go.sum @@ -517,8 +517,8 @@ github.com/gobuffalo/packr/v2 v2.0.9/go.mod h1:emmyGweYTm6Kdper+iywB6YK5YzuKchGt github.com/gobuffalo/packr/v2 v2.2.0/go.mod h1:CaAwI0GPIAv+5wKLtv8Afwl+Cm78K/I/VCm/3ptBN+0= github.com/gobuffalo/plush/v4 v4.1.16 h1:Y6jVVTLdg1BxRXDIbTJz+J8QRzEAtv5ZwYpGdIFR7VU= github.com/gobuffalo/plush/v4 v4.1.16/go.mod h1:6t7swVsarJ8qSLw1qyAH/KbrcSTwdun2ASEQkOznakg= -github.com/gobuffalo/pop/v6 v6.0.8 h1:9+5ShHYh3x9NDFCITfm/gtKDDRSgOwiY7kA0Hf7N9aQ= -github.com/gobuffalo/pop/v6 v6.0.8/go.mod h1:f4JQ4Zvkffcevz+t+XAwBLStD7IQs19DiIGIDFYw1eA= +github.com/gobuffalo/pop/v6 v6.1.1-0.20230102153939-35967190380a h1:O6PLVjaR9oDlaU+WOwfwNLvMFik0zdmvAyExIbJhGcI= +github.com/gobuffalo/pop/v6 v6.1.1-0.20230102153939-35967190380a/go.mod h1:Y3nCI31Zx40ffCnpQsYCMOWvR6f17K+IukEg0EHNxaQ= github.com/gobuffalo/syncx v0.0.0-20190224160051-33c29581e754/go.mod h1:HhnNqWY95UYwwW3uSASeV7vtgYkT2t16hJgV3AEPUpw= github.com/gobuffalo/tags/v3 v3.1.4 h1:X/ydLLPhgXV4h04Hp2xlbI2oc5MDaa7eub6zw8oHjsM= github.com/gobuffalo/tags/v3 v3.1.4/go.mod h1:ArRNo3ErlHO8BtdA0REaZxijuWnWzF6PUXngmMXd2I0= @@ -533,8 +533,8 @@ github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw= github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gofrs/uuid v4.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= -github.com/gofrs/uuid v4.3.0+incompatible h1:CaSVZxm5B+7o45rtab4jC2G37WGYX1zQfuU2i6DSvnc= -github.com/gofrs/uuid v4.3.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/gofrs/uuid v4.3.1+incompatible h1:0/KbAdpx3UXAx1kEOWHJeOkpbgRFGHVgv+CFIY7dBJI= +github.com/gofrs/uuid v4.3.1+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= @@ -806,7 +806,6 @@ github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLf github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo= github.com/inhies/go-bytesize v0.0.0-20220417184213-4913239db9cf h1:FtEj8sfIcaaBfAKrE1Cwb61YDtYq9JxChK1c7AKce7s= github.com/inhies/go-bytesize v0.0.0-20220417184213-4913239db9cf/go.mod h1:yrqSXGoD/4EKfF26AOGzscPOgTTJcyAwM2rpixWT+t4= -github.com/instana/testify v1.6.2-0.20200721153833-94b1851f4d65 h1:T25FL3WEzgmKB0m6XCJNZ65nw09/QIp3T1yXr487D+A= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= @@ -1141,8 +1140,8 @@ github.com/ory/sessions v1.2.2-0.20220110165800-b09c17334dc2 h1:zm6sDvHy/U9XrGpi github.com/ory/sessions v1.2.2-0.20220110165800-b09c17334dc2/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= github.com/ory/viper v1.7.5 h1:+xVdq7SU3e1vNaCsk/ixsfxE4zylk1TJUiJrY647jUE= github.com/ory/viper v1.7.5/go.mod h1:ypOuyJmEUb3oENywQZRgeAMwqgOyDqwboO1tj3DjTaM= -github.com/ory/x v0.0.523 h1:vn8e+8tV3RqD8RlvoE6lLPUnjpjua1ExJDMFy3Z5TAQ= -github.com/ory/x v0.0.523/go.mod h1:ayJio5x/fK4RwTgfgzs3JetOaaOSxso9hQjc3mFY8z0= +github.com/ory/x v0.0.527 h1:K6MmsYqT1NMb8VQ4hhn9q6NnrnecwNQJXc1bEoixQ8Y= +github.com/ory/x v0.0.527/go.mod h1:XBqhPZRppPHTxtsE0l0oI/B2Onf1QJtMRGPh3CpEpA0= github.com/otiai10/copy v1.2.0/go.mod h1:rrF5dJ5F0t/EWSYODDu4j9/vEeYHMkc8jt0zJChqQWw= github.com/otiai10/curr v0.0.0-20150429015615-9b4961190c95/go.mod h1:9qAhocn7zKJG+0mI8eUu6xqkFDYS2kb2saOteoSB3cE= github.com/otiai10/curr v1.0.0/go.mod h1:LskTG5wDwr8Rs+nNQ+1LlxRjAtTZZjtJW4rMXl6j4vs= @@ -1269,6 +1268,8 @@ github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQD github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc= +github.com/samber/lo v1.37.0 h1:XjVcB8g6tgUp8rsPsJ2CvhClfImrpL04YpQHXeHPhRw= +github.com/samber/lo v1.37.0/go.mod h1:9vaz2O4o8oOnK23pd2TrXufcbdbJIa3b6cstBWKpopA= github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da/go.mod h1:gi+0XIa01GRL2eRQVjQkKGqKF3SF9vZR/HnPullcV2E= github.com/sassoftware/go-rpmutils v0.0.0-20190420191620-a8f1baeba37b/go.mod h1:am+Fp8Bt506lA3Rk3QCmSqmYmLMnPDhdDUcosQCAx+I= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= @@ -1338,7 +1339,6 @@ github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tL github.com/spf13/cobra v1.0.0/go.mod h1:/6GTrnGXV9HjY+aR4k0oJ5tcvakLuG6EuKReYlHNrgE= github.com/spf13/cobra v1.1.1/go.mod h1:WnodtKOvamDL/PwE2M4iKs8aMDBZ5Q5klgD3qfVJQMI= github.com/spf13/cobra v1.1.3/go.mod h1:pGADOWyqRD/YMrPZigI/zbliZ2wVD/23d+is3pSWzOo= -github.com/spf13/cobra v1.5.0/go.mod h1:dWXEIy2H428czQCjInthrTRUg7yKbok+2Qi/yBIJoUM= github.com/spf13/cobra v1.6.1 h1:o94oiPyS4KD1mPy2fmcYYHHfCxLqYjJOhGsCHFZtEzA= github.com/spf13/cobra v1.6.1/go.mod h1:IOw/AERYS7UzyrGinqmz6HLUo219MORXGxhbaJUqzrY= github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= @@ -1600,6 +1600,8 @@ golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= golang.org/x/exp v0.0.0-20200331195152-e8c3332aa8e5/go.mod h1:4M0jN8W1tt0AVLNr8HDosyJCDCDuyL9N9+3m7wDWgKw= +golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM= +golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= diff --git a/identity/.snapshots/TestHandler-case=should_be_able_to_import_users-with_cleartext_password_and_oidc_credentials.json b/identity/.snapshots/TestHandler-case=should_be_able_to_import_users-with_cleartext_password_and_oidc_credentials.json index 156a9b518d4..9c2afdacbd1 100644 --- a/identity/.snapshots/TestHandler-case=should_be_able_to_import_users-with_cleartext_password_and_oidc_credentials.json +++ b/identity/.snapshots/TestHandler-case=should_be_able_to_import_users-with_cleartext_password_and_oidc_credentials.json @@ -2,10 +2,6 @@ "credentials": { "oidc": { "type": "oidc", - "identifiers": [ - "google:import-2", - "github:import-2" - ], "config": { "providers": [ { diff --git a/credentialmigrate/.snapshots/TestUpgradeCredentials-empty_credentials.json b/identity/.snapshots/TestUpgradeCredentials-empty_credentials.json similarity index 100% rename from credentialmigrate/.snapshots/TestUpgradeCredentials-empty_credentials.json rename to identity/.snapshots/TestUpgradeCredentials-empty_credentials.json diff --git a/credentialmigrate/.snapshots/TestUpgradeCredentials-type=webauthn-from=v0.json b/identity/.snapshots/TestUpgradeCredentials-type=webauthn-from=v0.json similarity index 66% rename from credentialmigrate/.snapshots/TestUpgradeCredentials-type=webauthn-from=v0.json rename to identity/.snapshots/TestUpgradeCredentials-type=webauthn-from=v0.json index 4c05bac5cb7..83c81edb6e0 100644 --- a/credentialmigrate/.snapshots/TestUpgradeCredentials-type=webauthn-from=v0.json +++ b/identity/.snapshots/TestUpgradeCredentials-type=webauthn-from=v0.json @@ -20,6 +20,19 @@ "display_name": "asdf", "added_at": "2022-02-28T16:40:39Z", "is_passwordless": false + }, + { + "id": "1Q4LaIJ9NiqS1r0CQpWY+K0gMvhOq4yk5BHuO/YlitcurSpBK7weDXOvBcuN4lvn6DAmjGfmj/J/6bpOmtdT8Q==", + "public_key": "pQECAyYgASFYILAYFLoH1T8bQMSbPrNBCMMS5U7OFWRwv2U+GkAoiBADIlggBv+8ni7XVZYBB8ufMbP/d9fDxbmOkVVHOgcJifnoOR4=", + "attestation_type": "none", + "authenticator": { + "aaguid": "AAAAAAAAAAAAAAAAAAAAAA==", + "sign_count": 4, + "clone_warning": false + }, + "display_name": "asdf", + "added_at": "2022-02-28T16:40:39Z", + "is_passwordless": false } ], "user_handle": "TWT6CCD8RQ2+vevXx7biSQ==" diff --git a/credentialmigrate/.snapshots/TestUpgradeCredentials-type=webauthn-from=v1.json b/identity/.snapshots/TestUpgradeCredentials-type=webauthn-from=v1.json similarity index 100% rename from credentialmigrate/.snapshots/TestUpgradeCredentials-type=webauthn-from=v1.json rename to identity/.snapshots/TestUpgradeCredentials-type=webauthn-from=v1.json diff --git a/identity/credentials.go b/identity/credentials.go index 98376142f3a..7b1c485579b 100644 --- a/identity/credentials.go +++ b/identity/credentials.go @@ -8,6 +8,8 @@ import ( "reflect" "time" + "github.com/gobuffalo/pop/v6" + "github.com/ory/kratos/ui/node" "github.com/gofrs/uuid" @@ -84,8 +86,6 @@ const ( type Credentials struct { ID uuid.UUID `json:"-" db:"id"` - CredentialTypeID uuid.UUID `json:"-" db:"identity_credential_type_id"` - // Type discriminates between different types of credentials. Type CredentialsType `json:"type" db:"-"` @@ -107,6 +107,30 @@ type Credentials struct { // UpdatedAt is a helper struct field for gobuffalo.pop. UpdatedAt time.Time `json:"updated_at" db:"updated_at"` NID uuid.UUID `json:"-" faker:"-" db:"nid"` + + IdentityCredentialTypeID uuid.UUID `json:"-" db:"identity_credential_type_id"` + IdentityCredentialType CredentialsTypeTable `json:"-" faker:"-" belongs_to:"identity_credential_types"` + CredentialIdentifiers CredentialIdentifierCollection `json:"-" faker:"-" has_many:"identity_credential_identifiers" fk_id:"identity_credential_id" order_by:"id asc"` +} + +func (c *Credentials) AfterEagerFind(tx *pop.Connection) error { + return c.setCredentials() +} + +func (c *Credentials) setCredentials() error { + c.Type = c.IdentityCredentialType.Name + c.Identifiers = make([]string, 0, len(c.CredentialIdentifiers)) + for _, id := range c.CredentialIdentifiers { + if c.NID != id.NID { + continue + } + c.Identifiers = append(c.Identifiers, id.Identifier) + } + return nil +} + +func (c Credentials) TableName(ctx context.Context) string { + return "identity_credentials" } type ( @@ -158,10 +182,6 @@ func (c CredentialsCollection) TableName(ctx context.Context) string { return "identity_credentials" } -func (c Credentials) TableName(ctx context.Context) string { - return "identity_credentials" -} - func (c CredentialIdentifierCollection) TableName(ctx context.Context) string { return "identity_credential_identifiers" } diff --git a/identity/credentials_migrate.go b/identity/credentials_migrate.go new file mode 100644 index 00000000000..4da2e14fd7b --- /dev/null +++ b/identity/credentials_migrate.go @@ -0,0 +1,66 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package identity + +import ( + "encoding/json" + "fmt" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + + "github.com/pkg/errors" +) + +func UpgradeWebAuthnCredentials(i *Identity, c *Credentials) (err error) { + if c.Type != CredentialsTypeWebAuthn { + return nil + } + + version := c.Version + if version == 0 { + if gjson.GetBytes(c.Config, "user_handle").String() == "" { + id, err := json.Marshal(i.ID[:]) + if err != nil { + return errors.WithStack(err) + } + + c.Config, err = sjson.SetRawBytes(c.Config, "user_handle", id) + if err != nil { + return errors.WithStack(err) + } + } + + var index = -1 + var err error + gjson.GetBytes(c.Config, "credentials").ForEach(func(key, value gjson.Result) bool { + index++ + + if value.Get("is_passwordless").Exists() { + return true + } + + c.Config, err = sjson.SetBytes(c.Config, fmt.Sprintf("credentials.%d.is_passwordless", index), false) + return err == nil + }) + if err != nil { + return errors.WithStack(err) + } + + c.Version = 1 + } + return nil +} + +// UpgradeCredentials migrates a set of older WebAuthn credentials to newer ones. +func UpgradeCredentials(i *Identity) error { + for k := range i.Credentials { + c := i.Credentials[k] + if err := UpgradeWebAuthnCredentials(i, &c); err != nil { + return errors.WithStack(err) + } + i.Credentials[k] = c + } + return nil +} diff --git a/credentialmigrate/migrate_test.go b/identity/credentials_migrate_test.go similarity index 59% rename from credentialmigrate/migrate_test.go rename to identity/credentials_migrate_test.go index 61519a9adfb..2916bc892bd 100644 --- a/credentialmigrate/migrate_test.go +++ b/identity/credentials_migrate_test.go @@ -1,7 +1,7 @@ // Copyright © 2023 Ory Corp // SPDX-License-Identifier: Apache-2.0 -package credentialmigrate +package identity import ( _ "embed" @@ -11,7 +11,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/ory/kratos/identity" "github.com/ory/x/snapshotx" ) @@ -23,51 +22,51 @@ var webAuthnV1 []byte func TestUpgradeCredentials(t *testing.T) { t.Run("empty credentials", func(t *testing.T) { - i := &identity.Identity{} + i := &Identity{} err := UpgradeCredentials(i) require.NoError(t, err) - wc := identity.WithCredentialsAndAdminMetadataInJSON(*i) + wc := WithCredentialsAndAdminMetadataInJSON(*i) snapshotx.SnapshotTExcept(t, &wc, nil) }) identityID := uuid.FromStringOrNil("4d64fa08-20fc-450d-bebd-ebd7c7b6e249") t.Run("type=webauthn", func(t *testing.T) { t.Run("from=v0", func(t *testing.T) { - i := &identity.Identity{ + i := &Identity{ ID: identityID, - Credentials: map[identity.CredentialsType]identity.Credentials{ - identity.CredentialsTypeWebAuthn: { + Credentials: map[CredentialsType]Credentials{ + CredentialsTypeWebAuthn: { Identifiers: []string{"4d64fa08-20fc-450d-bebd-ebd7c7b6e249"}, - Type: identity.CredentialsTypeWebAuthn, + Type: CredentialsTypeWebAuthn, Version: 0, Config: webAuthnV0, }}, } require.NoError(t, UpgradeCredentials(i)) - wc := identity.WithCredentialsAndAdminMetadataInJSON(*i) + wc := WithCredentialsAndAdminMetadataInJSON(*i) snapshotx.SnapshotTExcept(t, &wc, nil) - assert.Equal(t, 1, i.Credentials[identity.CredentialsTypeWebAuthn].Version) + assert.Equal(t, 1, i.Credentials[CredentialsTypeWebAuthn].Version) }) t.Run("from=v1", func(t *testing.T) { - i := &identity.Identity{ + i := &Identity{ ID: identityID, - Credentials: map[identity.CredentialsType]identity.Credentials{ - identity.CredentialsTypeWebAuthn: { - Type: identity.CredentialsTypeWebAuthn, + Credentials: map[CredentialsType]Credentials{ + CredentialsTypeWebAuthn: { + Type: CredentialsTypeWebAuthn, Version: 1, Config: webAuthnV1, }}, } require.NoError(t, UpgradeCredentials(i)) - wc := identity.WithCredentialsAndAdminMetadataInJSON(*i) + wc := WithCredentialsAndAdminMetadataInJSON(*i) snapshotx.SnapshotTExcept(t, &wc, nil) - assert.Equal(t, 1, i.Credentials[identity.CredentialsTypeWebAuthn].Version) + assert.Equal(t, 1, i.Credentials[CredentialsTypeWebAuthn].Version) }) }) } diff --git a/identity/expandables.go b/identity/expandables.go new file mode 100644 index 00000000000..a0abf73ff28 --- /dev/null +++ b/identity/expandables.go @@ -0,0 +1,45 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package identity + +import "github.com/ory/x/sqlxx" + +type Expandable = sqlxx.Expandable +type Expandables = sqlxx.Expandables + +const ( + ExpandFieldVerifiableAddresses Expandable = "VerifiableAddresses" + ExpandFieldRecoveryAddresses Expandable = "RecoveryAddresses" + ExpandFieldCredentials Expandable = "InternalCredentials" + ExpandFieldCredentialType Expandable = "InternalCredentials.IdentityCredentialType" + ExpandFieldCredentialIdentifiers Expandable = "InternalCredentials.CredentialIdentifiers" +) + +// ExpandNothing expands nothing +var ExpandNothing = Expandables{} + +// ExpandDefault expands the default fields: +// +// - Verifiable addresses +// - Recovery addresses +var ExpandDefault = Expandables{ + ExpandFieldVerifiableAddresses, + ExpandFieldRecoveryAddresses, +} + +// ExpandCredentials expands the identity's credentials. +var ExpandCredentials = Expandables{ + ExpandFieldCredentials, + ExpandFieldCredentialType, + ExpandFieldCredentialIdentifiers, +} + +// ExpandEverything expands all the fields of an identity. +var ExpandEverything = Expandables{ + ExpandFieldVerifiableAddresses, + ExpandFieldRecoveryAddresses, + ExpandFieldCredentials, + ExpandFieldCredentialType, + ExpandFieldCredentialIdentifiers, +} diff --git a/identity/handler.go b/identity/handler.go index 8fa53b480d2..5b964676e8e 100644 --- a/identity/handler.go +++ b/identity/handler.go @@ -136,7 +136,7 @@ type listIdentitiesParameters struct { // default: errorGeneric func (h *Handler) list(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { page, itemsPerPage := x.ParsePagination(r) - is, err := h.r.IdentityPool().ListIdentities(r.Context(), page, itemsPerPage) + is, err := h.r.IdentityPool().ListIdentities(r.Context(), ExpandDefault, page, itemsPerPage) if err != nil { h.r.Writer().WriteError(w, r, err) return diff --git a/identity/handler_test.go b/identity/handler_test.go index ea0acc4a77a..a87d7c4cdb8 100644 --- a/identity/handler_test.go +++ b/identity/handler_test.go @@ -218,7 +218,11 @@ func TestHandler(t *testing.T) { actual, err := reg.PrivilegedIdentityPool().GetIdentityConfidential(ctx, uuid.FromStringOrNil(res.Get("id").String())) require.NoError(t, err) - snapshotx.SnapshotTExceptMatchingKeys(t, identity.WithCredentialsAndAdminMetadataInJSON(*actual), append(ignoreDefault, "hashed_password")) + snapshotx.SnapshotT(t, identity.WithCredentialsAndAdminMetadataInJSON(*actual), snapshotx.ExceptNestedKeys(append(ignoreDefault, "hashed_password")...), snapshotx.ExceptPaths("credentials.oidc.identifiers")) + identifiers := actual.Credentials[identity.CredentialsTypeOIDC].Identifiers + assert.Len(t, identifiers, 2) + assert.Contains(t, identifiers, "google:import-2") + assert.Contains(t, identifiers, "github:import-2") require.NoError(t, hash.Compare(ctx, []byte("123456"), []byte(gjson.GetBytes(actual.Credentials[identity.CredentialsTypePassword].Config, "hashed_password").String()))) }) diff --git a/identity/identity.go b/identity/identity.go index d2dbae879fe..5ec6e8321ab 100644 --- a/identity/identity.go +++ b/identity/identity.go @@ -11,6 +11,10 @@ import ( "sync" "time" + "github.com/samber/lo" + + "github.com/gobuffalo/pop/v6" + "github.com/tidwall/sjson" "github.com/tidwall/gjson" @@ -119,6 +123,9 @@ type Identity struct { // Store metadata about the user which is only accessible through admin APIs such as `GET /admin/identities/`. MetadataAdmin sqlxx.NullJSONRawMessage `json:"metadata_admin,omitempty" faker:"-" db:"metadata_admin"` + // InternalCredentials is an internal representation of the credentials. + InternalCredentials CredentialsCollection `json:"-" faker:"-" has_many:"identity_credentials" fk_id:"identity_id" order_by:"id asc"` + // CreatedAt is a helper struct field for gobuffalo.pop. CreatedAt time.Time `json:"created_at" db:"created_at"` @@ -127,6 +134,36 @@ type Identity struct { NID uuid.UUID `json:"-" faker:"-" db:"nid"` } +func (i *Identity) AfterEagerFind(tx *pop.Connection) error { + if err := i.setCredentials(tx); err != nil { + return err + } + + if err := i.validate(); err != nil { + return err + } + + return UpgradeCredentials(i) +} + +func (i *Identity) setCredentials(tx *pop.Connection) error { + creds := i.InternalCredentials + i.Credentials = make(map[CredentialsType]Credentials, len(creds)) + for k := range creds { + cred := &creds[k] + if cred.NID != i.NID { + continue + } + if err := cred.AfterEagerFind(tx); err != nil { + return err + + } + i.Credentials[cred.Type] = *cred + } + + return nil +} + // Traits represent an identity's traits. The identity is able to create, modify, and delete traits // in a self-service manner. The input will always be validated against the JSON Schema defined // in `schema_url`. @@ -340,27 +377,25 @@ func (i WithCredentialsMetadataAndAdminMetadataInJSON) MarshalJSON() ([]byte, er return json.Marshal(localIdentity(i)) } -func (i *Identity) ValidateNID() error { +func (i *Identity) validate() error { expected := i.NID if expected == uuid.Nil { return errors.WithStack(herodot.ErrInternalServerError.WithReason("Received empty nid.")) } - for _, r := range i.RecoveryAddresses { - if r.NID != expected { - return errors.WithStack(herodot.ErrInternalServerError.WithReason("Mismatching nid for recovery addresses.")) - } - } + i.RecoveryAddresses = lo.Filter(i.RecoveryAddresses, func(v RecoveryAddress, key int) bool { + return v.NID == expected + }) - for _, r := range i.VerifiableAddresses { - if r.NID != expected { - return errors.WithStack(herodot.ErrInternalServerError.WithReason("Mismatching nid for verifiable addresses.")) - } - } + i.VerifiableAddresses = lo.Filter(i.VerifiableAddresses, func(v VerifiableAddress, key int) bool { + return v.NID == expected + }) - for _, r := range i.Credentials { - if r.NID != expected { - return errors.WithStack(herodot.ErrInternalServerError.WithReason("Mismatching nid for credentials.")) + for k := range i.Credentials { + c := i.Credentials[k] + if c.NID != expected { + delete(i.Credentials, k) + continue } } diff --git a/identity/identity_test.go b/identity/identity_test.go index 4bcef71720e..e14fee21436 100644 --- a/identity/identity_test.go +++ b/identity/identity_test.go @@ -229,24 +229,53 @@ func TestValidateNID(t *testing.T) { nid := x.NewUUID() for k, tc := range []struct { i *Identity + expect *Identity expectedErr bool }{ + {i: &Identity{}, expectedErr: true}, {i: &Identity{NID: nid}}, - {i: &Identity{NID: nid, RecoveryAddresses: []RecoveryAddress{{NID: x.NewUUID()}}}, expectedErr: true}, - {i: &Identity{NID: nid, VerifiableAddresses: []VerifiableAddress{{NID: x.NewUUID()}}}, expectedErr: true}, - {i: &Identity{NID: nid, Credentials: map[CredentialsType]Credentials{CredentialsTypePassword: {NID: x.NewUUID()}}}, expectedErr: true}, - {i: &Identity{NID: nid, RecoveryAddresses: []RecoveryAddress{{NID: nid}}, VerifiableAddresses: []VerifiableAddress{{NID: x.NewUUID()}}}, expectedErr: true}, - {i: &Identity{NID: nid, RecoveryAddresses: []RecoveryAddress{{NID: x.NewUUID()}}, VerifiableAddresses: []VerifiableAddress{{NID: nid}}}, expectedErr: true}, - {i: &Identity{NID: nid, RecoveryAddresses: []RecoveryAddress{{NID: nid}}, VerifiableAddresses: []VerifiableAddress{{NID: nid}}}, expectedErr: false}, - {i: &Identity{NID: nid, Credentials: map[CredentialsType]Credentials{CredentialsTypePassword: {NID: x.NewUUID()}}, RecoveryAddresses: []RecoveryAddress{{NID: nid}}, VerifiableAddresses: []VerifiableAddress{{NID: nid}}}, expectedErr: true}, - {i: &Identity{NID: nid, Credentials: map[CredentialsType]Credentials{CredentialsTypePassword: {NID: nid}}, RecoveryAddresses: []RecoveryAddress{{NID: nid}}, VerifiableAddresses: []VerifiableAddress{{NID: nid}}}}, + { + i: &Identity{NID: nid, RecoveryAddresses: []RecoveryAddress{{NID: x.NewUUID()}}}, + expect: &Identity{NID: nid, VerifiableAddresses: []VerifiableAddress{}, RecoveryAddresses: []RecoveryAddress{}}, + }, + { + i: &Identity{NID: nid, VerifiableAddresses: []VerifiableAddress{{NID: x.NewUUID()}}}, + expect: &Identity{NID: nid, VerifiableAddresses: []VerifiableAddress{}, RecoveryAddresses: []RecoveryAddress{}}, + }, + { + i: &Identity{NID: nid, Credentials: map[CredentialsType]Credentials{CredentialsTypePassword: {NID: x.NewUUID()}}}, + expect: &Identity{NID: nid, Credentials: map[CredentialsType]Credentials{}, VerifiableAddresses: []VerifiableAddress{}, RecoveryAddresses: []RecoveryAddress{}}, + }, + { + i: &Identity{NID: nid, VerifiableAddresses: []VerifiableAddress{{NID: x.NewUUID()}}, RecoveryAddresses: []RecoveryAddress{{NID: nid}}}, + expect: &Identity{NID: nid, VerifiableAddresses: []VerifiableAddress{}, RecoveryAddresses: []RecoveryAddress{{NID: nid}}}, + }, + { + i: &Identity{NID: nid, VerifiableAddresses: []VerifiableAddress{{NID: nid}}, RecoveryAddresses: []RecoveryAddress{{NID: x.NewUUID()}}}, + expect: &Identity{NID: nid, VerifiableAddresses: []VerifiableAddress{{NID: nid}}, RecoveryAddresses: []RecoveryAddress{}}, + }, + { + i: &Identity{NID: nid, VerifiableAddresses: []VerifiableAddress{{NID: nid}}, RecoveryAddresses: []RecoveryAddress{{NID: nid}}}, + expect: &Identity{NID: nid, VerifiableAddresses: []VerifiableAddress{{NID: nid}}, RecoveryAddresses: []RecoveryAddress{{NID: nid}}}, + }, + { + i: &Identity{NID: nid, Credentials: map[CredentialsType]Credentials{CredentialsTypePassword: {NID: x.NewUUID()}}, RecoveryAddresses: []RecoveryAddress{{NID: nid}}, VerifiableAddresses: []VerifiableAddress{{NID: nid}}}, + expect: &Identity{NID: nid, Credentials: map[CredentialsType]Credentials{}, RecoveryAddresses: []RecoveryAddress{{NID: nid}}, VerifiableAddresses: []VerifiableAddress{{NID: nid}}}, + }, + { + i: &Identity{NID: nid, Credentials: map[CredentialsType]Credentials{CredentialsTypePassword: {NID: nid}}, RecoveryAddresses: []RecoveryAddress{{NID: nid}}, VerifiableAddresses: []VerifiableAddress{{NID: nid}}}, + expect: &Identity{NID: nid, Credentials: map[CredentialsType]Credentials{CredentialsTypePassword: {NID: nid}}, RecoveryAddresses: []RecoveryAddress{{NID: nid}}, VerifiableAddresses: []VerifiableAddress{{NID: nid}}}, + }, } { t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { - err := tc.i.ValidateNID() + err := tc.i.validate() if tc.expectedErr { require.Error(t, err) } else { require.NoError(t, err) + if tc.expect != nil { + assert.EqualValues(t, tc.expect, tc.i) + } } }) } diff --git a/identity/manager.go b/identity/manager.go index fd17c8b9074..039db4a666e 100644 --- a/identity/manager.go +++ b/identity/manager.go @@ -7,6 +7,9 @@ import ( "context" "reflect" + "github.com/ory/kratos/x" + "github.com/ory/x/otelx" + "github.com/ory/kratos/driver/config" "github.com/gofrs/uuid" @@ -28,6 +31,7 @@ type ( managerDependencies interface { config.Provider PoolProvider + x.TracingProvider courier.Provider ValidationProvider ActiveCredentialsCounterStrategyProvider @@ -39,48 +43,54 @@ type ( r managerDependencies } - managerOptions struct { + ManagerOptions struct { ExposeValidationErrors bool AllowWriteProtectedTraits bool } - ManagerOption func(*managerOptions) + ManagerOption func(*ManagerOptions) ) func NewManager(r managerDependencies) *Manager { return &Manager{r: r} } -func ManagerExposeValidationErrorsForInternalTypeAssertion(options *managerOptions) { +func ManagerExposeValidationErrorsForInternalTypeAssertion(options *ManagerOptions) { options.ExposeValidationErrors = true } -func ManagerAllowWriteProtectedTraits(options *managerOptions) { +func ManagerAllowWriteProtectedTraits(options *ManagerOptions) { options.AllowWriteProtectedTraits = true } -func newManagerOptions(opts []ManagerOption) *managerOptions { - var o managerOptions +func newManagerOptions(opts []ManagerOption) *ManagerOptions { + var o ManagerOptions for _, f := range opts { f(&o) } return &o } -func (m *Manager) Create(ctx context.Context, i *Identity, opts ...ManagerOption) error { +func (m *Manager) Create(ctx context.Context, i *Identity, opts ...ManagerOption) (err error) { + ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.Create") + defer otelx.End(span, &err) + if i.SchemaID == "" { i.SchemaID = m.r.Config().DefaultIdentityTraitsSchemaID(ctx) } o := newManagerOptions(opts) - if err := m.validate(ctx, i, o); err != nil { + if err := m.ValidateIdentity(ctx, i, o); err != nil { return err } return m.r.IdentityPool().(PrivilegedPool).CreateIdentity(ctx, i) } -func (m *Manager) requiresPrivilegedAccess(_ context.Context, original, updated *Identity, o *managerOptions) error { +func (m *Manager) requiresPrivilegedAccess(ctx context.Context, original, updated *Identity, o *ManagerOptions) (err error) { + _, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.requiresPrivilegedAccess") + defer otelx.End(span, &err) + if !o.AllowWriteProtectedTraits { if !CredentialsEqual(updated.Credentials, original.Credentials) { // reset the identity @@ -99,9 +109,12 @@ func (m *Manager) requiresPrivilegedAccess(_ context.Context, original, updated return nil } -func (m *Manager) Update(ctx context.Context, updated *Identity, opts ...ManagerOption) error { +func (m *Manager) Update(ctx context.Context, updated *Identity, opts ...ManagerOption) (err error) { + ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.Update") + defer otelx.End(span, &err) + o := newManagerOptions(opts) - if err := m.validate(ctx, updated, o); err != nil { + if err := m.ValidateIdentity(ctx, updated, o); err != nil { return err } @@ -117,7 +130,10 @@ func (m *Manager) Update(ctx context.Context, updated *Identity, opts ...Manager return m.r.IdentityPool().(PrivilegedPool).UpdateIdentity(ctx, updated) } -func (m *Manager) UpdateSchemaID(ctx context.Context, id uuid.UUID, schemaID string, opts ...ManagerOption) error { +func (m *Manager) UpdateSchemaID(ctx context.Context, id uuid.UUID, schemaID string, opts ...ManagerOption) (err error) { + ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.UpdateSchemaID") + defer otelx.End(span, &err) + o := newManagerOptions(opts) original, err := m.r.IdentityPool().(PrivilegedPool).GetIdentityConfidential(ctx, id) if err != nil { @@ -129,14 +145,17 @@ func (m *Manager) UpdateSchemaID(ctx context.Context, id uuid.UUID, schemaID str } original.SchemaID = schemaID - if err := m.validate(ctx, original, o); err != nil { + if err := m.ValidateIdentity(ctx, original, o); err != nil { return err } return m.r.IdentityPool().(PrivilegedPool).UpdateIdentity(ctx, original) } -func (m *Manager) SetTraits(ctx context.Context, id uuid.UUID, traits Traits, opts ...ManagerOption) (*Identity, error) { +func (m *Manager) SetTraits(ctx context.Context, id uuid.UUID, traits Traits, opts ...ManagerOption) (_ *Identity, err error) { + ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.SetTraits") + defer otelx.End(span, &err) + o := newManagerOptions(opts) original, err := m.r.IdentityPool().(PrivilegedPool).GetIdentityConfidential(ctx, id) if err != nil { @@ -146,7 +165,7 @@ func (m *Manager) SetTraits(ctx context.Context, id uuid.UUID, traits Traits, op // original is used to check whether protected traits were modified updated := deepcopy.Copy(original).(*Identity) updated.Traits = traits - if err := m.validate(ctx, updated, o); err != nil { + if err := m.ValidateIdentity(ctx, updated, o); err != nil { return nil, err } @@ -157,7 +176,10 @@ func (m *Manager) SetTraits(ctx context.Context, id uuid.UUID, traits Traits, op return updated, nil } -func (m *Manager) UpdateTraits(ctx context.Context, id uuid.UUID, traits Traits, opts ...ManagerOption) error { +func (m *Manager) UpdateTraits(ctx context.Context, id uuid.UUID, traits Traits, opts ...ManagerOption) (err error) { + ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.UpdateTraits") + defer otelx.End(span, &err) + updated, err := m.SetTraits(ctx, id, traits, opts...) if err != nil { return err @@ -166,7 +188,10 @@ func (m *Manager) UpdateTraits(ctx context.Context, id uuid.UUID, traits Traits, return m.r.IdentityPool().(PrivilegedPool).UpdateIdentity(ctx, updated) } -func (m *Manager) validate(ctx context.Context, i *Identity, o *managerOptions) error { +func (m *Manager) ValidateIdentity(ctx context.Context, i *Identity, o *ManagerOptions) (err error) { + ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.validate") + defer otelx.End(span, &err) + if err := m.r.IdentityValidator().Validate(ctx, i); err != nil { if _, ok := errorsx.Cause(err).(*jsonschema.ValidationError); ok && !o.ExposeValidationErrors { return herodot.ErrBadRequest.WithReasonf("%s", err).WithWrap(err) @@ -178,6 +203,9 @@ func (m *Manager) validate(ctx context.Context, i *Identity, o *managerOptions) } func (m *Manager) CountActiveFirstFactorCredentials(ctx context.Context, i *Identity) (count int, err error) { + ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.CountActiveFirstFactorCredentials") + defer otelx.End(span, &err) + for _, strategy := range m.r.ActiveCredentialsCounterStrategies(ctx) { current, err := strategy.CountActiveFirstFactorCredentials(i.Credentials) if err != nil { @@ -190,6 +218,9 @@ func (m *Manager) CountActiveFirstFactorCredentials(ctx context.Context, i *Iden } func (m *Manager) CountActiveMultiFactorCredentials(ctx context.Context, i *Identity) (count int, err error) { + ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.CountActiveMultiFactorCredentials") + defer otelx.End(span, &err) + for _, strategy := range m.r.ActiveCredentialsCounterStrategies(ctx) { current, err := strategy.CountActiveMultiFactorCredentials(i.Credentials) if err != nil { diff --git a/identity/manager_test.go b/identity/manager_test.go index 0118acc45bb..6764f33ab95 100644 --- a/identity/manager_test.go +++ b/identity/manager_test.go @@ -233,7 +233,7 @@ func TestManager(t *testing.T) { require.NoError(t, reg.IdentityManager().UpdateTraits(context.Background(), original.ID, identity.Traits(`{"email":"baz@ory.sh","email_verify":"baz@ory.sh","email_recovery":"baz@ory.sh","email_creds":"baz@ory.sh","unprotected": "bar"}`))) checkExtensionFieldsForIdentities(t, "baz@ory.sh", original) - actual, err := reg.IdentityPool().GetIdentity(context.Background(), original.ID) + actual, err := reg.IdentityPool().GetIdentity(context.Background(), original.ID, identity.ExpandNothing) require.NoError(t, err) assert.JSONEq(t, `{"email":"baz@ory.sh","email_verify":"baz@ory.sh","email_recovery":"baz@ory.sh","email_creds":"baz@ory.sh","unprotected": "bar"}`, string(actual.Traits)) }) diff --git a/identity/pool.go b/identity/pool.go index ac900388907..e4830535b7f 100644 --- a/identity/pool.go +++ b/identity/pool.go @@ -6,20 +6,22 @@ package identity import ( "context" + "github.com/ory/x/sqlxx" + "github.com/gofrs/uuid" ) type ( Pool interface { // ListIdentities lists all identities in the store given the page and itemsPerPage. - ListIdentities(ctx context.Context, page, itemsPerPage int) ([]Identity, error) + ListIdentities(ctx context.Context, expandables sqlxx.Expandables, page, itemsPerPage int) ([]Identity, error) // CountIdentities counts the number of identities in the store. CountIdentities(ctx context.Context) (int64, error) // GetIdentity returns an identity by its id. Will return an error if the identity does not exist or backend // connectivity is broken. - GetIdentity(context.Context, uuid.UUID) (*Identity, error) + GetIdentity(context.Context, uuid.UUID, sqlxx.Expandables) (*Identity, error) // FindVerifiableAddressByValue returns a matching address or sql.ErrNoRows if no address could be found. FindVerifiableAddressByValue(ctx context.Context, via VerifiableAddressType, address string) (*VerifiableAddress, error) @@ -65,5 +67,8 @@ type ( // ListRecoveryAddresses lists all tracked recovery addresses. ListRecoveryAddresses(ctx context.Context, page, itemsPerPage int) ([]RecoveryAddress, error) + + // HydrateIdentityAssociations hydrates the associations of an identity. + HydrateIdentityAssociations(ctx context.Context, i *Identity, expandables Expandables) error } ) diff --git a/identity/stub/expand.schema.json b/identity/stub/expand.schema.json new file mode 100644 index 00000000000..5cb8a8eaba0 --- /dev/null +++ b/identity/stub/expand.schema.json @@ -0,0 +1,49 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": false, + "properties": { + "traits": { + "additionalProperties": false, + "properties": { + "email": { + "format": "email", + "ory.sh/kratos": { + "credentials": { + "password": { + "identifier": true + }, + "webauthn": { + "identifier": true + }, + "totp": { + "account_name": true + } + }, + "recovery": { + "via": "email" + }, + "verification": { + "via": "email" + } + }, + "title": "Email address", + "type": "string", + "maxLength": 320 + }, + "name": { + "minLength": 1, + "title": "Name", + "type": "string", + "maxLength": 256 + } + }, + "required": [ + "email", + "name" + ], + "type": "object" + } + }, + "title": "Person", + "type": "object" +} diff --git a/credentialmigrate/stub/webauthn/v0.json b/identity/stub/webauthn/v0.json similarity index 51% rename from credentialmigrate/stub/webauthn/v0.json rename to identity/stub/webauthn/v0.json index 5845232120b..4017734c53f 100644 --- a/credentialmigrate/stub/webauthn/v0.json +++ b/identity/stub/webauthn/v0.json @@ -11,6 +11,18 @@ }, "display_name": "asdf", "added_at": "2022-02-28T16:40:39Z" + }, + { + "id": "1Q4LaIJ9NiqS1r0CQpWY+K0gMvhOq4yk5BHuO/YlitcurSpBK7weDXOvBcuN4lvn6DAmjGfmj/J/6bpOmtdT8Q==", + "public_key": "pQECAyYgASFYILAYFLoH1T8bQMSbPrNBCMMS5U7OFWRwv2U+GkAoiBADIlggBv+8ni7XVZYBB8ufMbP/d9fDxbmOkVVHOgcJifnoOR4=", + "attestation_type": "none", + "authenticator": { + "aaguid": "AAAAAAAAAAAAAAAAAAAAAA==", + "sign_count": 4, + "clone_warning": false + }, + "display_name": "asdf", + "added_at": "2022-02-28T16:40:39Z" } ] } diff --git a/credentialmigrate/stub/webauthn/v1.json b/identity/stub/webauthn/v1.json similarity index 100% rename from credentialmigrate/stub/webauthn/v1.json rename to identity/stub/webauthn/v1.json diff --git a/identity/test/pool.go b/identity/test/pool.go index 141067cb4da..524b3451368 100644 --- a/identity/test/pool.go +++ b/identity/test/pool.go @@ -43,10 +43,157 @@ import ( func TestPool(ctx context.Context, conf *config.Config, p interface { persistence.Persister -}) func(t *testing.T) { +}, m *identity.Manager) func(t *testing.T) { return func(t *testing.T) { nid, p := testhelpers.NewNetworkUnlessExisting(t, ctx, p) + t.Run("case=expand", func(t *testing.T) { + expandSchema := schema.Schema{ + ID: "expandSchema", + URL: urlx.ParseOrPanic("file://./stub/expand.schema.json"), + RawURL: "file://./stub/expand.schema.json", + } + + conf.MustSet(ctx, config.ViperKeyIdentitySchemas, []config.Schema{ + { + ID: expandSchema.ID, + URL: expandSchema.RawURL, + }, + }) + + require.NoError(t, p.GetConnection(ctx).RawQuery("DELETE FROM identities WHERE true").Exec()) + t.Cleanup(func() { + require.NoError(t, p.GetConnection(ctx).RawQuery("DELETE FROM identities WHERE true").Exec()) + }) + + expected := identity.NewIdentity(expandSchema.ID) + expected.Traits = identity.Traits(`{"email":"` + uuid.Must(uuid.NewV4()).String() + "@ory.sh" + `","name":"john doe"}`) + require.NoError(t, m.ValidateIdentity(ctx, expected, new(identity.ManagerOptions))) + require.NoError(t, p.CreateIdentity(ctx, expected)) + require.NoError(t, identity.UpgradeCredentials(expected)) + + assert.NotEmpty(t, expected.RecoveryAddresses) + assert.NotEmpty(t, expected.VerifiableAddresses) + assert.NotEmpty(t, expected.Credentials) + assert.NotEqual(t, uuid.Nil, expected.RecoveryAddresses[0].ID) + assert.NotEqual(t, uuid.Nil, expected.VerifiableAddresses[0].ID) + + runner := func(t *testing.T, expand sqlxx.Expandables, cb func(*testing.T, *identity.Identity)) { + assertion := func(t *testing.T, actual *identity.Identity) { + assertx.EqualAsJSONExcept(t, expected, actual, []string{ + "verifiable_addresses", "recovery_addresses", "updated_at", "created_at", "credentials", "state_changed_at", + }) + cb(t, actual) + } + + t.Run("find", func(t *testing.T) { + actual, err := p.GetIdentity(ctx, expected.ID, expand) + require.NoError(t, err) + assertion(t, actual) + }) + + t.Run("list", func(t *testing.T) { + actual, err := p.ListIdentities(ctx, expand, 0, 10) + require.NoError(t, err) + require.Len(t, actual, 1) + assertion(t, &actual[0]) + }) + } + + t.Run("expand=nothing", func(t *testing.T) { + runner(t, identity.ExpandNothing, func(t *testing.T, actual *identity.Identity) { + assert.Empty(t, actual.RecoveryAddresses) + assert.Empty(t, actual.VerifiableAddresses) + assert.Empty(t, actual.Credentials) + assert.Empty(t, actual.InternalCredentials) + }) + }) + + t.Run("expand=credentials", func(t *testing.T) { + runner(t, identity.ExpandCredentials, func(t *testing.T, actual *identity.Identity) { + assert.Empty(t, actual.RecoveryAddresses) + assert.Empty(t, actual.VerifiableAddresses) + + require.Len(t, actual.InternalCredentials, 2) + require.Len(t, actual.Credentials, 2) + + assertx.EqualAsJSONExcept(t, expected.Credentials[identity.CredentialsTypePassword], actual.Credentials[identity.CredentialsTypePassword], []string{"updated_at", "created_at"}) + assertx.EqualAsJSONExcept(t, expected.Credentials[identity.CredentialsTypeWebAuthn], actual.Credentials[identity.CredentialsTypeWebAuthn], []string{"updated_at", "created_at"}) + }) + }) + + t.Run("expand=recovery address", func(t *testing.T) { + runner(t, sqlxx.Expandables{identity.ExpandFieldRecoveryAddresses}, func(t *testing.T, actual *identity.Identity) { + assert.Empty(t, actual.Credentials) + assert.Empty(t, actual.InternalCredentials) + assert.Empty(t, actual.VerifiableAddresses) + + require.Len(t, actual.RecoveryAddresses, 1) + assertx.EqualAsJSONExcept(t, expected.RecoveryAddresses, actual.RecoveryAddresses, []string{"0.updated_at", "0.created_at"}) + }) + }) + + t.Run("expand=verification address", func(t *testing.T) { + runner(t, sqlxx.Expandables{identity.ExpandFieldVerifiableAddresses}, func(t *testing.T, actual *identity.Identity) { + assert.Empty(t, actual.Credentials) + assert.Empty(t, actual.InternalCredentials) + assert.Empty(t, actual.RecoveryAddresses) + + require.Len(t, actual.VerifiableAddresses, 1) + assertx.EqualAsJSONExcept(t, expected.VerifiableAddresses, actual.VerifiableAddresses, []string{"0.updated_at", "0.created_at"}) + }) + }) + + t.Run("expand=default", func(t *testing.T) { + runner(t, identity.ExpandDefault, func(t *testing.T, actual *identity.Identity) { + + assert.Empty(t, actual.Credentials) + assert.Empty(t, actual.InternalCredentials) + + require.Len(t, actual.RecoveryAddresses, 1) + assertx.EqualAsJSONExcept(t, expected.RecoveryAddresses, actual.RecoveryAddresses, []string{"0.updated_at", "0.created_at"}) + + require.Len(t, actual.VerifiableAddresses, 1) + assertx.EqualAsJSONExcept(t, expected.VerifiableAddresses, actual.VerifiableAddresses, []string{"0.updated_at", "0.created_at"}) + }) + }) + + t.Run("expand=everything", func(t *testing.T) { + runner(t, identity.ExpandEverything, func(t *testing.T, actual *identity.Identity) { + + require.Len(t, actual.InternalCredentials, 2) + require.Len(t, actual.Credentials, 2) + + assertx.EqualAsJSONExcept(t, expected.Credentials[identity.CredentialsTypePassword], actual.Credentials[identity.CredentialsTypePassword], []string{"updated_at", "created_at"}) + assertx.EqualAsJSONExcept(t, expected.Credentials[identity.CredentialsTypeWebAuthn], actual.Credentials[identity.CredentialsTypeWebAuthn], []string{"updated_at", "created_at"}) + + require.Len(t, actual.RecoveryAddresses, 1) + assertx.EqualAsJSONExcept(t, expected.RecoveryAddresses, actual.RecoveryAddresses, []string{"0.updated_at", "0.created_at"}) + + require.Len(t, actual.VerifiableAddresses, 1) + assertx.EqualAsJSONExcept(t, expected.VerifiableAddresses, actual.VerifiableAddresses, []string{"0.updated_at", "0.created_at"}) + }) + }) + + t.Run("expand=load", func(t *testing.T) { + runner(t, identity.ExpandNothing, func(t *testing.T, actual *identity.Identity) { + require.NoError(t, p.HydrateIdentityAssociations(ctx, actual, identity.ExpandEverything)) + + require.Len(t, actual.InternalCredentials, 2) + require.Len(t, actual.Credentials, 2) + + assertx.EqualAsJSONExcept(t, expected.Credentials[identity.CredentialsTypePassword], actual.Credentials[identity.CredentialsTypePassword], []string{"updated_at", "created_at"}) + assertx.EqualAsJSONExcept(t, expected.Credentials[identity.CredentialsTypeWebAuthn], actual.Credentials[identity.CredentialsTypeWebAuthn], []string{"updated_at", "created_at"}) + + require.Len(t, actual.RecoveryAddresses, 1) + assertx.EqualAsJSONExcept(t, expected.RecoveryAddresses, actual.RecoveryAddresses, []string{"0.updated_at", "0.created_at"}) + + require.Len(t, actual.VerifiableAddresses, 1) + assertx.EqualAsJSONExcept(t, expected.VerifiableAddresses, actual.VerifiableAddresses, []string{"0.updated_at", "0.created_at"}) + }) + }) + }) + exampleServerURL := urlx.ParseOrPanic("http://example.com") conf.MustSet(ctx, config.ViperKeyPublicBaseURL, exampleServerURL.String()) defaultSchema := schema.Schema{ @@ -134,7 +281,7 @@ func TestPool(ctx context.Context, conf *config.Config, p interface { require.NoError(t, p.CreateIdentity(ctx, expected)) createdIDs = append(createdIDs, expected.ID) - actual, err := p.GetIdentity(ctx, expected.ID) + actual, err := p.GetIdentity(ctx, expected.ID, identity.ExpandDefault) require.NoError(t, err) assert.Equal(t, expected.ID, actual.ID) @@ -148,7 +295,7 @@ func TestPool(ctx context.Context, conf *config.Config, p interface { t.Run("different network", func(t *testing.T) { _, p := testhelpers.NewNetwork(t, ctx, p) - _, err := p.GetIdentity(ctx, expected.ID) + _, err := p.GetIdentity(ctx, expected.ID, identity.ExpandDefault) require.ErrorIs(t, err, sqlcon.ErrNoRows) count, err := p.CountIdentities(ctx) @@ -158,10 +305,10 @@ func TestPool(ctx context.Context, conf *config.Config, p interface { }) t.Run("case=should error when the identity ID does not exist", func(t *testing.T) { - _, err := p.GetIdentity(ctx, uuid.UUID{}) + _, err := p.GetIdentity(ctx, uuid.UUID{}, identity.ExpandNothing) require.Error(t, err) - _, err = p.GetIdentity(ctx, x.NewUUID()) + _, err = p.GetIdentity(ctx, x.NewUUID(), identity.ExpandNothing) require.Error(t, err) _, err = p.GetIdentityConfidential(ctx, x.NewUUID()) @@ -185,7 +332,7 @@ func TestPool(ctx context.Context, conf *config.Config, p interface { require.NoError(t, p.CreateIdentity(ctx, expected)) createdIDs = append(createdIDs, expected.ID) - actual, err := p.GetIdentity(ctx, expected.ID) + actual, err := p.GetIdentity(ctx, expected.ID, identity.ExpandDefault) require.NoError(t, err) assert.Equal(t, altSchema.ID, actual.SchemaID) assert.Equal(t, altSchema.SchemaURL(exampleServerURL).String(), actual.SchemaURL) @@ -208,7 +355,7 @@ func TestPool(ctx context.Context, conf *config.Config, p interface { t.Run("different network", func(t *testing.T) { _, p := testhelpers.NewNetwork(t, ctx, p) - _, err := p.GetIdentity(ctx, expected.ID) + _, err := p.GetIdentity(ctx, expected.ID, identity.ExpandNothing) require.ErrorIs(t, err, sqlcon.ErrNoRows) _, err = p.GetIdentityConfidential(ctx, expected.ID) @@ -226,7 +373,7 @@ func TestPool(ctx context.Context, conf *config.Config, p interface { err := p.CreateIdentity(ctx, expected) require.ErrorIs(t, err, sqlcon.ErrUniqueViolation, "%+v", err) - _, err = p.GetIdentity(ctx, expected.ID) + _, err = p.GetIdentity(ctx, expected.ID, identity.ExpandNothing) require.Error(t, err) t.Run("succeeds on different network/id="+ids, func(t *testing.T) { @@ -235,7 +382,7 @@ func TestPool(ctx context.Context, conf *config.Config, p interface { err := p.CreateIdentity(ctx, expected) require.NoError(t, err) - _, err = p.GetIdentity(ctx, expected.ID) + _, err = p.GetIdentity(ctx, expected.ID, identity.ExpandNothing) require.NoError(t, err) }) } @@ -249,7 +396,7 @@ func TestPool(ctx context.Context, conf *config.Config, p interface { expected := oidcIdentity("", "oidc-1") require.Error(t, p.CreateIdentity(ctx, expected)) - _, err := p.GetIdentity(ctx, expected.ID) + _, err := p.GetIdentity(ctx, expected.ID, identity.ExpandNothing) require.Error(t, err) second := oidcIdentity("", "OIDC-1") @@ -261,7 +408,7 @@ func TestPool(ctx context.Context, conf *config.Config, p interface { expected := oidcIdentity("", "oidc-1") require.NoError(t, p.CreateIdentity(ctx, expected)) - _, err = p.GetIdentity(ctx, expected.ID) + _, err = p.GetIdentity(ctx, expected.ID, identity.ExpandNothing) require.NoError(t, err) }) }) @@ -402,13 +549,13 @@ func TestPool(ctx context.Context, conf *config.Config, p interface { require.ErrorIs(t, p.DeleteIdentity(ctx, expected.ID), sqlcon.ErrNoRows) p = testhelpers.ExistingNetwork(t, p, nid) - _, err := p.GetIdentity(ctx, expected.ID) + _, err := p.GetIdentity(ctx, expected.ID, identity.ExpandNothing) require.NoError(t, err) }) require.NoError(t, p.DeleteIdentity(ctx, expected.ID)) - _, err := p.GetIdentity(ctx, expected.ID) + _, err := p.GetIdentity(ctx, expected.ID, identity.ExpandNothing) require.Error(t, err) }) @@ -426,14 +573,14 @@ func TestPool(ctx context.Context, conf *config.Config, p interface { }) t.Run("case=list", func(t *testing.T) { - is, err := p.ListIdentities(ctx, 0, 25) + is, err := p.ListIdentities(ctx, identity.ExpandDefault, 0, 25) require.NoError(t, err) assert.Len(t, is, len(createdIDs)) for _, id := range createdIDs { var found bool for _, i := range is { if i.ID == id { - expected, err := p.GetIdentity(ctx, id) + expected, err := p.GetIdentity(ctx, id, identity.ExpandDefault) require.NoError(t, err) assertx.EqualAsJSON(t, expected, i) found = true @@ -444,7 +591,7 @@ func TestPool(ctx context.Context, conf *config.Config, p interface { t.Run("no results on other network", func(t *testing.T) { _, p := testhelpers.NewNetwork(t, ctx, p) - is, err := p.ListIdentities(ctx, 0, 25) + is, err := p.ListIdentities(ctx, identity.ExpandDefault, 0, 25) require.NoError(t, err) assert.Len(t, is, 0) }) @@ -860,14 +1007,14 @@ func TestPool(ctx context.Context, conf *config.Config, p interface { require.NoError(t, p.GetConnection(ctx).RawQuery("INSERT INTO identity_credential_identifiers (id, identity_credential_id, nid, identifier, created_at, updated_at, identity_credential_type_id) VALUES (?, ?, ?, ?, ?, ?, ?)", ici1, cid1, nid1, "nid1", time.Now(), time.Now(), m[0].ID).Exec()) require.NoError(t, p.GetConnection(ctx).RawQuery("INSERT INTO identity_credential_identifiers (id, identity_credential_id, nid, identifier, created_at, updated_at, identity_credential_type_id) VALUES (?, ?, ?, ?, ?, ?, ?)", ici2, cid2, nid2, "nid2", time.Now(), time.Now(), m[0].ID).Exec()) - _, err := p.GetIdentity(ctx, nid1) + _, err := p.GetIdentity(ctx, nid1, identity.ExpandNothing) require.ErrorIs(t, err, sqlcon.ErrNoRows) _, err = p.GetIdentityConfidential(ctx, nid1) require.ErrorIs(t, err, sqlcon.ErrNoRows) i, c, err := p.FindByCredentialsIdentifier(ctx, m[0].Name, "nid1") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "nid1", c.Identifiers[0]) require.Len(t, i.Credentials, 0) diff --git a/persistence/sql/migratest/fixtures/identity/5ff66179-c240-4703-b0d8-494592cefff5.json b/persistence/sql/migratest/fixtures/identity/5ff66179-c240-4703-b0d8-494592cefff5.json index f80eed09a37..dc03bec2801 100644 --- a/persistence/sql/migratest/fixtures/identity/5ff66179-c240-4703-b0d8-494592cefff5.json +++ b/persistence/sql/migratest/fixtures/identity/5ff66179-c240-4703-b0d8-494592cefff5.json @@ -4,6 +4,7 @@ "password": { "type": "password", "identifiers": [ + "foo-dupe@ory.sh", "foo@ory.sh" ], "config": { @@ -20,6 +21,17 @@ "traits": { "email": "bazbar@ory.sh" }, + "verifiable_addresses": [ + { + "id": "45e867e9-2745-4f16-8dd4-84334a252b61", + "value": "foo@ory.sh", + "verified": false, + "via": "email", + "status": "pending", + "created_at": "2013-10-07T08:23:19Z", + "updated_at": "2013-10-07T08:23:19Z" + } + ], "metadata_public": null, "metadata_admin": null, "created_at": "2013-10-07T08:23:19Z", diff --git a/persistence/sql/migratest/fixtures/identity/a251ebc2-880c-4f76-a8f3-38e6940eab0e.json b/persistence/sql/migratest/fixtures/identity/a251ebc2-880c-4f76-a8f3-38e6940eab0e.json index 93d23658cb0..9c63190a6fa 100644 --- a/persistence/sql/migratest/fixtures/identity/a251ebc2-880c-4f76-a8f3-38e6940eab0e.json +++ b/persistence/sql/migratest/fixtures/identity/a251ebc2-880c-4f76-a8f3-38e6940eab0e.json @@ -4,6 +4,7 @@ "password": { "type": "password", "identifiers": [ + "foo-dupe@ory.sh", "foobar@ory.sh" ], "config": { @@ -20,6 +21,44 @@ "traits": { "email": "foobar@ory.sh" }, + "verifiable_addresses": [ + { + "id": "b2d59320-8564-4400-a39f-a22a497a23f1", + "value": "foobar+without-code@ory.sh", + "verified": false, + "via": "email", + "status": "pending", + "created_at": "2013-10-07T08:23:19Z", + "updated_at": "2013-10-07T08:23:19Z" + }, + { + "id": "c2427b6d-312b-46d9-9285-536db7ae11fd", + "value": "foobar@ory.sh", + "verified": false, + "via": "email", + "status": "pending", + "created_at": "2013-10-07T08:23:19Z", + "updated_at": "2013-10-07T08:23:19Z" + }, + { + "id": "d4718a67-aec2-418d-8173-6ebc7bde3b86", + "value": "foobar+11345642c6c0@ory.sh", + "verified": false, + "via": "email", + "status": "pending", + "created_at": "2013-10-07T08:23:19Z", + "updated_at": "2013-10-07T08:23:19Z" + } + ], + "recovery_addresses": [ + { + "id": "b8293f1c-010f-45d9-b809-f3fc5365ba80", + "value": "foobar@ory.sh", + "via": "email", + "created_at": "2013-10-07T08:23:19Z", + "updated_at": "2013-10-07T08:23:19Z" + } + ], "metadata_public": null, "metadata_admin": null, "created_at": "2013-10-07T08:23:19Z", diff --git a/persistence/sql/migratest/migration_test.go b/persistence/sql/migratest/migration_test.go index f7caf0aa27b..f7afa928816 100644 --- a/persistence/sql/migratest/migration_test.go +++ b/persistence/sql/migratest/migration_test.go @@ -157,15 +157,14 @@ func TestMigrations(t *testing.T) { defer wg.Done() t.Parallel() - ids, err := d.PrivilegedIdentityPool().ListIdentities(context.Background(), 0, 1000) + ids, err := d.PrivilegedIdentityPool().ListIdentities(context.Background(), identity.ExpandEverything, 0, 1000) require.NoError(t, err) require.NotEmpty(t, ids) var found []string - for _, id := range ids { + for y, id := range ids { found = append(found, id.ID.String()) - actual, err := d.PrivilegedIdentityPool().GetIdentityConfidential(context.Background(), id.ID) - require.NoError(t, err, "ID: %s", id.ID) + actual := &ids[y] for _, a := range actual.VerifiableAddresses { CompareWithFixture(t, a, "identity_verification_address", a.ID.String()) @@ -175,9 +174,27 @@ func TestMigrations(t *testing.T) { CompareWithFixture(t, a, "identity_recovery_address", a.ID.String()) } - // Prevents ordering to get in the way. - actual.VerifiableAddresses = nil - actual.RecoveryAddresses = nil + CompareWithFixture(t, identity.WithCredentialsAndAdminMetadataInJSON(*actual), "identity", id.ID.String()) + } + + migratest.ContainsExpectedIds(t, filepath.Join("fixtures", "identity"), found) + }) + + t.Run("case=identity", func(t *testing.T) { + wg.Add(1) + defer wg.Done() + t.Parallel() + + ids, err := d.PrivilegedIdentityPool().ListIdentities(context.Background(), identity.ExpandEverything, 0, 1000) + require.NoError(t, err) + require.NotEmpty(t, ids) + + var found []string + for _, id := range ids { + actual, err := d.PrivilegedIdentityPool().GetIdentityConfidential(context.Background(), id.ID) + require.NoError(t, err) + found = append(found, actual.ID.String()) + CompareWithFixture(t, identity.WithCredentialsAndAdminMetadataInJSON(*actual), "identity", id.ID.String()) } diff --git a/persistence/sql/migratest/testdata/20210410175418_testdata.sql b/persistence/sql/migratest/testdata/20210410175418_testdata.sql index 6fff54f3414..491ddbbc42e 100644 --- a/persistence/sql/migratest/testdata/20210410175418_testdata.sql +++ b/persistence/sql/migratest/testdata/20210410175418_testdata.sql @@ -21,3 +21,4 @@ INSERT INTO identity_verifiable_addresses (id, nid, status, via, verified, value INSERT INTO identities (id, nid, schema_id, traits, created_at, updated_at) VALUES ('196d8c1e-4f04-40f0-94b3-5ec43996b28a', '884f556e-eb3a-4b9f-bee3-11345642c6c0', 'default', '{"email":"foobar@ory.sh"}', '2013-10-07 08:23:19', '2013-10-07 08:23:19'); INSERT INTO identities (id, nid, schema_id, traits, created_at, updated_at) VALUES ('ed253b2c-48ed-4c58-9b6f-1dc963c30a66', '884f556e-eb3a-4b9f-bee3-11345642c6c0', 'default', '{"email":"bazbar@ory.sh"}', '2013-10-07 08:23:19', '2013-10-07 08:23:19'); + diff --git a/persistence/sql/migratest/testdata/20210817181232_testdata.sql b/persistence/sql/migratest/testdata/20210817181232_testdata.sql index 3c994c384e3..b5e2a64c5a3 100644 --- a/persistence/sql/migratest/testdata/20210817181232_testdata.sql +++ b/persistence/sql/migratest/testdata/20210817181232_testdata.sql @@ -1,7 +1,7 @@ -INSERT INTO identity_credential_identifiers (id, identifier, identity_credential_id, created_at, updated_at, identity_credential_type_id) -VALUES ('2672f198-4795-437c-8e14-56459b1d941a', 'foo-dupe@ory.sh', '35b60ecf-30f9-42d6-bf5d-47ad41148691', +INSERT INTO identity_credential_identifiers (id, nid, identifier, identity_credential_id, created_at, updated_at, identity_credential_type_id) +VALUES ('2672f198-4795-437c-8e14-56459b1d941a', '884f556e-eb3a-4b9f-bee3-11345642c6c0','foo-dupe@ory.sh', '35b60ecf-30f9-42d6-bf5d-47ad41148691', '2013-10-07 08:23:19', '2013-10-07 08:23:19', '22bff9ae-f5aa-45d7-803b-97ec0b4e7b32'); -INSERT INTO identity_credential_identifiers (id, identifier, identity_credential_id, created_at, updated_at, identity_credential_type_id) -VALUES ('10985ed1-5b6e-4012-ac10-03d87df65618', 'foo-dupe@ory.sh', '74ac2d31-bccb-442f-a792-7b8bb14817f8', +INSERT INTO identity_credential_identifiers (id, nid, identifier, identity_credential_id, created_at, updated_at, identity_credential_type_id) +VALUES ('10985ed1-5b6e-4012-ac10-03d87df65618', '884f556e-eb3a-4b9f-bee3-11345642c6c0','foo-dupe@ory.sh', '74ac2d31-bccb-442f-a792-7b8bb14817f8', '2013-10-07 08:23:19', '2013-10-07 08:23:19', '6b213fa0-e6ad-46cb-8878-b088d2ce2e3c'); diff --git a/persistence/sql/migratest/testdata/20220301102701_testdata.sql b/persistence/sql/migratest/testdata/20220301102701_testdata.sql index bafd65fde6d..f672b9ddfec 100644 --- a/persistence/sql/migratest/testdata/20220301102701_testdata.sql +++ b/persistence/sql/migratest/testdata/20220301102701_testdata.sql @@ -1 +1 @@ -INSERT INTO identity_credentials (id, config, identity_credential_type_id, identity_id, created_at, updated_at, version) VALUES ('4cefc264-4291-4abc-8f26-cc0217874f14', '{"hashed_password":"$argon2id$v=19$m=131072,t=2,p=1$lQFPaKxXqPL56/mU7vRi4w$6aldHyBnURt8sP8+xu41Ng"}', '22bff9ae-f5aa-45d7-803b-97ec0b4e7b32', '5ff66179-c240-4703-b0d8-494592cefff5', '2013-10-07 08:23:19', '2013-10-07 08:23:19', 0); +INSERT INTO identity_credentials (id, nid, config, identity_credential_type_id, identity_id, created_at, updated_at, version) VALUES ('4cefc264-4291-4abc-8f26-cc0217874f14', '884f556e-eb3a-4b9f-bee3-11345642c6c0', '{"hashed_password":"$argon2id$v=19$m=131072,t=2,p=1$lQFPaKxXqPL56/mU7vRi4w$6aldHyBnURt8sP8+xu41Ng"}', '22bff9ae-f5aa-45d7-803b-97ec0b4e7b32', '5ff66179-c240-4703-b0d8-494592cefff5', '2013-10-07 08:23:19', '2013-10-07 08:23:19', 0); diff --git a/persistence/sql/persister_identity.go b/persistence/sql/persister_identity.go index e8cee3d7174..e733904357e 100644 --- a/persistence/sql/persister_identity.go +++ b/persistence/sql/persister_identity.go @@ -10,7 +10,9 @@ import ( "strings" "time" - "github.com/ory/kratos/credentialmigrate" + "go.opentelemetry.io/otel/attribute" + + "github.com/ory/x/otelx" "github.com/ory/jsonschema/v3" "github.com/ory/x/sqlxx" @@ -109,7 +111,7 @@ WHERE ici.identifier = ? nid, ct, ).First(&find); err != nil { - if errors.Cause(err) == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil, nil, sqlcon.HandleError(err) // herodot.ErrNotFound.WithTrace(err).WithReasonf(`No identity matching credentials identifier "%s" could be found.`, match) } @@ -161,7 +163,7 @@ func (p *Persister) createIdentityCredentials(ctx context.Context, i *identity.I cred.IdentityID = i.ID cred.NID = nid - cred.CredentialTypeID = ct.ID + cred.IdentityCredentialTypeID = ct.ID if err := c.Create(&cred); err != nil { return sqlcon.HandleError(err) } @@ -279,30 +281,6 @@ func (p *Persister) createRecoveryAddresses(ctx context.Context, i *identity.Ide return nil } -func (p *Persister) findVerifiableAddresses(ctx context.Context, i *identity.Identity) error { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.findVerifiableAddresses") - defer span.End() - - var addresses []identity.VerifiableAddress - if err := p.GetConnection(ctx).Where("identity_id = ? AND nid = ?", i.ID, p.NetworkID(ctx)).Order("id ASC").All(&addresses); err != nil { - return err - } - i.VerifiableAddresses = addresses - return nil -} - -func (p *Persister) findRecoveryAddresses(ctx context.Context, i *identity.Identity) error { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.findRecoveryAddresses") - defer span.End() - - var addresses []identity.RecoveryAddress - if err := p.GetConnection(ctx).Where("identity_id = ? AND nid = ?", i.ID, p.NetworkID(ctx)).Order("id ASC").All(&addresses); err != nil { - return err - } - i.RecoveryAddresses = addresses - return nil -} - func (p *Persister) CountIdentities(ctx context.Context) (int64, error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CountIdentities") defer span.End() @@ -361,27 +339,52 @@ func (p *Persister) CreateIdentity(ctx context.Context, i *identity.Identity) er }) } -func (p *Persister) ListIdentities(ctx context.Context, page, perPage int) ([]identity.Identity, error) { +func (p *Persister) HydrateIdentityAssociations(ctx context.Context, i *identity.Identity, expand identity.Expandables) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.HydrateIdentityAssociations") + defer otelx.End(span, &err) + + con := p.GetConnection(ctx) + if err := con.Load(i, expand.ToEager()...); err != nil { + return err + } + + if err := i.AfterEagerFind(con); err != nil { + return err + } + + return p.injectTraitsSchemaURL(ctx, i) +} + +func (p *Persister) ListIdentities(ctx context.Context, expand identity.Expandables, page, perPage int) (res []identity.Identity, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListIdentities") - defer span.End() + defer otelx.End(span, &err) + span.SetAttributes( + attribute.Int("page", page), + attribute.Int("per_page", perPage), + attribute.StringSlice("expand", expand.ToEager()), + attribute.String("network.id", p.NetworkID(ctx).String()), + ) is := make([]identity.Identity, 0) + con := p.GetConnection(ctx) + query := con. + Where("nid = ?", p.NetworkID(ctx)). + Paginate(page, perPage). + Order("id DESC") + + if len(expand) > 0 { + query = query.EagerPreload(expand.ToEager()...) + } + /* #nosec G201 TableName is static */ - if err := sqlcon.HandleError(p.GetConnection(ctx).Where("nid = ?", p.NetworkID(ctx)). - EagerPreload("VerifiableAddresses", "RecoveryAddresses"). - Paginate(page, perPage).Order("id DESC"). - All(&is)); err != nil { + if err := sqlcon.HandleError(query.All(&is)); err != nil { return nil, err } schemaCache := map[string]string{} - for k := range is { i := &is[k] - if err := i.ValidateNID(); err != nil { - return nil, sqlcon.HandleError(err) - } if u, ok := schemaCache[i.SchemaID]; ok { i.SchemaURL = u @@ -447,17 +450,27 @@ func (p *Persister) DeleteIdentity(ctx context.Context, id uuid.UUID) error { return p.delete(ctx, new(identity.Identity), id) } -func (p *Persister) GetIdentity(ctx context.Context, id uuid.UUID) (*identity.Identity, error) { +func (p *Persister) GetIdentity(ctx context.Context, id uuid.UUID, expand identity.Expandables) (res *identity.Identity, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetIdentity") - defer span.End() + defer otelx.End(span, &err) + + span.SetAttributes( + attribute.String("identity.id", id.String()), + attribute.StringSlice("expand", expand.ToEager()), + attribute.String("network.id", p.NetworkID(ctx).String()), + ) + + con := p.GetConnection(ctx) + query := con.Where("id = ? AND nid = ?", id, p.NetworkID(ctx)) + if len(expand) > 0 { + query = query.EagerPreload(expand.ToEager()...) + } var i identity.Identity - if err := p.GetConnection(ctx).EagerPreload("VerifiableAddresses", "RecoveryAddresses").Where("id = ? AND nid = ?", id, p.NetworkID(ctx)).First(&i); err != nil { + if err := query.First(&i); err != nil { return nil, sqlcon.HandleError(err) } - i.Credentials = nil - if err := p.injectTraitsSchemaURL(ctx, &i); err != nil { return nil, err } @@ -465,61 +478,11 @@ func (p *Persister) GetIdentity(ctx context.Context, id uuid.UUID) (*identity.Id return &i, nil } -func (p *Persister) GetIdentityConfidential(ctx context.Context, id uuid.UUID) (*identity.Identity, error) { +func (p *Persister) GetIdentityConfidential(ctx context.Context, id uuid.UUID) (res *identity.Identity, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetIdentityConfidential") - defer span.End() + defer otelx.End(span, &err) - var i identity.Identity - - nid := p.NetworkID(ctx) - if err := p.GetConnection(ctx).Where("id = ? AND nid = ?", id, nid).First(&i); err != nil { - return nil, sqlcon.HandleError(err) - } - - var creds identity.CredentialsCollection - if err := p.GetConnection(ctx).Where("identity_id = ? AND nid = ?", id, nid).All(&creds); err != nil { - return nil, sqlcon.HandleError(err) - } - - i.Credentials = make(map[identity.CredentialsType]identity.Credentials) - for k := range creds { - cred := &creds[k] - - var ct identity.CredentialsTypeTable - if err := p.GetConnection(ctx).Find(&ct, cred.CredentialTypeID); err != nil { - return nil, sqlcon.HandleError(err) - } - cred.Type = ct.Name - - var cids identity.CredentialIdentifierCollection - if err := p.GetConnection(ctx).Where("identity_credential_id = ? AND nid = ?", cred.ID, nid).All(&cids); err != nil { - return nil, sqlcon.HandleError(err) - } - - cred.Identifiers = make([]string, len(cids)) - for kk, cid := range cids { - cred.Identifiers[kk] = cid.Identifier - } - - i.Credentials[cred.Type] = *cred - } - - if err := credentialmigrate.UpgradeCredentials(&i); err != nil { - return nil, err - } - - if err := p.findRecoveryAddresses(ctx, &i); err != nil { - return nil, err - } - if err := p.findVerifiableAddresses(ctx, &i); err != nil { - return nil, err - } - - if err := p.injectTraitsSchemaURL(ctx, &i); err != nil { - return nil, err - } - - return &i, nil + return p.GetIdentity(ctx, id, identity.ExpandEverything) } func (p *Persister) FindVerifiableAddressByValue(ctx context.Context, via identity.VerifiableAddressType, value string) (*identity.VerifiableAddress, error) { diff --git a/persistence/sql/persister_session.go b/persistence/sql/persister_session.go index a4fc1618d91..50b6e7110ae 100644 --- a/persistence/sql/persister_session.go +++ b/persistence/sql/persister_session.go @@ -8,6 +8,10 @@ import ( "fmt" "time" + "github.com/ory/x/otelx" + + "github.com/ory/kratos/identity" + "github.com/ory/x/pagination/keysetpagination" "github.com/ory/x/stringsx" @@ -30,9 +34,9 @@ const SessionDeviceLocationMaxLength = 512 const paginationMaxItemsSize = 1000 const paginationDefaultItemsSize = 250 -func (p *Persister) GetSession(ctx context.Context, sid uuid.UUID, expandables session.Expandables) (*session.Session, error) { +func (p *Persister) GetSession(ctx context.Context, sid uuid.UUID, expandables session.Expandables) (_ *session.Session, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetSession") - defer span.End() + defer otelx.End(span, &err) var s session.Session s.Devices = make([]session.Device, 0) @@ -51,7 +55,7 @@ func (p *Persister) GetSession(ctx context.Context, sid uuid.UUID, expandables s if expandables.Has(session.ExpandSessionIdentity) { // This is needed because of how identities are fetched from the store (if we use eager not all fields are // available!). - i, err := p.GetIdentity(ctx, s.IdentityID) + i, err := p.GetIdentity(ctx, s.IdentityID, identity.ExpandDefault) if err != nil { return nil, err } @@ -62,9 +66,9 @@ func (p *Persister) GetSession(ctx context.Context, sid uuid.UUID, expandables s return &s, nil } -func (p *Persister) ListSessions(ctx context.Context, active *bool, paginatorOpts []keysetpagination.Option, expandables session.Expandables) ([]session.Session, int64, *keysetpagination.Paginator, error) { +func (p *Persister) ListSessions(ctx context.Context, active *bool, paginatorOpts []keysetpagination.Option, expandables session.Expandables) (_ []session.Session, _ int64, _ *keysetpagination.Paginator, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListSessions") - defer span.End() + defer otelx.End(span, &err) s := make([]session.Session, 0) t := int64(0) @@ -106,7 +110,7 @@ func (p *Persister) ListSessions(ctx context.Context, active *bool, paginatorOpt for index := range s { sess := &(s[index]) - i, err := p.GetIdentity(ctx, sess.IdentityID) + i, err := p.GetIdentity(ctx, sess.IdentityID, identity.ExpandDefault) if err != nil { return err } @@ -125,9 +129,9 @@ func (p *Persister) ListSessions(ctx context.Context, active *bool, paginatorOpt } // ListSessionsByIdentity retrieves sessions for an identity from the store. -func (p *Persister) ListSessionsByIdentity(ctx context.Context, iID uuid.UUID, active *bool, page, perPage int, except uuid.UUID, expandables session.Expandables) ([]*session.Session, int64, error) { +func (p *Persister) ListSessionsByIdentity(ctx context.Context, iID uuid.UUID, active *bool, page, perPage int, except uuid.UUID, expandables session.Expandables) (_ []*session.Session, _ int64, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListSessionsByIdentity") - defer span.End() + defer otelx.End(span, &err) s := make([]*session.Session, 0) t := int64(0) @@ -162,7 +166,7 @@ func (p *Persister) ListSessionsByIdentity(ctx context.Context, iID uuid.UUID, a } if expandables.Has(session.ExpandSessionIdentity) { - i, err := p.GetIdentity(ctx, iID) + i, err := p.GetIdentity(ctx, iID, identity.ExpandDefault) if err != nil { return sqlcon.HandleError(err) } @@ -183,9 +187,9 @@ func (p *Persister) ListSessionsByIdentity(ctx context.Context, iID uuid.UUID, a // UpsertSession creates a session if not found else updates. // This operation also inserts Session device records when a session is being created. // The update operation skips updating Session device records since only one record would need to be updated in this case. -func (p *Persister) UpsertSession(ctx context.Context, s *session.Session) error { +func (p *Persister) UpsertSession(ctx context.Context, s *session.Session) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpsertSession") - defer span.End() + defer otelx.End(span, &err) s.NID = p.NetworkID(ctx) @@ -227,16 +231,16 @@ func (p *Persister) UpsertSession(ctx context.Context, s *session.Session) error })) } -func (p *Persister) DeleteSession(ctx context.Context, sid uuid.UUID) error { +func (p *Persister) DeleteSession(ctx context.Context, sid uuid.UUID) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteSession") - defer span.End() + defer otelx.End(span, &err) return p.delete(ctx, new(session.Session), sid) } -func (p *Persister) DeleteSessionsByIdentity(ctx context.Context, identityID uuid.UUID) error { +func (p *Persister) DeleteSessionsByIdentity(ctx context.Context, identityID uuid.UUID) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteSessionsByIdentity") - defer span.End() + defer otelx.End(span, &err) // #nosec G201 count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf( @@ -255,17 +259,17 @@ func (p *Persister) DeleteSessionsByIdentity(ctx context.Context, identityID uui return nil } -func (p *Persister) GetSessionByToken(ctx context.Context, token string, expandables session.Expandables) (*session.Session, error) { +func (p *Persister) GetSessionByToken(ctx context.Context, token string, expand session.Expandables, identityExpand identity.Expandables) (res *session.Session, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetSessionByToken") - defer span.End() + defer otelx.End(span, &err) var s session.Session s.Devices = make([]session.Device, 0) nid := p.NetworkID(ctx) q := p.GetConnection(ctx).Q() - if len(expandables) > 0 { - q = q.Eager(expandables.ToEager()...) + if len(expand) > 0 { + q = q.Eager(expand.ToEager()...) } if err := q.Where("token = ? AND nid = ?", token, nid).First(&s); err != nil { @@ -274,8 +278,8 @@ func (p *Persister) GetSessionByToken(ctx context.Context, token string, expanda // This is needed because of how identities are fetched from the store (if we use eager not all fields are // available!). - if expandables.Has(session.ExpandSessionIdentity) { - i, err := p.GetIdentity(ctx, s.IdentityID) + if expand.Has(session.ExpandSessionIdentity) { + i, err := p.GetIdentity(ctx, s.IdentityID, identityExpand) if err != nil { return nil, err } @@ -284,9 +288,9 @@ func (p *Persister) GetSessionByToken(ctx context.Context, token string, expanda return &s, nil } -func (p *Persister) DeleteSessionByToken(ctx context.Context, token string) error { +func (p *Persister) DeleteSessionByToken(ctx context.Context, token string) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteSessionByToken") - defer span.End() + defer otelx.End(span, &err) // #nosec G201 count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf( @@ -305,9 +309,9 @@ func (p *Persister) DeleteSessionByToken(ctx context.Context, token string) erro return nil } -func (p *Persister) RevokeSessionByToken(ctx context.Context, token string) error { +func (p *Persister) RevokeSessionByToken(ctx context.Context, token string) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSessionByToken") - defer span.End() + defer otelx.End(span, &err) // #nosec G201 count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf( @@ -327,9 +331,9 @@ func (p *Persister) RevokeSessionByToken(ctx context.Context, token string) erro } // RevokeSessionById revokes a given session -func (p *Persister) RevokeSessionById(ctx context.Context, sID uuid.UUID) error { +func (p *Persister) RevokeSessionById(ctx context.Context, sID uuid.UUID) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSessionById") - defer span.End() + defer otelx.End(span, &err) // #nosec G201 count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf( @@ -350,12 +354,12 @@ func (p *Persister) RevokeSessionById(ctx context.Context, sID uuid.UUID) error // RevokeSession revokes a given session. If the session does not exist or was not modified, // it effectively has been revoked already, and therefore that case does not return an error. -func (p *Persister) RevokeSession(ctx context.Context, iID, sID uuid.UUID) error { +func (p *Persister) RevokeSession(ctx context.Context, iID, sID uuid.UUID) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSession") - defer span.End() + defer otelx.End(span, &err) // #nosec G201 - err := p.GetConnection(ctx).RawQuery(fmt.Sprintf( + err = p.GetConnection(ctx).RawQuery(fmt.Sprintf( "UPDATE %s SET active = false WHERE id = ? AND identity_id = ? AND nid = ?", "sessions", ), @@ -370,9 +374,9 @@ func (p *Persister) RevokeSession(ctx context.Context, iID, sID uuid.UUID) error } // RevokeSessionsIdentityExcept marks all except the given session of an identity inactive. -func (p *Persister) RevokeSessionsIdentityExcept(ctx context.Context, iID, sID uuid.UUID) (int, error) { +func (p *Persister) RevokeSessionsIdentityExcept(ctx context.Context, iID, sID uuid.UUID) (res int, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSessionsIdentityExcept") - defer span.End() + defer otelx.End(span, &err) // #nosec G201 count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf( @@ -389,8 +393,10 @@ func (p *Persister) RevokeSessionsIdentityExcept(ctx context.Context, iID, sID u return count, nil } -func (p *Persister) DeleteExpiredSessions(ctx context.Context, expiresAt time.Time, limit int) error { - err := p.GetConnection(ctx).RawQuery(fmt.Sprintf( +func (p *Persister) DeleteExpiredSessions(ctx context.Context, expiresAt time.Time, limit int) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteExpiredSessions") + defer otelx.End(span, &err) + err = p.GetConnection(ctx).RawQuery(fmt.Sprintf( "DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )", "sessions", "sessions", diff --git a/persistence/sql/persister_settings.go b/persistence/sql/persister_settings.go index 55635ef9931..a511831334d 100644 --- a/persistence/sql/persister_settings.go +++ b/persistence/sql/persister_settings.go @@ -8,6 +8,8 @@ import ( "fmt" "time" + "github.com/ory/kratos/identity" + "github.com/gofrs/uuid" "github.com/ory/x/sqlcon" @@ -37,7 +39,7 @@ func (p *Persister) GetSettingsFlow(ctx context.Context, id uuid.UUID) (*setting return nil, sqlcon.HandleError(err) } - r.Identity, err = p.GetIdentity(ctx, r.IdentityID) + r.Identity, err = p.GetIdentity(ctx, r.IdentityID, identity.ExpandDefault) if err != nil { return nil, err } diff --git a/persistence/sql/persister_test.go b/persistence/sql/persister_test.go index 65e2b904bea..258bb4e1702 100644 --- a/persistence/sql/persister_test.go +++ b/persistence/sql/persister_test.go @@ -7,11 +7,9 @@ import ( "context" "fmt" "os" - "path/filepath" "sync" "testing" - - "github.com/ory/x/dbal" + "time" "github.com/ory/kratos/driver/config" "github.com/ory/kratos/schema" @@ -23,7 +21,6 @@ import ( "github.com/go-errors/errors" "github.com/gobuffalo/pop/v6" "github.com/gobuffalo/pop/v6/logging" - "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -52,11 +49,12 @@ import ( "github.com/ory/x/sqlcon/dockertest" ) -var sqlite = fmt.Sprintf("sqlite3://%s.sqlite?_fk=true&mode=rwc", filepath.Join(os.TempDir(), uuid.New().String())) - func init() { corpx.RegisterFakes() - // op.Debug = true + pop.SetNowFunc(func() time.Time { + return time.Now().UTC().Round(time.Second) + }) + //pop.Debug = true } // nolint:staticcheck @@ -99,7 +97,7 @@ func pl(t *testing.T) func(lvl logging.Level, s string, args ...interface{}) { } func createCleanDatabases(t *testing.T) map[string]*driver.RegistryDefault { - conns := map[string]string{"sqlite": dbal.NewSQLiteTestDatabase(t)} + conns := map[string]string{"sqlite": "sqlite://file:" + t.TempDir() + "/db.sqlite?_fk=true"} var l sync.Mutex if !testing.Short() { @@ -131,7 +129,7 @@ func createCleanDatabases(t *testing.T) map[string]*driver.RegistryDefault { for name, dsn := range conns { go func(name, dsn string) { defer wg.Done() - t.Logf("Connecting to %s", name) + t.Logf("Connecting to %s: %s", name, dsn) _, reg := internal.NewRegistryDefaultWithDSN(t, dsn) p := reg.Persister().(*sql.Persister) @@ -220,7 +218,7 @@ func TestPersister(t *testing.T) { t.Run("contract=identity.TestPool", func(t *testing.T) { pop.SetLogger(pl(t)) - identity.TestPool(ctx, conf, p)(t) + identity.TestPool(ctx, conf, p, reg.IdentityManager())(t) }) t.Run("contract=registration.TestFlowPersister", func(t *testing.T) { pop.SetLogger(pl(t)) @@ -299,7 +297,7 @@ func TestPersister_Transaction(t *testing.T) { }) require.Error(t, err) assert.Contains(t, err.Error(), errMessage) - _, err = p.GetIdentity(context.Background(), i.ID) + _, err = p.GetIdentity(context.Background(), i.ID, ri.ExpandNothing) require.Error(t, err) assert.Equal(t, sqlcon.ErrNoRows.Error(), err.Error()) }) diff --git a/persistence/sql/stub/expand.schema.json b/persistence/sql/stub/expand.schema.json new file mode 100644 index 00000000000..5cb8a8eaba0 --- /dev/null +++ b/persistence/sql/stub/expand.schema.json @@ -0,0 +1,49 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": false, + "properties": { + "traits": { + "additionalProperties": false, + "properties": { + "email": { + "format": "email", + "ory.sh/kratos": { + "credentials": { + "password": { + "identifier": true + }, + "webauthn": { + "identifier": true + }, + "totp": { + "account_name": true + } + }, + "recovery": { + "via": "email" + }, + "verification": { + "via": "email" + } + }, + "title": "Email address", + "type": "string", + "maxLength": 320 + }, + "name": { + "minLength": 1, + "title": "Name", + "type": "string", + "maxLength": 256 + } + }, + "required": [ + "email", + "name" + ], + "type": "object" + } + }, + "title": "Person", + "type": "object" +} diff --git a/selfservice/flow/registration/hook_test.go b/selfservice/flow/registration/hook_test.go index 513c714ca95..5a95188a5d3 100644 --- a/selfservice/flow/registration/hook_test.go +++ b/selfservice/flow/registration/hook_test.go @@ -77,7 +77,7 @@ func TestRegistrationExecutor(t *testing.T) { assert.EqualValues(t, http.StatusOK, res.StatusCode) assert.EqualValues(t, "https://www.ory.sh/", res.Request.URL.String()) - actual, err := reg.IdentityPool().GetIdentity(context.Background(), i.ID) + actual, err := reg.IdentityPool().GetIdentity(context.Background(), i.ID, identity.ExpandNothing) require.NoError(t, err) assert.Equal(t, actual.Traits, i.Traits) }) @@ -100,7 +100,7 @@ func TestRegistrationExecutor(t *testing.T) { assert.EqualValues(t, http.StatusOK, res.StatusCode) assert.Equal(t, "", body) - _, err := reg.IdentityPool().GetIdentity(context.Background(), i.ID) + _, err := reg.IdentityPool().GetIdentity(context.Background(), i.ID, identity.ExpandNothing) require.Error(t, err) }) diff --git a/selfservice/hook/verification_test.go b/selfservice/hook/verification_test.go index 23228a9bd7c..67c444620d3 100644 --- a/selfservice/hook/verification_test.go +++ b/selfservice/hook/verification_test.go @@ -72,7 +72,7 @@ func TestVerifier(t *testing.T) { actual.VerifiedAt = &verifiedAt require.NoError(t, reg.PrivilegedIdentityPool().UpdateVerifiableAddress(context.Background(), actual)) - i, err = reg.IdentityPool().GetIdentity(context.Background(), i.ID) + i, err = reg.IdentityPool().GetIdentity(context.Background(), i.ID, identity.ExpandDefault) require.NoError(t, err) var originalFlow flow.Flow diff --git a/selfservice/strategy/code/code_sender.go b/selfservice/strategy/code/code_sender.go index d5ef8529267..ea265b4c2b2 100644 --- a/selfservice/strategy/code/code_sender.go +++ b/selfservice/strategy/code/code_sender.go @@ -77,7 +77,7 @@ func (s *Sender) SendRecoveryCode(ctx context.Context, r *http.Request, f *recov } // Get the identity associated with the recovery address - i, err := s.deps.IdentityPool().GetIdentity(ctx, address.IdentityID) + i, err := s.deps.IdentityPool().GetIdentity(ctx, address.IdentityID, identity.ExpandDefault) if err != nil { return err } @@ -160,7 +160,7 @@ func (s *Sender) SendVerificationCode(ctx context.Context, f *verification.Flow, } // Get the identity associated with the recovery address - i, err := s.deps.IdentityPool().GetIdentity(ctx, address.IdentityID) + i, err := s.deps.IdentityPool().GetIdentity(ctx, address.IdentityID, identity.ExpandDefault) if err != nil { return err } diff --git a/selfservice/strategy/code/strategy_recovery.go b/selfservice/strategy/code/strategy_recovery.go index 8c226759fee..b7ffc85320b 100644 --- a/selfservice/strategy/code/strategy_recovery.go +++ b/selfservice/strategy/code/strategy_recovery.go @@ -196,7 +196,7 @@ func (s *Strategy) createRecoveryCodeForIdentity(w http.ResponseWriter, r *http. return } - id, err := s.deps.IdentityPool().GetIdentity(ctx, p.IdentityID) + id, err := s.deps.IdentityPool().GetIdentity(ctx, p.IdentityID, identity.ExpandDefault) if notFoundErr := sqlcon.ErrNoRows; errors.As(err, ¬FoundErr) { s.deps.Writer().WriteError(w, r, notFoundErr.WithReasonf("could not find identity")) return @@ -427,7 +427,7 @@ func (s *Strategy) recoveryUseCode(w http.ResponseWriter, r *http.Request, body return s.retryRecoveryFlowWithError(w, r, f.Type, err) } - recovered, err := s.deps.IdentityPool().GetIdentity(ctx, code.IdentityID) + recovered, err := s.deps.IdentityPool().GetIdentity(ctx, code.IdentityID, identity.ExpandDefault) if err != nil { return s.HandleRecoveryError(w, r, f, nil, err) } diff --git a/selfservice/strategy/code/strategy_verification.go b/selfservice/strategy/code/strategy_verification.go index 22b2cfd25e0..0868bfca7b5 100644 --- a/selfservice/strategy/code/strategy_verification.go +++ b/selfservice/strategy/code/strategy_verification.go @@ -259,7 +259,7 @@ func (s *Strategy) verificationUseCode(w http.ResponseWriter, r *http.Request, c return s.retryVerificationFlowWithError(w, r, f.Type, err) } - i, err := s.deps.IdentityPool().GetIdentity(r.Context(), code.VerifiableAddress.IdentityID) + i, err := s.deps.IdentityPool().GetIdentity(r.Context(), code.VerifiableAddress.IdentityID, identity.ExpandDefault) if err != nil { return s.retryVerificationFlowWithError(w, r, f.Type, err) } diff --git a/selfservice/strategy/link/sender.go b/selfservice/strategy/link/sender.go index 13b603a7d16..c50a5ff937e 100644 --- a/selfservice/strategy/link/sender.go +++ b/selfservice/strategy/link/sender.go @@ -77,7 +77,7 @@ func (s *Sender) SendRecoveryLink(ctx context.Context, r *http.Request, f *recov } // Get the identity associated with the recovery address - i, err := s.r.IdentityPool().GetIdentity(ctx, address.IdentityID) + i, err := s.r.IdentityPool().GetIdentity(ctx, address.IdentityID, identity.ExpandDefault) if err != nil { return err } @@ -119,7 +119,7 @@ func (s *Sender) SendVerificationLink(ctx context.Context, f *verification.Flow, } // Get the identity associated with the recovery address - i, err := s.r.IdentityPool().GetIdentity(ctx, address.IdentityID) + i, err := s.r.IdentityPool().GetIdentity(ctx, address.IdentityID, identity.ExpandDefault) if err != nil { return err } diff --git a/selfservice/strategy/link/strategy_recovery.go b/selfservice/strategy/link/strategy_recovery.go index 967aaba8949..62267147ec4 100644 --- a/selfservice/strategy/link/strategy_recovery.go +++ b/selfservice/strategy/link/strategy_recovery.go @@ -171,7 +171,7 @@ func (s *Strategy) createRecoveryLinkForIdentity(w http.ResponseWriter, r *http. return } - id, err := s.d.IdentityPool().GetIdentity(r.Context(), p.IdentityID) + id, err := s.d.IdentityPool().GetIdentity(r.Context(), p.IdentityID, identity.ExpandDefault) if errors.Is(err, sqlcon.ErrNoRows) { s.d.Writer().WriteError(w, r, errors.WithStack(herodot.ErrBadRequest.WithReasonf("The requested identity id does not exist.").WithWrap(err))) return @@ -350,7 +350,7 @@ func (s *Strategy) recoveryUseToken(w http.ResponseWriter, r *http.Request, fID return s.retryRecoveryFlowWithError(w, r, flow.TypeBrowser, err) } - recovered, err := s.d.IdentityPool().GetIdentity(r.Context(), token.IdentityID) + recovered, err := s.d.IdentityPool().GetIdentity(r.Context(), token.IdentityID, identity.ExpandDefault) if err != nil { return s.HandleRecoveryError(w, r, f, nil, err) } diff --git a/selfservice/strategy/link/strategy_verification.go b/selfservice/strategy/link/strategy_verification.go index 12f296233af..8b11c4213cb 100644 --- a/selfservice/strategy/link/strategy_verification.go +++ b/selfservice/strategy/link/strategy_verification.go @@ -208,7 +208,7 @@ func (s *Strategy) verificationUseToken(w http.ResponseWriter, r *http.Request, return s.retryVerificationFlowWithError(w, r, flow.TypeBrowser, err) } - i, err := s.d.IdentityPool().GetIdentity(r.Context(), token.VerifiableAddress.IdentityID) + i, err := s.d.IdentityPool().GetIdentity(r.Context(), token.VerifiableAddress.IdentityID, identity.ExpandDefault) if err != nil { return s.retryVerificationFlowWithError(w, r, flow.TypeBrowser, err) } diff --git a/selfservice/strategy/link/test/persistence.go b/selfservice/strategy/link/test/persistence.go index 1a6b63b34a4..ae669a88845 100644 --- a/selfservice/strategy/link/test/persistence.go +++ b/selfservice/strategy/link/test/persistence.go @@ -106,7 +106,7 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface { expected, f := newRecoveryToken(t, "some-user@ory.sh") require.NoError(t, p.CreateRecoveryToken(ctx, expected)) - id, err := p.GetIdentity(ctx, expected.IdentityID) + id, err := p.GetIdentity(ctx, expected.IdentityID, identity.ExpandDefault) require.NoError(t, err) require.NoError(t, p.UpdateIdentity(ctx, id)) diff --git a/session/manager_http.go b/session/manager_http.go index 1cef8d12485..84ea691a359 100644 --- a/session/manager_http.go +++ b/session/manager_http.go @@ -9,6 +9,8 @@ import ( "net/url" "time" + "github.com/ory/x/otelx" + "github.com/ory/x/randx" "github.com/gorilla/sessions" @@ -37,6 +39,7 @@ type ( identity.ManagementProvider x.CookieProvider x.CSRFProvider + x.TracingProvider PersistenceProvider } ManagerHTTP struct { @@ -54,7 +57,10 @@ func NewManagerHTTP(r managerHTTPDependencies) *ManagerHTTP { } } -func (s *ManagerHTTP) UpsertAndIssueCookie(ctx context.Context, w http.ResponseWriter, r *http.Request, ss *Session) error { +func (s *ManagerHTTP) UpsertAndIssueCookie(ctx context.Context, w http.ResponseWriter, r *http.Request, ss *Session) (err error) { + ctx, span := s.r.Tracer(ctx).Tracer().Start(ctx, "sessions.ManagerHTTP.UpsertAndIssueCookie") + defer otelx.End(span, &err) + if err := s.r.SessionPersister().UpsertSession(ctx, ss); err != nil { return err } @@ -66,7 +72,10 @@ func (s *ManagerHTTP) UpsertAndIssueCookie(ctx context.Context, w http.ResponseW return nil } -func (s *ManagerHTTP) RefreshCookie(ctx context.Context, w http.ResponseWriter, r *http.Request, session *Session) error { +func (s *ManagerHTTP) RefreshCookie(ctx context.Context, w http.ResponseWriter, r *http.Request, session *Session) (err error) { + ctx, span := s.r.Tracer(ctx).Tracer().Start(ctx, "sessions.ManagerHTTP.RefreshCookie") + defer otelx.End(span, &err) + // If it is a session token there is nothing to do. _, cookieErr := r.Cookie(s.cookieName(r.Context())) if errors.Is(cookieErr, http.ErrNoCookie) { @@ -88,7 +97,10 @@ func (s *ManagerHTTP) RefreshCookie(ctx context.Context, w http.ResponseWriter, return nil } -func (s *ManagerHTTP) IssueCookie(ctx context.Context, w http.ResponseWriter, r *http.Request, session *Session) error { +func (s *ManagerHTTP) IssueCookie(ctx context.Context, w http.ResponseWriter, r *http.Request, session *Session) (err error) { + ctx, span := s.r.Tracer(ctx).Tracer().Start(ctx, "sessions.ManagerHTTP.IssueCookie") + defer otelx.End(span, &err) + cookie, err := s.r.CookieManager(r.Context()).Get(r, s.cookieName(ctx)) // Fix for https://github.com/ory/kratos/issues/1695 if err != nil && cookie == nil { @@ -159,6 +171,9 @@ func (s *ManagerHTTP) getCookie(r *http.Request) (*sessions.Session, error) { } func (s *ManagerHTTP) extractToken(r *http.Request) string { + _, span := s.r.Tracer(r.Context()).Tracer().Start(r.Context(), "sessions.ManagerHTTP.extractToken") + defer span.End() + if token := r.Header.Get("X-Session-Token"); len(token) > 0 { return token } @@ -178,13 +193,16 @@ func (s *ManagerHTTP) extractToken(r *http.Request) string { return token } -func (s *ManagerHTTP) FetchFromRequest(ctx context.Context, r *http.Request) (*Session, error) { +func (s *ManagerHTTP) FetchFromRequest(ctx context.Context, r *http.Request) (_ *Session, err error) { + ctx, span := s.r.Tracer(ctx).Tracer().Start(ctx, "sessions.ManagerHTTP.FetchFromRequest") + defer otelx.End(span, &err) + token := s.extractToken(r) if token == "" { return nil, errors.WithStack(NewErrNoCredentialsForSession()) } - se, err := s.r.SessionPersister().GetSessionByToken(ctx, token, ExpandEverything) + se, err := s.r.SessionPersister().GetSessionByToken(ctx, token, ExpandEverything, identity.ExpandDefault) if err != nil { if errors.Is(err, herodot.ErrNotFound) || errors.Is(err, sqlcon.ErrNoRows) { return nil, errors.WithStack(NewErrNoActiveSessionFound()) @@ -200,7 +218,10 @@ func (s *ManagerHTTP) FetchFromRequest(ctx context.Context, r *http.Request) (*S return se, nil } -func (s *ManagerHTTP) PurgeFromRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) error { +func (s *ManagerHTTP) PurgeFromRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (err error) { + ctx, span := s.r.Tracer(ctx).Tracer().Start(ctx, "sessions.ManagerHTTP.PurgeFromRequest") + defer otelx.End(span, &err) + if token, ok := bearerTokenFromRequest(r); ok { return errors.WithStack(s.r.SessionPersister().RevokeSessionByToken(ctx, token)) } @@ -222,7 +243,10 @@ func (s *ManagerHTTP) PurgeFromRequest(ctx context.Context, w http.ResponseWrite return nil } -func (s *ManagerHTTP) DoesSessionSatisfy(r *http.Request, sess *Session, requestedAAL string) error { +func (s *ManagerHTTP) DoesSessionSatisfy(r *http.Request, sess *Session, requestedAAL string) (err error) { + _, span := s.r.Tracer(r.Context()).Tracer().Start(r.Context(), "sessions.ManagerHTTP.DoesSessionSatisfy") + defer otelx.End(span, &err) + sess.SetAuthenticatorAssuranceLevel() switch requestedAAL { case string(identity.AuthenticatorAssuranceLevel1): @@ -230,19 +254,23 @@ func (s *ManagerHTTP) DoesSessionSatisfy(r *http.Request, sess *Session, request return nil } case config.HighestAvailableAAL: - i, err := s.r.PrivilegedIdentityPool().GetIdentityConfidential(r.Context(), sess.IdentityID) - if err != nil { - return err + i := *sess.Identity + + // If credentials are not expanded, we load them here. + if len(i.Credentials) == 0 { + if err := s.r.PrivilegedIdentityPool().HydrateIdentityAssociations(r.Context(), &i, identity.ExpandCredentials); err != nil { + return err + } } available := identity.NoAuthenticatorAssuranceLevel - if firstCount, err := s.r.IdentityManager().CountActiveFirstFactorCredentials(r.Context(), i); err != nil { + if firstCount, err := s.r.IdentityManager().CountActiveFirstFactorCredentials(r.Context(), &i); err != nil { return err } else if firstCount > 0 { available = identity.AuthenticatorAssuranceLevel1 } - if secondCount, err := s.r.IdentityManager().CountActiveMultiFactorCredentials(r.Context(), i); err != nil { + if secondCount, err := s.r.IdentityManager().CountActiveMultiFactorCredentials(r.Context(), &i); err != nil { return err } else if secondCount > 0 { available = identity.AuthenticatorAssuranceLevel2 @@ -255,10 +283,14 @@ func (s *ManagerHTTP) DoesSessionSatisfy(r *http.Request, sess *Session, request return NewErrAALNotSatisfied( urlx.CopyWithQuery(urlx.AppendPaths(s.r.Config().SelfPublicURL(r.Context()), "/self-service/login/browser"), url.Values{"aal": {"aal2"}}).String()) } + return errors.Errorf("requested unknown aal: %s", requestedAAL) } -func (s *ManagerHTTP) SessionAddAuthenticationMethods(ctx context.Context, sid uuid.UUID, ams ...AuthenticationMethod) error { +func (s *ManagerHTTP) SessionAddAuthenticationMethods(ctx context.Context, sid uuid.UUID, ams ...AuthenticationMethod) (err error) { + ctx, span := s.r.Tracer(ctx).Tracer().Start(ctx, "sessions.ManagerHTTP.SessionAddAuthenticationMethods") + defer otelx.End(span, &err) + // Since we added the method, it also means that we have authenticated it sess, err := s.r.SessionPersister().GetSession(ctx, sid, ExpandNothing) if err != nil { diff --git a/session/persistence.go b/session/persistence.go index 5c40c0890b3..ef99ec486e3 100644 --- a/session/persistence.go +++ b/session/persistence.go @@ -7,6 +7,10 @@ import ( "context" "time" + "github.com/gobuffalo/pop/v6" + + "github.com/ory/kratos/identity" + "github.com/ory/x/pagination/keysetpagination" "github.com/gofrs/uuid" @@ -17,6 +21,8 @@ type PersistenceProvider interface { } type Persister interface { + GetConnection(ctx context.Context) *pop.Connection + // GetSession retrieves a session from the store. GetSession(ctx context.Context, sid uuid.UUID, expandables Expandables) (*Session, error) @@ -39,7 +45,7 @@ type Persister interface { // // Functionality is similar to GetSession but accepts a session token // instead of a session ID. - GetSessionByToken(ctx context.Context, token string, expandables Expandables) (*Session, error) + GetSessionByToken(ctx context.Context, token string, expandables Expandables, identityExpandables identity.Expandables) (*Session, error) // DeleteExpiredSessions deletes sessions that expired before the given time. DeleteExpiredSessions(context.Context, time.Time, int) error diff --git a/session/test/persistence.go b/session/test/persistence.go index 64316c610cd..430b27c9a6a 100644 --- a/session/test/persistence.go +++ b/session/test/persistence.go @@ -102,13 +102,13 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface { }) t.Run("method=get by token", func(t *testing.T) { - sess, err := p.GetSessionByToken(ctx, expected.Token, session.ExpandEverything) + sess, err := p.GetSessionByToken(ctx, expected.Token, session.ExpandEverything, identity.ExpandDefault) check(sess, err) checkDevices(sess.Devices, err) t.Run("on another network", func(t *testing.T) { _, p := testhelpers.NewNetwork(t, ctx, p) - _, err := p.GetSessionByToken(ctx, expected.Token, session.ExpandNothing) + _, err := p.GetSessionByToken(ctx, expected.Token, session.ExpandNothing, identity.ExpandDefault) assert.ErrorIs(t, err, sqlcon.ErrNoRows) }) }) @@ -117,7 +117,7 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface { expected.AuthenticatorAssuranceLevel = identity.AuthenticatorAssuranceLevel3 require.NoError(t, p.UpsertSession(ctx, &expected)) - actual, err := p.GetSessionByToken(ctx, expected.Token, session.ExpandDefault) + actual, err := p.GetSessionByToken(ctx, expected.Token, session.ExpandDefault, identity.ExpandDefault) check(actual, err) assert.Equal(t, identity.AuthenticatorAssuranceLevel3, actual.AuthenticatorAssuranceLevel) }) @@ -126,7 +126,7 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface { expected.AMR = nil require.NoError(t, p.UpsertSession(ctx, &expected)) - actual, err := p.GetSessionByToken(ctx, expected.Token, session.ExpandDefault) + actual, err := p.GetSessionByToken(ctx, expected.Token, session.ExpandDefault, identity.ExpandDefault) check(actual, err) assert.Empty(t, actual.AMR) }) @@ -376,7 +376,7 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface { err := other.DeleteSessionByToken(ctx, expected.Token) assert.ErrorIs(t, err, sqlcon.ErrNoRows) - _, err = p.GetSessionByToken(ctx, expected.Token, session.ExpandNothing) + _, err = p.GetSessionByToken(ctx, expected.Token, session.ExpandNothing, identity.ExpandDefault) assert.NoError(t, err) }) @@ -584,9 +584,9 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface { _, err = p.GetSession(ctx, sid2, session.ExpandNothing) require.ErrorIs(t, err, sqlcon.ErrNoRows) - _, err = p.GetSessionByToken(ctx, t1, session.ExpandNothing) + _, err = p.GetSessionByToken(ctx, t1, session.ExpandNothing, identity.ExpandDefault) require.NoError(t, err) - _, err = p.GetSessionByToken(ctx, t2, session.ExpandNothing) + _, err = p.GetSessionByToken(ctx, t2, session.ExpandNothing, identity.ExpandDefault) require.ErrorIs(t, err, sqlcon.ErrNoRows) }) }