Skip to content

Commit

Permalink
fix: use RETURNING clause for batch create (#3293)
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr committed May 30, 2023
1 parent 0f3cf22 commit 8ae8783
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 34 deletions.
@@ -0,0 +1,15 @@
{
"TableName": "\"test_models\"",
"ColumnsDecl": "\"created_at\", \"id\", \"int\", \"nid\", \"null_time_ptr\", \"string\", \"traits\", \"updated_at\"",
"Columns": [
"created_at",
"id",
"int",
"nid",
"null_time_ptr",
"string",
"traits",
"updated_at"
],
"Placeholders": "(?, ?, ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?)"
}
@@ -0,0 +1,10 @@
[
"0001-01-01T00:00:00Z",
"0001-01-01T00:00:00Z",
"string",
42,
null,
{
"foo": "bar"
}
]
149 changes: 137 additions & 12 deletions persistence/sql/batch/create.go
Expand Up @@ -5,15 +5,19 @@ package batch

import (
"context"
"database/sql"
"fmt"
"reflect"
"sort"
"strings"
"time"

"github.com/jmoiron/sqlx/reflectx"

"github.com/ory/x/dbal"

"github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"
"github.com/jmoiron/sqlx/reflectx"
"github.com/pkg/errors"

"github.com/ory/x/otelx"
Expand All @@ -38,7 +42,7 @@ type (
}
)

func buildInsertQueryArgs[T any](ctx context.Context, quoter quoter, models []*T) insertQueryArgs {
func buildInsertQueryArgs[T any](ctx context.Context, dialect string, mapper *reflectx.Mapper, quoter quoter, models []*T) insertQueryArgs {
var (
v T
model = pop.NewModel(v, ctx)
Expand All @@ -60,8 +64,41 @@ func buildInsertQueryArgs[T any](ctx context.Context, quoter quoter, models []*T
for _, col := range columns {
quotedColumns = append(quotedColumns, quoter.Quote(col))
}
for range models {
placeholders = append(placeholders, fmt.Sprintf("(%s)", strings.Join(placeholderRow, ", ")))

// We generate a list (for every row one) of VALUE statements here that
// will be substituted by their column values later:
//
// (?, ?, ?, ?),
// (?, ?, ?, ?),
// (?, ?, ?, ?)
for _, m := range models {
m := reflect.ValueOf(m)

pl := make([]string, len(placeholderRow))
copy(pl, placeholderRow)

// There is a special case - when using CockroachDB we want to generate
// UUIDs using "gen_random_uuid()" which ends up in a VALUE statement of:
//
// (gen_random_uuid(), ?, ?, ?),
for k := range placeholderRow {
if columns[k] != "id" {
continue
}

field := mapper.FieldByName(m, columns[k])
val, ok := field.Interface().(uuid.UUID)
if !ok {
continue
}

if val == uuid.Nil && dialect == dbal.DriverCockroachDB {
pl[k] = "gen_random_uuid()"
break
}
}

placeholders = append(placeholders, fmt.Sprintf("(%s)", strings.Join(pl, ", ")))
}

return insertQueryArgs{
Expand All @@ -72,12 +109,11 @@ func buildInsertQueryArgs[T any](ctx context.Context, quoter quoter, models []*T
}
}

func buildInsertQueryValues[T any](mapper *reflectx.Mapper, columns []string, models []*T) (values []any, err error) {
now := time.Now().UTC().Truncate(time.Microsecond)

func buildInsertQueryValues[T any](dialect string, mapper *reflectx.Mapper, columns []string, models []*T, nowFunc func() time.Time) (values []any, err error) {
for _, m := range models {
m := reflect.ValueOf(m)

now := nowFunc()
// Append model fields to args
for _, c := range columns {
field := mapper.FieldByName(m, c)
Expand All @@ -89,17 +125,31 @@ func buildInsertQueryValues[T any](mapper *reflectx.Mapper, columns []string, mo
}
case "updated_at":
field.Set(reflect.ValueOf(now))

case "id":
if field.Interface().(uuid.UUID) != uuid.Nil {
break // breaks switch, not for
} else if dialect == dbal.DriverCockroachDB {
// This is a special case:
// 1. We're using cockroach
// 2. It's the primary key field ("ID")
// 3. A UUID was not yet set.
//
// If all these conditions meet, the VALUE statement will look as such:
//
// (gen_random_uuid(), ?, ?, ?, ...)
//
// For that reason, we do not add the ID value to the list of arguments,
// because one of the arguments is using a built-in and thus doesn't need a value.
continue // break switch, not for
}

id, err := uuid.NewV4()
if err != nil {
return nil, err
}
field.Set(reflect.ValueOf(id))
}

values = append(values, field.Interface())

// Special-handling for *sqlxx.NullTime: mapper.FieldByName sets this to a zero time.Time,
Expand All @@ -125,26 +175,101 @@ func Create[T any](ctx context.Context, p *TracerConnection, models []*T) (err e
return nil
}

var v T
model := pop.NewModel(v, ctx)

conn := p.Connection
quoter, ok := conn.Dialect.(quoter)
if !ok {
return errors.Errorf("store is not a quoter: %T", conn.Store)
}

queryArgs := buildInsertQueryArgs(ctx, quoter, models)
values, err := buildInsertQueryValues(conn.TX.Mapper, queryArgs.Columns, models)
queryArgs := buildInsertQueryArgs(ctx, conn.Dialect.Name(), conn.TX.Mapper, quoter, models)
values, err := buildInsertQueryValues(conn.Dialect.Name(), conn.TX.Mapper, queryArgs.Columns, models, func() time.Time { return time.Now().UTC().Truncate(time.Microsecond) })
if err != nil {
return err
}

var returningClause string
if conn.Dialect.Name() != dbal.DriverMySQL {
// PostgreSQL, CockroachDB, SQLite support RETURNING.
returningClause = fmt.Sprintf("RETURNING %s", model.IDField())
}

query := conn.Dialect.TranslateSQL(fmt.Sprintf(
"INSERT INTO %s (%s) VALUES\n%s",
"INSERT INTO %s (%s) VALUES\n%s\n%s",
queryArgs.TableName,
queryArgs.ColumnsDecl,
queryArgs.Placeholders,
returningClause,
))

_, err = conn.Store.ExecContext(ctx, query, values...)
rows, err := conn.TX.QueryContext(ctx, query, values...)
if err != nil {
return sqlcon.HandleError(err)
}
defer rows.Close()

// Hydrate the models from the RETURNING clause.
//
// Databases not supporting RETURNING will just return 0 rows.
count := 0
for rows.Next() {
if err := rows.Err(); err != nil {
return sqlcon.HandleError(err)
}

if err := setModelID(rows, pop.NewModel(models[count], ctx)); err != nil {
return err
}
count++
}

if err := rows.Err(); err != nil {
return sqlcon.HandleError(err)
}

if err := rows.Close(); err != nil {
return sqlcon.HandleError(err)
}

return sqlcon.HandleError(err)
}

// setModelID was copy & pasted from pop. It basically sets
// the primary key to the given value read from the SQL row.
func setModelID(row *sql.Rows, model *pop.Model) error {
el := reflect.ValueOf(model.Value).Elem()
fbn := el.FieldByName("ID")
if !fbn.IsValid() {
return errors.New("model does not have a field named id")
}

pkt, err := model.PrimaryKeyType()
if err != nil {
return errors.WithStack(err)
}

switch pkt {
case "UUID":
var id uuid.UUID
if err := row.Scan(&id); err != nil {
return errors.WithStack(err)
}
fbn.Set(reflect.ValueOf(id))
default:
var id interface{}
if err := row.Scan(&id); err != nil {
return errors.WithStack(err)
}
v := reflect.ValueOf(id)
switch fbn.Kind() {
case reflect.Int, reflect.Int64:
fbn.SetInt(v.Int())
default:
fbn.Set(reflect.ValueOf(id))
}
}

return nil
}
77 changes: 58 additions & 19 deletions persistence/sql/batch/create_test.go
Expand Up @@ -9,6 +9,8 @@ import (
"testing"
"time"

"github.com/ory/x/dbal"

"github.com/gofrs/uuid"
"github.com/jmoiron/sqlx/reflectx"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -39,11 +41,20 @@ func (i testModel) TableName(ctx context.Context) string {

func (tq testQuoter) Quote(s string) string { return fmt.Sprintf("%q", s) }

func makeModels[T any]() []*T {
models := make([]*T, 10)
for k := range models {
models[k] = new(T)
}
return models
}

func Test_buildInsertQueryArgs(t *testing.T) {
ctx := context.Background()
t.Run("case=testModel", func(t *testing.T) {
models := make([]*testModel, 10)
args := buildInsertQueryArgs(ctx, testQuoter{}, models)
models := makeModels[testModel]()
mapper := reflectx.NewMapper("db")
args := buildInsertQueryArgs(ctx, "other", mapper, testQuoter{}, models)
snapshotx.SnapshotT(t, args)

query := fmt.Sprintf("INSERT INTO %s (%s) VALUES\n%s", args.TableName, args.ColumnsDecl, args.Placeholders)
Expand All @@ -61,20 +72,35 @@ func Test_buildInsertQueryArgs(t *testing.T) {
})

t.Run("case=Identities", func(t *testing.T) {
models := make([]*identity.Identity, 10)
args := buildInsertQueryArgs(ctx, testQuoter{}, models)
models := makeModels[identity.Identity]()
mapper := reflectx.NewMapper("db")
args := buildInsertQueryArgs(ctx, "other", mapper, testQuoter{}, models)
snapshotx.SnapshotT(t, args)
})

t.Run("case=RecoveryAddress", func(t *testing.T) {
models := make([]*identity.RecoveryAddress, 10)
args := buildInsertQueryArgs(ctx, testQuoter{}, models)
models := makeModels[identity.RecoveryAddress]()
mapper := reflectx.NewMapper("db")
args := buildInsertQueryArgs(ctx, "other", mapper, testQuoter{}, models)
snapshotx.SnapshotT(t, args)
})

t.Run("case=RecoveryAddress", func(t *testing.T) {
models := make([]*identity.RecoveryAddress, 10)
args := buildInsertQueryArgs(ctx, testQuoter{}, models)
models := makeModels[identity.RecoveryAddress]()
mapper := reflectx.NewMapper("db")
args := buildInsertQueryArgs(ctx, "other", mapper, testQuoter{}, models)
snapshotx.SnapshotT(t, args)
})

t.Run("case=cockroach", func(t *testing.T) {
models := makeModels[testModel]()
for k := range models {
if k%3 == 0 {
models[k].ID = uuid.FromStringOrNil(fmt.Sprintf("ae0125a9-2786-4ada-82d2-d169cf75047%d", k))
}
}
mapper := reflectx.NewMapper("db")
args := buildInsertQueryArgs(ctx, "cockroach", mapper, testQuoter{}, models)
snapshotx.SnapshotT(t, args)
})
}
Expand All @@ -87,21 +113,34 @@ func Test_buildInsertQueryValues(t *testing.T) {
Traits: []byte(`{"foo": "bar"}`),
}
mapper := reflectx.NewMapper("db")
values, err := buildInsertQueryValues(mapper, []string{"created_at", "updated_at", "id", "string", "int", "null_time_ptr", "traits"}, []*testModel{model})
require.NoError(t, err)

assert.NotNil(t, model.CreatedAt)
assert.Equal(t, model.CreatedAt, values[0])
nowFunc := func() time.Time {
return time.Time{}
}
t.Run("case=cockroach", func(t *testing.T) {
values, err := buildInsertQueryValues(dbal.DriverCockroachDB, mapper, []string{"created_at", "updated_at", "id", "string", "int", "null_time_ptr", "traits"}, []*testModel{model}, nowFunc)
require.NoError(t, err)
snapshotx.SnapshotT(t, values)
})

t.Run("case=others", func(t *testing.T) {
values, err := buildInsertQueryValues("other", mapper, []string{"created_at", "updated_at", "id", "string", "int", "null_time_ptr", "traits"}, []*testModel{model}, nowFunc)
require.NoError(t, err)

assert.NotNil(t, model.CreatedAt)
assert.Equal(t, model.CreatedAt, values[0])

assert.NotNil(t, model.UpdatedAt)
assert.Equal(t, model.UpdatedAt, values[1])

assert.NotNil(t, model.UpdatedAt)
assert.Equal(t, model.UpdatedAt, values[1])
assert.NotZero(t, model.ID)
assert.Equal(t, model.ID, values[2])

assert.NotNil(t, model.ID)
assert.Equal(t, model.ID, values[2])
assert.Equal(t, model.String, values[3])
assert.Equal(t, model.Int, values[4])

assert.Equal(t, model.String, values[3])
assert.Equal(t, model.Int, values[4])
assert.Nil(t, model.NullTimePtr)

assert.Nil(t, model.NullTimePtr)
})
})
}
3 changes: 1 addition & 2 deletions persistence/sql/identity/persister_identity.go
Expand Up @@ -297,9 +297,8 @@ func (p *IdentityPersister) createVerifiableAddresses(ctx context.Context, conn
work = append(work, &id.VerifiableAddresses[i])
}
}
err = batch.Create(ctx, &batch.TracerConnection{Tracer: p.r.Tracer(ctx), Connection: conn}, work)

return err
return batch.Create(ctx, &batch.TracerConnection{Tracer: p.r.Tracer(ctx), Connection: conn}, work)
}

func updateAssociation[T interface {
Expand Down
2 changes: 1 addition & 1 deletion script/testenv.sh
@@ -1,7 +1,7 @@
#!/bin/bash

docker rm -f kratos_test_database_mysql kratos_test_database_postgres kratos_test_database_cockroach kratos_test_hydra || true
docker run --platform linux/amd64 --name kratos_test_database_mysql -p 3444:3306 -e MYSQL_ROOT_PASSWORD=secret -d mysql:8.0.23
docker run --platform linux/amd64 --name kratos_test_database_mysql -p 3444:3306 -e MYSQL_ROOT_PASSWORD=secret -d mysql:8.0.26
docker run --platform linux/amd64 --name kratos_test_database_postgres -p 3445:5432 -e POSTGRES_PASSWORD=secret -e POSTGRES_DB=postgres -d postgres:11.8 postgres -c log_statement=all
docker run --platform linux/amd64 --name kratos_test_database_cockroach -p 3446:26257 -p 3447:8080 -d cockroachdb/cockroach:v22.2.6 start-single-node --insecure
docker run --platform linux/amd64 --name kratos_test_hydra -p 4444:4444 -p 4445:4445 -d -e DSN=memory -e URLS_SELF_ISSUER=http://localhost:4444/ -e URLS_LOGIN=http://localhost:4446/login -e URLS_CONSENT=http://localhost:4446/consent oryd/hydra:v2.0.2 serve all --dev
Expand Down

0 comments on commit 8ae8783

Please sign in to comment.