From 8ae8783935292fb011b1018ac7417ed77eb6abb7 Mon Sep 17 00:00:00 2001 From: hackerman <3372410+aeneasr@users.noreply.github.com> Date: Tue, 30 May 2023 13:37:55 +0200 Subject: [PATCH] fix: use RETURNING clause for batch create (#3293) --- ...t_buildInsertQueryArgs-case=cockroach.json | 15 ++ ...yValues-case=testModel-case=cockroach.json | 10 ++ persistence/sql/batch/create.go | 149 ++++++++++++++++-- persistence/sql/batch/create_test.go | 77 ++++++--- .../sql/identity/persister_identity.go | 3 +- script/testenv.sh | 2 +- 6 files changed, 222 insertions(+), 34 deletions(-) create mode 100644 persistence/sql/batch/.snapshots/Test_buildInsertQueryArgs-case=cockroach.json create mode 100644 persistence/sql/batch/.snapshots/Test_buildInsertQueryValues-case=testModel-case=cockroach.json diff --git a/persistence/sql/batch/.snapshots/Test_buildInsertQueryArgs-case=cockroach.json b/persistence/sql/batch/.snapshots/Test_buildInsertQueryArgs-case=cockroach.json new file mode 100644 index 00000000000..4fc722f33af --- /dev/null +++ b/persistence/sql/batch/.snapshots/Test_buildInsertQueryArgs-case=cockroach.json @@ -0,0 +1,15 @@ +{ + "TableName": "\"test_models\"", + "ColumnsDecl": "\"created_at\", \"id\", \"int\", \"nid\", \"null_time_ptr\", \"string\", \"traits\", \"updated_at\"", + "Columns": [ + "created_at", + "id", + "int", + "nid", + "null_time_ptr", + "string", + "traits", + "updated_at" + ], + "Placeholders": "(?, ?, ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?)" +} diff --git a/persistence/sql/batch/.snapshots/Test_buildInsertQueryValues-case=testModel-case=cockroach.json b/persistence/sql/batch/.snapshots/Test_buildInsertQueryValues-case=testModel-case=cockroach.json new file mode 100644 index 00000000000..9c8e755cacd --- /dev/null +++ b/persistence/sql/batch/.snapshots/Test_buildInsertQueryValues-case=testModel-case=cockroach.json @@ -0,0 +1,10 @@ +[ + "0001-01-01T00:00:00Z", + "0001-01-01T00:00:00Z", + "string", + 42, + null, + { + "foo": "bar" + } +] diff --git a/persistence/sql/batch/create.go b/persistence/sql/batch/create.go index 4d04331fedc..801dcbdd96b 100644 --- a/persistence/sql/batch/create.go +++ b/persistence/sql/batch/create.go @@ -5,15 +5,19 @@ package batch import ( "context" + "database/sql" "fmt" "reflect" "sort" "strings" "time" + "github.com/jmoiron/sqlx/reflectx" + + "github.com/ory/x/dbal" + "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" - "github.com/jmoiron/sqlx/reflectx" "github.com/pkg/errors" "github.com/ory/x/otelx" @@ -38,7 +42,7 @@ type ( } ) -func buildInsertQueryArgs[T any](ctx context.Context, quoter quoter, models []*T) insertQueryArgs { +func buildInsertQueryArgs[T any](ctx context.Context, dialect string, mapper *reflectx.Mapper, quoter quoter, models []*T) insertQueryArgs { var ( v T model = pop.NewModel(v, ctx) @@ -60,8 +64,41 @@ func buildInsertQueryArgs[T any](ctx context.Context, quoter quoter, models []*T for _, col := range columns { quotedColumns = append(quotedColumns, quoter.Quote(col)) } - for range models { - placeholders = append(placeholders, fmt.Sprintf("(%s)", strings.Join(placeholderRow, ", "))) + + // We generate a list (for every row one) of VALUE statements here that + // will be substituted by their column values later: + // + // (?, ?, ?, ?), + // (?, ?, ?, ?), + // (?, ?, ?, ?) + for _, m := range models { + m := reflect.ValueOf(m) + + pl := make([]string, len(placeholderRow)) + copy(pl, placeholderRow) + + // There is a special case - when using CockroachDB we want to generate + // UUIDs using "gen_random_uuid()" which ends up in a VALUE statement of: + // + // (gen_random_uuid(), ?, ?, ?), + for k := range placeholderRow { + if columns[k] != "id" { + continue + } + + field := mapper.FieldByName(m, columns[k]) + val, ok := field.Interface().(uuid.UUID) + if !ok { + continue + } + + if val == uuid.Nil && dialect == dbal.DriverCockroachDB { + pl[k] = "gen_random_uuid()" + break + } + } + + placeholders = append(placeholders, fmt.Sprintf("(%s)", strings.Join(pl, ", "))) } return insertQueryArgs{ @@ -72,12 +109,11 @@ func buildInsertQueryArgs[T any](ctx context.Context, quoter quoter, models []*T } } -func buildInsertQueryValues[T any](mapper *reflectx.Mapper, columns []string, models []*T) (values []any, err error) { - now := time.Now().UTC().Truncate(time.Microsecond) - +func buildInsertQueryValues[T any](dialect string, mapper *reflectx.Mapper, columns []string, models []*T, nowFunc func() time.Time) (values []any, err error) { for _, m := range models { m := reflect.ValueOf(m) + now := nowFunc() // Append model fields to args for _, c := range columns { field := mapper.FieldByName(m, c) @@ -89,17 +125,31 @@ func buildInsertQueryValues[T any](mapper *reflectx.Mapper, columns []string, mo } case "updated_at": field.Set(reflect.ValueOf(now)) - case "id": if field.Interface().(uuid.UUID) != uuid.Nil { break // breaks switch, not for + } else if dialect == dbal.DriverCockroachDB { + // This is a special case: + // 1. We're using cockroach + // 2. It's the primary key field ("ID") + // 3. A UUID was not yet set. + // + // If all these conditions meet, the VALUE statement will look as such: + // + // (gen_random_uuid(), ?, ?, ?, ...) + // + // For that reason, we do not add the ID value to the list of arguments, + // because one of the arguments is using a built-in and thus doesn't need a value. + continue // break switch, not for } + id, err := uuid.NewV4() if err != nil { return nil, err } field.Set(reflect.ValueOf(id)) } + values = append(values, field.Interface()) // Special-handling for *sqlxx.NullTime: mapper.FieldByName sets this to a zero time.Time, @@ -125,26 +175,101 @@ func Create[T any](ctx context.Context, p *TracerConnection, models []*T) (err e return nil } + var v T + model := pop.NewModel(v, ctx) + conn := p.Connection quoter, ok := conn.Dialect.(quoter) if !ok { return errors.Errorf("store is not a quoter: %T", conn.Store) } - queryArgs := buildInsertQueryArgs(ctx, quoter, models) - values, err := buildInsertQueryValues(conn.TX.Mapper, queryArgs.Columns, models) + queryArgs := buildInsertQueryArgs(ctx, conn.Dialect.Name(), conn.TX.Mapper, quoter, models) + values, err := buildInsertQueryValues(conn.Dialect.Name(), conn.TX.Mapper, queryArgs.Columns, models, func() time.Time { return time.Now().UTC().Truncate(time.Microsecond) }) if err != nil { return err } + var returningClause string + if conn.Dialect.Name() != dbal.DriverMySQL { + // PostgreSQL, CockroachDB, SQLite support RETURNING. + returningClause = fmt.Sprintf("RETURNING %s", model.IDField()) + } + query := conn.Dialect.TranslateSQL(fmt.Sprintf( - "INSERT INTO %s (%s) VALUES\n%s", + "INSERT INTO %s (%s) VALUES\n%s\n%s", queryArgs.TableName, queryArgs.ColumnsDecl, queryArgs.Placeholders, + returningClause, )) - _, err = conn.Store.ExecContext(ctx, query, values...) + rows, err := conn.TX.QueryContext(ctx, query, values...) + if err != nil { + return sqlcon.HandleError(err) + } + defer rows.Close() + + // Hydrate the models from the RETURNING clause. + // + // Databases not supporting RETURNING will just return 0 rows. + count := 0 + for rows.Next() { + if err := rows.Err(); err != nil { + return sqlcon.HandleError(err) + } + + if err := setModelID(rows, pop.NewModel(models[count], ctx)); err != nil { + return err + } + count++ + } + + if err := rows.Err(); err != nil { + return sqlcon.HandleError(err) + } + + if err := rows.Close(); err != nil { + return sqlcon.HandleError(err) + } return sqlcon.HandleError(err) } + +// setModelID was copy & pasted from pop. It basically sets +// the primary key to the given value read from the SQL row. +func setModelID(row *sql.Rows, model *pop.Model) error { + el := reflect.ValueOf(model.Value).Elem() + fbn := el.FieldByName("ID") + if !fbn.IsValid() { + return errors.New("model does not have a field named id") + } + + pkt, err := model.PrimaryKeyType() + if err != nil { + return errors.WithStack(err) + } + + switch pkt { + case "UUID": + var id uuid.UUID + if err := row.Scan(&id); err != nil { + return errors.WithStack(err) + } + fbn.Set(reflect.ValueOf(id)) + default: + var id interface{} + if err := row.Scan(&id); err != nil { + return errors.WithStack(err) + } + v := reflect.ValueOf(id) + switch fbn.Kind() { + case reflect.Int, reflect.Int64: + fbn.SetInt(v.Int()) + default: + fbn.Set(reflect.ValueOf(id)) + } + } + + return nil +} diff --git a/persistence/sql/batch/create_test.go b/persistence/sql/batch/create_test.go index f5d6ee47ce2..f5c81664a48 100644 --- a/persistence/sql/batch/create_test.go +++ b/persistence/sql/batch/create_test.go @@ -9,6 +9,8 @@ import ( "testing" "time" + "github.com/ory/x/dbal" + "github.com/gofrs/uuid" "github.com/jmoiron/sqlx/reflectx" "github.com/stretchr/testify/assert" @@ -39,11 +41,20 @@ func (i testModel) TableName(ctx context.Context) string { func (tq testQuoter) Quote(s string) string { return fmt.Sprintf("%q", s) } +func makeModels[T any]() []*T { + models := make([]*T, 10) + for k := range models { + models[k] = new(T) + } + return models +} + func Test_buildInsertQueryArgs(t *testing.T) { ctx := context.Background() t.Run("case=testModel", func(t *testing.T) { - models := make([]*testModel, 10) - args := buildInsertQueryArgs(ctx, testQuoter{}, models) + models := makeModels[testModel]() + mapper := reflectx.NewMapper("db") + args := buildInsertQueryArgs(ctx, "other", mapper, testQuoter{}, models) snapshotx.SnapshotT(t, args) query := fmt.Sprintf("INSERT INTO %s (%s) VALUES\n%s", args.TableName, args.ColumnsDecl, args.Placeholders) @@ -61,20 +72,35 @@ func Test_buildInsertQueryArgs(t *testing.T) { }) t.Run("case=Identities", func(t *testing.T) { - models := make([]*identity.Identity, 10) - args := buildInsertQueryArgs(ctx, testQuoter{}, models) + models := makeModels[identity.Identity]() + mapper := reflectx.NewMapper("db") + args := buildInsertQueryArgs(ctx, "other", mapper, testQuoter{}, models) snapshotx.SnapshotT(t, args) }) t.Run("case=RecoveryAddress", func(t *testing.T) { - models := make([]*identity.RecoveryAddress, 10) - args := buildInsertQueryArgs(ctx, testQuoter{}, models) + models := makeModels[identity.RecoveryAddress]() + mapper := reflectx.NewMapper("db") + args := buildInsertQueryArgs(ctx, "other", mapper, testQuoter{}, models) snapshotx.SnapshotT(t, args) }) t.Run("case=RecoveryAddress", func(t *testing.T) { - models := make([]*identity.RecoveryAddress, 10) - args := buildInsertQueryArgs(ctx, testQuoter{}, models) + models := makeModels[identity.RecoveryAddress]() + mapper := reflectx.NewMapper("db") + args := buildInsertQueryArgs(ctx, "other", mapper, testQuoter{}, models) + snapshotx.SnapshotT(t, args) + }) + + t.Run("case=cockroach", func(t *testing.T) { + models := makeModels[testModel]() + for k := range models { + if k%3 == 0 { + models[k].ID = uuid.FromStringOrNil(fmt.Sprintf("ae0125a9-2786-4ada-82d2-d169cf75047%d", k)) + } + } + mapper := reflectx.NewMapper("db") + args := buildInsertQueryArgs(ctx, "cockroach", mapper, testQuoter{}, models) snapshotx.SnapshotT(t, args) }) } @@ -87,21 +113,34 @@ func Test_buildInsertQueryValues(t *testing.T) { Traits: []byte(`{"foo": "bar"}`), } mapper := reflectx.NewMapper("db") - values, err := buildInsertQueryValues(mapper, []string{"created_at", "updated_at", "id", "string", "int", "null_time_ptr", "traits"}, []*testModel{model}) - require.NoError(t, err) - assert.NotNil(t, model.CreatedAt) - assert.Equal(t, model.CreatedAt, values[0]) + nowFunc := func() time.Time { + return time.Time{} + } + t.Run("case=cockroach", func(t *testing.T) { + values, err := buildInsertQueryValues(dbal.DriverCockroachDB, mapper, []string{"created_at", "updated_at", "id", "string", "int", "null_time_ptr", "traits"}, []*testModel{model}, nowFunc) + require.NoError(t, err) + snapshotx.SnapshotT(t, values) + }) + + t.Run("case=others", func(t *testing.T) { + values, err := buildInsertQueryValues("other", mapper, []string{"created_at", "updated_at", "id", "string", "int", "null_time_ptr", "traits"}, []*testModel{model}, nowFunc) + require.NoError(t, err) + + assert.NotNil(t, model.CreatedAt) + assert.Equal(t, model.CreatedAt, values[0]) + + assert.NotNil(t, model.UpdatedAt) + assert.Equal(t, model.UpdatedAt, values[1]) - assert.NotNil(t, model.UpdatedAt) - assert.Equal(t, model.UpdatedAt, values[1]) + assert.NotZero(t, model.ID) + assert.Equal(t, model.ID, values[2]) - assert.NotNil(t, model.ID) - assert.Equal(t, model.ID, values[2]) + assert.Equal(t, model.String, values[3]) + assert.Equal(t, model.Int, values[4]) - assert.Equal(t, model.String, values[3]) - assert.Equal(t, model.Int, values[4]) + assert.Nil(t, model.NullTimePtr) - assert.Nil(t, model.NullTimePtr) + }) }) } diff --git a/persistence/sql/identity/persister_identity.go b/persistence/sql/identity/persister_identity.go index a9074df8e84..4bba9e6faf4 100644 --- a/persistence/sql/identity/persister_identity.go +++ b/persistence/sql/identity/persister_identity.go @@ -297,9 +297,8 @@ func (p *IdentityPersister) createVerifiableAddresses(ctx context.Context, conn work = append(work, &id.VerifiableAddresses[i]) } } - err = batch.Create(ctx, &batch.TracerConnection{Tracer: p.r.Tracer(ctx), Connection: conn}, work) - return err + return batch.Create(ctx, &batch.TracerConnection{Tracer: p.r.Tracer(ctx), Connection: conn}, work) } func updateAssociation[T interface { diff --git a/script/testenv.sh b/script/testenv.sh index 4a01e5b3e71..f7a5467a0f7 100755 --- a/script/testenv.sh +++ b/script/testenv.sh @@ -1,7 +1,7 @@ #!/bin/bash docker rm -f kratos_test_database_mysql kratos_test_database_postgres kratos_test_database_cockroach kratos_test_hydra || true -docker run --platform linux/amd64 --name kratos_test_database_mysql -p 3444:3306 -e MYSQL_ROOT_PASSWORD=secret -d mysql:8.0.23 +docker run --platform linux/amd64 --name kratos_test_database_mysql -p 3444:3306 -e MYSQL_ROOT_PASSWORD=secret -d mysql:8.0.26 docker run --platform linux/amd64 --name kratos_test_database_postgres -p 3445:5432 -e POSTGRES_PASSWORD=secret -e POSTGRES_DB=postgres -d postgres:11.8 postgres -c log_statement=all docker run --platform linux/amd64 --name kratos_test_database_cockroach -p 3446:26257 -p 3447:8080 -d cockroachdb/cockroach:v22.2.6 start-single-node --insecure docker run --platform linux/amd64 --name kratos_test_hydra -p 4444:4444 -p 4445:4445 -d -e DSN=memory -e URLS_SELF_ISSUER=http://localhost:4444/ -e URLS_LOGIN=http://localhost:4446/login -e URLS_CONSENT=http://localhost:4446/consent oryd/hydra:v2.0.2 serve all --dev