Skip to content

Commit

Permalink
fix: use RETURNING clause for batch create
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr committed May 26, 2023
1 parent 61cb722 commit d01ce4c
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 25 deletions.
@@ -0,0 +1,11 @@
[
"0001-01-01T00:00:00Z",
"0001-01-01T00:00:00Z",
"gen_random_uuid()",
"string",
42,
null,
{
"foo": "bar"
}
]
103 changes: 91 additions & 12 deletions persistence/sql/batch/create.go
Expand Up @@ -5,15 +5,19 @@ package batch

import (
"context"
"database/sql"
"fmt"
"reflect"
"sort"
"strings"
"time"

"github.com/jmoiron/sqlx/reflectx"

"github.com/ory/x/dbal"

"github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"
"github.com/jmoiron/sqlx/reflectx"
"github.com/pkg/errors"

"github.com/ory/x/otelx"
Expand Down Expand Up @@ -72,12 +76,11 @@ func buildInsertQueryArgs[T any](ctx context.Context, quoter quoter, models []*T
}
}

func buildInsertQueryValues[T any](mapper *reflectx.Mapper, columns []string, models []*T) (values []any, err error) {
now := time.Now().UTC().Truncate(time.Microsecond)

func buildInsertQueryValues[T any](dialect string, mapper *reflectx.Mapper, columns []string, models []*T, nowFunc func() time.Time) (values []any, err error) {
for _, m := range models {
m := reflect.ValueOf(m)

now := nowFunc()
// Append model fields to args
for _, c := range columns {
field := mapper.FieldByName(m, c)
Expand All @@ -89,17 +92,23 @@ func buildInsertQueryValues[T any](mapper *reflectx.Mapper, columns []string, mo
}
case "updated_at":
field.Set(reflect.ValueOf(now))

case "id":
if field.Interface().(uuid.UUID) != uuid.Nil {
break // breaks switch, not for
}
id, err := uuid.NewV4()
if err != nil {
return nil, err

if dialect == dbal.DriverCockroachDB {
values = append(values, "gen_random_uuid()")
continue
} else {
id, err := uuid.NewV4()
if err != nil {
return nil, err
}
field.Set(reflect.ValueOf(id))
}
field.Set(reflect.ValueOf(id))
}

values = append(values, field.Interface())

// Special-handling for *sqlxx.NullTime: mapper.FieldByName sets this to a zero time.Time,
Expand All @@ -125,26 +134,96 @@ func Create[T any](ctx context.Context, p *TracerConnection, models []*T) (err e
return nil
}

var v T
model := pop.NewModel(v, ctx)

conn := p.Connection
quoter, ok := conn.Dialect.(quoter)
if !ok {
return errors.Errorf("store is not a quoter: %T", conn.Store)
}

queryArgs := buildInsertQueryArgs(ctx, quoter, models)
values, err := buildInsertQueryValues(conn.TX.Mapper, queryArgs.Columns, models)
values, err := buildInsertQueryValues(conn.Dialect.Name(), conn.TX.Mapper, queryArgs.Columns, models, func() time.Time { return time.Now().UTC().Truncate(time.Microsecond) })
if err != nil {
return err
}

var returningClause string
if conn.Dialect.Name() != dbal.DriverMySQL {
// PostgreSQL, CockroachDB, SQLite support RETURNING.
returningClause = fmt.Sprintf("RETURNING %s", model.IDField())
}

query := conn.Dialect.TranslateSQL(fmt.Sprintf(
"INSERT INTO %s (%s) VALUES\n%s",
"INSERT INTO %s (%s) VALUES\n%s\n%s",
queryArgs.TableName,
queryArgs.ColumnsDecl,
queryArgs.Placeholders,
returningClause,
))

_, err = conn.Store.ExecContext(ctx, query, values...)
rows, err := conn.TX.QueryContext(ctx, query, values...)
if err != nil {
return sqlcon.HandleError(err)
}
defer rows.Close()

count := 0
for rows.Next() {
if err := rows.Err(); err != nil {
return sqlcon.HandleError(err)
}

if err := setModelID(rows, pop.NewModel(models[count], ctx)); err != nil {
return err
}
count++
}

if err := rows.Err(); err != nil {
return sqlcon.HandleError(err)
}

if err := rows.Close(); err != nil {
return sqlcon.HandleError(err)
}

return sqlcon.HandleError(err)
}

func setModelID(row *sql.Rows, model *pop.Model) error {
el := reflect.ValueOf(model.Value).Elem()
fbn := el.FieldByName("ID")
if !fbn.IsValid() {
return errors.New("model does not have a field named id")
}

pkt, err := model.PrimaryKeyType()
if err != nil {
return errors.WithStack(err)
}

switch pkt {
case "UUID":
var id uuid.UUID
if err := row.Scan(&id); err != nil {
return errors.WithStack(err)
}
fbn.Set(reflect.ValueOf(id))
default:
var id interface{}
if err := row.Scan(&id); err != nil {
return errors.WithStack(err)
}
v := reflect.ValueOf(id)
switch fbn.Kind() {
case reflect.Int, reflect.Int64:
fbn.SetInt(v.Int())
default:
fbn.Set(reflect.ValueOf(id))
}
}

return nil
}
37 changes: 26 additions & 11 deletions persistence/sql/batch/create_test.go
Expand Up @@ -9,6 +9,8 @@ import (
"testing"
"time"

"github.com/ory/x/dbal"

"github.com/gofrs/uuid"
"github.com/jmoiron/sqlx/reflectx"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -87,21 +89,34 @@ func Test_buildInsertQueryValues(t *testing.T) {
Traits: []byte(`{"foo": "bar"}`),
}
mapper := reflectx.NewMapper("db")
values, err := buildInsertQueryValues(mapper, []string{"created_at", "updated_at", "id", "string", "int", "null_time_ptr", "traits"}, []*testModel{model})
require.NoError(t, err)

assert.NotNil(t, model.CreatedAt)
assert.Equal(t, model.CreatedAt, values[0])
nowFunc := func() time.Time {
return time.Time{}
}
t.Run("case=cockroach", func(t *testing.T) {
values, err := buildInsertQueryValues(dbal.DriverCockroachDB, mapper, []string{"created_at", "updated_at", "id", "string", "int", "null_time_ptr", "traits"}, []*testModel{model}, nowFunc)
require.NoError(t, err)
snapshotx.SnapshotT(t, values)
})

t.Run("case=others", func(t *testing.T) {
values, err := buildInsertQueryValues("other", mapper, []string{"created_at", "updated_at", "id", "string", "int", "null_time_ptr", "traits"}, []*testModel{model}, nowFunc)
require.NoError(t, err)

assert.NotNil(t, model.CreatedAt)
assert.Equal(t, model.CreatedAt, values[0])

assert.NotNil(t, model.UpdatedAt)
assert.Equal(t, model.UpdatedAt, values[1])

assert.NotNil(t, model.UpdatedAt)
assert.Equal(t, model.UpdatedAt, values[1])
assert.NotNil(t, model.ID)
assert.Equal(t, model.ID, values[2])

assert.NotNil(t, model.ID)
assert.Equal(t, model.ID, values[2])
assert.Equal(t, model.String, values[3])
assert.Equal(t, model.Int, values[4])

assert.Equal(t, model.String, values[3])
assert.Equal(t, model.Int, values[4])
assert.Nil(t, model.NullTimePtr)

assert.Nil(t, model.NullTimePtr)
})
})
}
3 changes: 1 addition & 2 deletions persistence/sql/identity/persister_identity.go
Expand Up @@ -297,9 +297,8 @@ func (p *IdentityPersister) createVerifiableAddresses(ctx context.Context, conn
work = append(work, &id.VerifiableAddresses[i])
}
}
err = batch.Create(ctx, &batch.TracerConnection{Tracer: p.r.Tracer(ctx), Connection: conn}, work)

return err
return batch.Create(ctx, &batch.TracerConnection{Tracer: p.r.Tracer(ctx), Connection: conn}, work)
}

func updateAssociation[T interface {
Expand Down

0 comments on commit d01ce4c

Please sign in to comment.