Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use RETURNING clause for batch create #3293

Merged
merged 9 commits into from
May 30, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[
"0001-01-01T00:00:00Z",
"0001-01-01T00:00:00Z",
"gen_random_uuid()",
"string",
42,
null,
{
"foo": "bar"
}
]
103 changes: 91 additions & 12 deletions persistence/sql/batch/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,19 @@ package batch

import (
"context"
"database/sql"
"fmt"
"reflect"
"sort"
"strings"
"time"

"github.com/jmoiron/sqlx/reflectx"

"github.com/ory/x/dbal"

"github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"
"github.com/jmoiron/sqlx/reflectx"
"github.com/pkg/errors"

"github.com/ory/x/otelx"
Expand Down Expand Up @@ -72,12 +76,11 @@ func buildInsertQueryArgs[T any](ctx context.Context, quoter quoter, models []*T
}
}

func buildInsertQueryValues[T any](mapper *reflectx.Mapper, columns []string, models []*T) (values []any, err error) {
now := time.Now().UTC().Truncate(time.Microsecond)

func buildInsertQueryValues[T any](dialect string, mapper *reflectx.Mapper, columns []string, models []*T, nowFunc func() time.Time) (values []any, err error) {
for _, m := range models {
m := reflect.ValueOf(m)

now := nowFunc()
// Append model fields to args
for _, c := range columns {
field := mapper.FieldByName(m, c)
Expand All @@ -89,17 +92,23 @@ func buildInsertQueryValues[T any](mapper *reflectx.Mapper, columns []string, mo
}
case "updated_at":
field.Set(reflect.ValueOf(now))

case "id":
if field.Interface().(uuid.UUID) != uuid.Nil {
break // breaks switch, not for
}
id, err := uuid.NewV4()
if err != nil {
return nil, err

if dialect == dbal.DriverCockroachDB {
values = append(values, "gen_random_uuid()")
continue
} else {
id, err := uuid.NewV4()
if err != nil {
return nil, err
}
field.Set(reflect.ValueOf(id))
}
field.Set(reflect.ValueOf(id))
}

values = append(values, field.Interface())

// Special-handling for *sqlxx.NullTime: mapper.FieldByName sets this to a zero time.Time,
Expand All @@ -125,26 +134,96 @@ func Create[T any](ctx context.Context, p *TracerConnection, models []*T) (err e
return nil
}

var v T
model := pop.NewModel(v, ctx)

conn := p.Connection
quoter, ok := conn.Dialect.(quoter)
if !ok {
return errors.Errorf("store is not a quoter: %T", conn.Store)
}

queryArgs := buildInsertQueryArgs(ctx, quoter, models)
values, err := buildInsertQueryValues(conn.TX.Mapper, queryArgs.Columns, models)
values, err := buildInsertQueryValues(conn.Dialect.Name(), conn.TX.Mapper, queryArgs.Columns, models, func() time.Time { return time.Now().UTC().Truncate(time.Microsecond) })
if err != nil {
return err
}

var returningClause string
if conn.Dialect.Name() != dbal.DriverMySQL {
// PostgreSQL, CockroachDB, SQLite support RETURNING.
returningClause = fmt.Sprintf("RETURNING %s", model.IDField())
}

query := conn.Dialect.TranslateSQL(fmt.Sprintf(
"INSERT INTO %s (%s) VALUES\n%s",
"INSERT INTO %s (%s) VALUES\n%s\n%s",
queryArgs.TableName,
queryArgs.ColumnsDecl,
queryArgs.Placeholders,
returningClause,
))

_, err = conn.Store.ExecContext(ctx, query, values...)
rows, err := conn.TX.QueryContext(ctx, query, values...)
if err != nil {
return sqlcon.HandleError(err)
}
defer rows.Close()

count := 0
for rows.Next() {
if err := rows.Err(); err != nil {
return sqlcon.HandleError(err)
}

if err := setModelID(rows, pop.NewModel(models[count], ctx)); err != nil {
return err
}
count++
}

if err := rows.Err(); err != nil {
return sqlcon.HandleError(err)
}

if err := rows.Close(); err != nil {
return sqlcon.HandleError(err)
}

return sqlcon.HandleError(err)
}

func setModelID(row *sql.Rows, model *pop.Model) error {
el := reflect.ValueOf(model.Value).Elem()
fbn := el.FieldByName("ID")
if !fbn.IsValid() {
return errors.New("model does not have a field named id")
}

pkt, err := model.PrimaryKeyType()
if err != nil {
return errors.WithStack(err)
}

switch pkt {
case "UUID":
var id uuid.UUID
if err := row.Scan(&id); err != nil {
return errors.WithStack(err)
}
fbn.Set(reflect.ValueOf(id))
default:
var id interface{}
if err := row.Scan(&id); err != nil {
return errors.WithStack(err)
}
v := reflect.ValueOf(id)
switch fbn.Kind() {
case reflect.Int, reflect.Int64:
fbn.SetInt(v.Int())
default:
fbn.Set(reflect.ValueOf(id))
}
}

return nil
}
37 changes: 26 additions & 11 deletions persistence/sql/batch/create_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"testing"
"time"

"github.com/ory/x/dbal"

"github.com/gofrs/uuid"
"github.com/jmoiron/sqlx/reflectx"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -87,21 +89,34 @@ func Test_buildInsertQueryValues(t *testing.T) {
Traits: []byte(`{"foo": "bar"}`),
}
mapper := reflectx.NewMapper("db")
values, err := buildInsertQueryValues(mapper, []string{"created_at", "updated_at", "id", "string", "int", "null_time_ptr", "traits"}, []*testModel{model})
require.NoError(t, err)

assert.NotNil(t, model.CreatedAt)
assert.Equal(t, model.CreatedAt, values[0])
nowFunc := func() time.Time {
return time.Time{}
}
t.Run("case=cockroach", func(t *testing.T) {
values, err := buildInsertQueryValues(dbal.DriverCockroachDB, mapper, []string{"created_at", "updated_at", "id", "string", "int", "null_time_ptr", "traits"}, []*testModel{model}, nowFunc)
require.NoError(t, err)
snapshotx.SnapshotT(t, values)
})

t.Run("case=others", func(t *testing.T) {
values, err := buildInsertQueryValues("other", mapper, []string{"created_at", "updated_at", "id", "string", "int", "null_time_ptr", "traits"}, []*testModel{model}, nowFunc)
require.NoError(t, err)

assert.NotNil(t, model.CreatedAt)
assert.Equal(t, model.CreatedAt, values[0])

assert.NotNil(t, model.UpdatedAt)
assert.Equal(t, model.UpdatedAt, values[1])

assert.NotNil(t, model.UpdatedAt)
assert.Equal(t, model.UpdatedAt, values[1])
assert.NotNil(t, model.ID)
aeneasr marked this conversation as resolved.
Show resolved Hide resolved
assert.Equal(t, model.ID, values[2])

assert.NotNil(t, model.ID)
assert.Equal(t, model.ID, values[2])
assert.Equal(t, model.String, values[3])
assert.Equal(t, model.Int, values[4])

assert.Equal(t, model.String, values[3])
assert.Equal(t, model.Int, values[4])
assert.Nil(t, model.NullTimePtr)

assert.Nil(t, model.NullTimePtr)
})
})
}
3 changes: 1 addition & 2 deletions persistence/sql/identity/persister_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,8 @@ func (p *IdentityPersister) createVerifiableAddresses(ctx context.Context, conn
work = append(work, &id.VerifiableAddresses[i])
}
}
err = batch.Create(ctx, &batch.TracerConnection{Tracer: p.r.Tracer(ctx), Connection: conn}, work)

return err
return batch.Create(ctx, &batch.TracerConnection{Tracer: p.r.Tracer(ctx), Connection: conn}, work)
}

func updateAssociation[T interface {
Expand Down