Skip to content

Commit

Permalink
Merge pull request #274 from uptrace/feat/where-pk-cols
Browse files Browse the repository at this point in the history
feat: accept columns in WherePK
  • Loading branch information
vmihailenco committed Oct 27, 2021
2 parents 6601d13 + b3e7035 commit 9a44f93
Show file tree
Hide file tree
Showing 15 changed files with 100 additions and 35 deletions.
14 changes: 14 additions & 0 deletions internal/dbtest/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,20 @@ func TestQuery(t *testing.T) {
IfNotExists().
ColumnExpr("column_name VARCHAR(123)")
},
func(db *bun.DB) schema.QueryAppender {
models := []Model{
{ID: 1},
{ID: 2},
}
return db.NewSelect().Model(&models).WherePK()
},
func(db *bun.DB) schema.QueryAppender {
models := []Model{
{ID: 1, Str: "hello"},
{ID: 2, Str: "world"},
}
return db.NewSelect().Model(&models).WherePK("id", "str")
},
}

timeRE := regexp.MustCompile(`'\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d+(\+\d{2}:\d{2})?'`)
Expand Down
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mysql5-100
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SELECT `model`.`id`, `model`.`str` FROM `models` AS `model` WHERE (`model`.`id`, `model`.`str`) IN ((1, 'hello'), (2, 'world'))
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mysql5-99
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SELECT `model`.`id`, `model`.`str` FROM `models` AS `model` WHERE `model`.`id` IN (1, 2)
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mysql8-100
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SELECT `model`.`id`, `model`.`str` FROM `models` AS `model` WHERE (`model`.`id`, `model`.`str`) IN ((1, 'hello'), (2, 'world'))
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mysql8-99
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SELECT `model`.`id`, `model`.`str` FROM `models` AS `model` WHERE `model`.`id` IN (1, 2)
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-pg-100
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SELECT "model"."id", "model"."str" FROM "models" AS "model" WHERE ("model"."id", "model"."str") IN ((1, 'hello'), (2, 'world'))
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-pg-99
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SELECT "model"."id", "model"."str" FROM "models" AS "model" WHERE "model"."id" IN (1, 2)
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-pgx-100
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SELECT "model"."id", "model"."str" FROM "models" AS "model" WHERE ("model"."id", "model"."str") IN ((1, 'hello'), (2, 'world'))
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-pgx-99
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SELECT "model"."id", "model"."str" FROM "models" AS "model" WHERE "model"."id" IN (1, 2)
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-sqlite-100
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SELECT "model"."id", "model"."str" FROM "models" AS "model" WHERE ("model"."id", "model"."str") IN ((1, 'hello'), (2, 'world'))
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-sqlite-99
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SELECT "model"."id", "model"."str" FROM "models" AS "model" WHERE "model"."id" IN (1, 2)
99 changes: 70 additions & 29 deletions query_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ import (
)

const (
wherePKFlag internal.Flag = 1 << iota
forceDeleteFlag
forceDeleteFlag internal.Flag = 1 << iota
deletedFlag
allWithDeletedFlag
)
Expand Down Expand Up @@ -580,7 +579,8 @@ func formatterWithModel(fmter schema.Formatter, model schema.NamedArgAppender) s
type whereBaseQuery struct {
baseQuery

where []schema.QueryWithSep
where []schema.QueryWithSep
whereFields []*schema.Field
}

func (q *whereBaseQuery) addWhere(where schema.QueryWithSep) {
Expand All @@ -601,10 +601,46 @@ func (q *whereBaseQuery) addWhereGroup(sep string, where []schema.QueryWithSep)
q.addWhere(schema.SafeQueryWithSep("", nil, ")"))
}

func (q *whereBaseQuery) addWhereCols(cols []string) {
if q.table == nil {
err := fmt.Errorf("bun: got %T, but WherePK requires a struct or slice-based model", q.model)
q.setErr(err)
return
}

var fields []*schema.Field

if cols == nil {
if err := q.table.CheckPKs(); err != nil {
q.setErr(err)
return
}
fields = q.table.PKs
} else {
fields = make([]*schema.Field, len(cols))
for i, col := range cols {
field, err := q.table.Field(col)
if err != nil {
q.setErr(err)
return
}
fields[i] = field
}
}

if q.whereFields != nil {
err := errors.New("bun: WherePK can only be called once")
q.setErr(err)
return
}

q.whereFields = fields
}

func (q *whereBaseQuery) mustAppendWhere(
fmter schema.Formatter, b []byte, withAlias bool,
) ([]byte, error) {
if len(q.where) == 0 && !q.flags.Has(wherePKFlag) {
if len(q.where) == 0 && q.whereFields == nil {
err := errors.New("bun: Update and Delete queries require at least one Where")
return nil, err
}
Expand All @@ -614,7 +650,7 @@ func (q *whereBaseQuery) mustAppendWhere(
func (q *whereBaseQuery) appendWhere(
fmter schema.Formatter, b []byte, withAlias bool,
) (_ []byte, err error) {
if len(q.where) == 0 && !q.isSoftDelete() && !q.flags.Has(wherePKFlag) {
if len(q.where) == 0 && q.whereFields == nil && !q.isSoftDelete() {
return b, nil
}

Expand Down Expand Up @@ -656,11 +692,11 @@ func (q *whereBaseQuery) appendWhere(
}
}

if q.flags.Has(wherePKFlag) {
if q.whereFields != nil {
if len(b) > startLen {
b = append(b, " AND "...)
}
b, err = q.appendWherePK(fmter, b, withAlias)
b, err = q.appendWhereFields(fmter, b, q.whereFields, withAlias)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -691,37 +727,38 @@ func appendWhere(
return b, nil
}

func (q *whereBaseQuery) appendWherePK(
fmter schema.Formatter, b []byte, withAlias bool,
func (q *whereBaseQuery) appendWhereFields(
fmter schema.Formatter, b []byte, fields []*schema.Field, withAlias bool,
) (_ []byte, err error) {
if q.table == nil {
err := fmt.Errorf("bun: got %T, but WherePK requires a struct or slice-based model", q.model)
return nil, err
}
if err := q.table.CheckPKs(); err != nil {
err := fmt.Errorf("bun: got %T, but WherePK requires struct or slice-based model", q.model)
return nil, err
}

switch model := q.tableModel.(type) {
case *structTableModel:
return q.appendWherePKStruct(fmter, b, model, withAlias)
return q.appendWhereStructFields(fmter, b, model, fields, withAlias)
case *sliceTableModel:
return q.appendWherePKSlice(fmter, b, model, withAlias)
return q.appendWhereSliceFields(fmter, b, model, fields, withAlias)
default:
return nil, fmt.Errorf("bun: WhereColumn does not support %T", q.tableModel)
}

return nil, fmt.Errorf("bun: WherePK does not support %T", q.tableModel)
}

func (q *whereBaseQuery) appendWherePKStruct(
fmter schema.Formatter, b []byte, model *structTableModel, withAlias bool,
func (q *whereBaseQuery) appendWhereStructFields(
fmter schema.Formatter,
b []byte,
model *structTableModel,
fields []*schema.Field,
withAlias bool,
) (_ []byte, err error) {
if !model.strct.IsValid() {
return nil, errNilModel
}

isTemplate := fmter.IsNop()
b = append(b, '(')
for i, f := range q.table.PKs {
for i, f := range fields {
if i > 0 {
b = append(b, " AND "...)
}
Expand All @@ -741,18 +778,22 @@ func (q *whereBaseQuery) appendWherePKStruct(
return b, nil
}

func (q *whereBaseQuery) appendWherePKSlice(
fmter schema.Formatter, b []byte, model *sliceTableModel, withAlias bool,
func (q *whereBaseQuery) appendWhereSliceFields(
fmter schema.Formatter,
b []byte,
model *sliceTableModel,
fields []*schema.Field,
withAlias bool,
) (_ []byte, err error) {
if len(q.table.PKs) > 1 {
if len(fields) > 1 {
b = append(b, '(')
}
if withAlias {
b = appendColumns(b, q.table.SQLAlias, q.table.PKs)
b = appendColumns(b, q.table.SQLAlias, fields)
} else {
b = appendColumns(b, "", q.table.PKs)
b = appendColumns(b, "", fields)
}
if len(q.table.PKs) > 1 {
if len(fields) > 1 {
b = append(b, ')')
}

Expand All @@ -771,10 +812,10 @@ func (q *whereBaseQuery) appendWherePKSlice(

el := indirect(slice.Index(i))

if len(q.table.PKs) > 1 {
if len(fields) > 1 {
b = append(b, '(')
}
for i, f := range q.table.PKs {
for i, f := range fields {
if i > 0 {
b = append(b, ", "...)
}
Expand All @@ -784,7 +825,7 @@ func (q *whereBaseQuery) appendWherePKSlice(
b = f.AppendValue(fmter, b, el)
}
}
if len(q.table.PKs) > 1 {
if len(fields) > 1 {
b = append(b, ')')
}
}
Expand Down
4 changes: 2 additions & 2 deletions query_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ func (q *DeleteQuery) ModelTableExpr(query string, args ...interface{}) *DeleteQ

//------------------------------------------------------------------------------

func (q *DeleteQuery) WherePK() *DeleteQuery {
q.flags = q.flags.Set(wherePKFlag)
func (q *DeleteQuery) WherePK(cols ...string) *DeleteQuery {
q.addWhereCols(cols)
return q
}

Expand Down
4 changes: 2 additions & 2 deletions query_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ func (q *SelectQuery) ExcludeColumn(columns ...string) *SelectQuery {

//------------------------------------------------------------------------------

func (q *SelectQuery) WherePK() *SelectQuery {
q.flags = q.flags.Set(wherePKFlag)
func (q *SelectQuery) WherePK(cols ...string) *SelectQuery {
q.addWhereCols(cols)
return q
}

Expand Down
4 changes: 2 additions & 2 deletions query_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ func (q *UpdateQuery) OmitZero() *UpdateQuery {

//------------------------------------------------------------------------------

func (q *UpdateQuery) WherePK() *UpdateQuery {
q.flags = q.flags.Set(wherePKFlag)
func (q *UpdateQuery) WherePK(cols ...string) *UpdateQuery {
q.addWhereCols(cols)
return q
}

Expand Down

0 comments on commit 9a44f93

Please sign in to comment.