Skip to content

Commit

Permalink
Fix handling of slice parameters. Closes #279
Browse files Browse the repository at this point in the history
  • Loading branch information
xiam committed Oct 21, 2016
1 parent 9e3ad92 commit ddf461a
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 60 deletions.
15 changes: 15 additions & 0 deletions internal/sqladapter/exql/order_by_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,21 @@ func TestOrderBy(t *testing.T) {
}
}

func TestOrderByRaw(t *testing.T) {
o := JoinWithOrderBy(
JoinSortColumns(
&SortColumn{Column: RawValue("CASE WHEN id IN ? THEN 0 ELSE 1 END")},
),
)

s := o.Compile(defaultTemplate)
e := `ORDER BY CASE WHEN id IN ? THEN 0 ELSE 1 END`

if trim(s) != e {
t.Fatalf("Got: %s, Expecting: %s", s, e)
}
}

func TestOrderByDesc(t *testing.T) {
o := JoinWithOrderBy(
JoinSortColumns(
Expand Down
101 changes: 90 additions & 11 deletions lib/sqlbuilder/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,71 @@ func TestSelect(t *testing.T) {
b.SelectFrom("artist").String(),
)

{
rawCase := db.Raw("CASE WHEN id IN ? THEN 0 ELSE 1 END", []int{1000, 2000})
sel := b.SelectFrom("artist").OrderBy(rawCase)
assert.Equal(
`SELECT * FROM "artist" ORDER BY CASE WHEN id IN ($1, $2) THEN 0 ELSE 1 END`,
sel.String(),
)
assert.Equal(
[]interface{}{1000, 2000},
sel.Arguments(),
)
}

{
rawCase := db.Raw("CASE WHEN id IN ? THEN 0 ELSE 1 END", []int{1000})
sel := b.SelectFrom("artist").OrderBy(rawCase)
assert.Equal(
`SELECT * FROM "artist" ORDER BY CASE WHEN id IN ($1) THEN 0 ELSE 1 END`,
sel.String(),
)
assert.Equal(
[]interface{}{1000},
sel.Arguments(),
)
}

{
rawCase := db.Raw("CASE WHEN id IN ? THEN 0 ELSE 1 END", []int{})
sel := b.SelectFrom("artist").OrderBy(rawCase)
assert.Equal(
`SELECT * FROM "artist" ORDER BY CASE WHEN id IN (NULL) THEN 0 ELSE 1 END`,
sel.String(),
)
assert.Equal(
[]interface{}(nil),
sel.Arguments(),
)
}

{
rawCase := db.Raw("CASE WHEN id IN (NULL) THEN 0 ELSE 1 END")
sel := b.SelectFrom("artist").OrderBy(rawCase)
assert.Equal(
`SELECT * FROM "artist" ORDER BY CASE WHEN id IN (NULL) THEN 0 ELSE 1 END`,
sel.String(),
)
assert.Equal(
[]interface{}(nil),
rawCase.Arguments(),
)
}

{
rawCase := db.Raw("CASE WHEN id IN (?, ?) THEN 0 ELSE 1 END", 1000, 2000)
sel := b.SelectFrom("artist").OrderBy(rawCase)
assert.Equal(
`SELECT * FROM "artist" ORDER BY CASE WHEN id IN ($1, $2) THEN 0 ELSE 1 END`,
sel.String(),
)
assert.Equal(
[]interface{}{1000, 2000},
rawCase.Arguments(),
)
}

{
sel := b.Select(db.Func("DISTINCT", "name")).From("artist")
assert.Equal(
Expand Down Expand Up @@ -49,15 +114,29 @@ func TestSelect(t *testing.T) {
b.Select().From("artist").Where(db.Cond{1: db.Func("ANY", db.Raw("column"))}).String(),
)

assert.Equal(
`SELECT * FROM "artist" WHERE ("id" NOT IN ($1, $2))`,
b.Select().From("artist").Where(db.Cond{"id NOT IN": []int{0, -1}}).String(),
)
{
q := b.Select().From("artist").Where(db.Cond{"id NOT IN": []int{0, -1}})
assert.Equal(
`SELECT * FROM "artist" WHERE ("id" NOT IN ($1, $2))`,
q.String(),
)
assert.Equal(
[]interface{}{0, -1},
q.Arguments(),
)
}

assert.Equal(
`SELECT * FROM "artist" WHERE ("id" NOT IN ($1))`,
b.Select().From("artist").Where(db.Cond{"id NOT IN": []int{-1}}).String(),
)
{
q := b.Select().From("artist").Where(db.Cond{"id NOT IN": []int{-1}})
assert.Equal(
`SELECT * FROM "artist" WHERE ("id" NOT IN ($1))`,
q.String(),
)
assert.Equal(
[]interface{}{-1},
q.Arguments(),
)
}

assert.Equal(
`SELECT * FROM "artist" WHERE ("id" IN ($1, $2))`,
Expand Down Expand Up @@ -288,7 +367,7 @@ func TestSelect(t *testing.T) {
)

assert.Equal(
`SELECT * FROM "artist" WHERE ("id" IS NULL)`,
`SELECT * FROM "artist" WHERE ("id" IN (NULL))`,
b.SelectFrom("artist").Where(db.Cond{"id": []int64{}}).String(),
)

Expand Down Expand Up @@ -671,7 +750,7 @@ func TestUpdate(t *testing.T) {
idSlice := []int64{}
q := b.Update("artist").Set(db.Cond{"some_column": 10}).Where(db.Cond{"id": 1}, db.Cond{"another_val": idSlice})
assert.Equal(
`UPDATE "artist" SET "some_column" = $1 WHERE ("id" = $2 AND "another_val" IS NULL)`,
`UPDATE "artist" SET "some_column" = $1 WHERE ("id" = $2 AND "another_val" IN (NULL))`,
q.String(),
)
assert.Equal(
Expand All @@ -684,7 +763,7 @@ func TestUpdate(t *testing.T) {
idSlice := []int64{}
q := b.Update("artist").Where(db.Cond{"id": 1}, db.Cond{"another_val": idSlice}).Set(db.Cond{"some_column": 10})
assert.Equal(
`UPDATE "artist" SET "some_column" = $1 WHERE ("id" = $2 AND "another_val" IS NULL)`,
`UPDATE "artist" SET "some_column" = $1 WHERE ("id" = $2 AND "another_val" IN (NULL))`,
q.String(),
)
assert.Equal(
Expand Down
98 changes: 55 additions & 43 deletions lib/sqlbuilder/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,33 +26,42 @@ func newTemplateWithUtils(template *exql.Template) *templateWithUtils {

func expandPlaceholders(in string, args ...interface{}) (string, []interface{}) {
argn := 0
argx := make([]interface{}, 0, len(args))
for i := 0; i < len(in); i++ {
if in[i] == '?' {
if len(args) > argn { // we have arguments to match.
u := toInterfaceArguments(args[argn])
if len(args) > argn {
k := `?`

if len(u) > 1 {
// An array of arguments
k = `(?` + strings.Repeat(`, ?`, len(u)-1) + `)`
} else if len(u) == 1 {
if rawValue, ok := u[0].(db.RawValue); ok {
k = rawValue.Raw()
u = []interface{}{}
values, isSlice := toInterfaceArguments(args[argn])
if isSlice {
if len(values) == 0 {
k = `(NULL)`
} else {
k = `(?` + strings.Repeat(`, ?`, len(values)-1) + `)`
}
} else {
if len(values) == 1 {
if rawValue, ok := values[0].(db.RawValue); ok {
k, values = rawValue.Raw(), nil
}
} else if len(values) == 0 {
k = `NULL`
}
}

lk := len(k)
if lk > 1 {
if k != `?` {
in = in[:i] + k + in[i+1:]
i += len(k) - 1
}
args = append(args[:argn], append(u, args[argn+1:]...)...)
argn += len(u)

if len(values) > 0 {
argx = append(argx, values...)
}
argn++
}
}
}
return in, args
return in, argx
}

// ToWhereWithArguments converts the given parameters into a exql.Where
Expand Down Expand Up @@ -154,7 +163,7 @@ func (tu *templateWithUtils) PlaceholderValue(in interface{}) (exql.Fragment, []
fnName := t.Name()
fnArgs := []interface{}{}

args := toInterfaceArguments(t.Arguments())
args, _ := toInterfaceArguments(t.Arguments())
fragments := []string{}
for i := range args {
frag, args := tu.PlaceholderValue(args[i])
Expand All @@ -169,33 +178,30 @@ func (tu *templateWithUtils) PlaceholderValue(in interface{}) (exql.Fragment, []
}

// toInterfaceArguments converts the given value into an array of interfaces.
func toInterfaceArguments(value interface{}) (args []interface{}) {
func toInterfaceArguments(value interface{}) (args []interface{}, isSlice bool) {
v := reflect.ValueOf(value)

if value == nil {
return nil
return nil, false
}

v := reflect.ValueOf(value)

switch v.Type().Kind() {
case reflect.Slice:
if v.Type().Kind() == reflect.Slice {
var i, total int

// Byte slice gets transformed into a string.
if v.Type().Elem().Kind() == reflect.Uint8 {
return []interface{}{string(value.([]byte))}
return []interface{}{string(value.([]byte))}, false
}

total = v.Len()
if total > 0 {
args = make([]interface{}, total)
for i = 0; i < total; i++ {
args[i] = v.Index(i).Interface()
}
return args
args = make([]interface{}, total)
for i = 0; i < total; i++ {
args[i] = v.Index(i).Interface()
}
return nil
default:
args = []interface{}{value}
return args, true
}

return args
return []interface{}{value}, false
}

// ToColumnValues converts the given conditions into a exql.ColumnValues struct.
Expand Down Expand Up @@ -265,35 +271,41 @@ func (tu *templateWithUtils) ToColumnValues(term interface{}) (cv exql.ColumnVal
// A function with one or more arguments.
fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")"
}
expanded, fnArgs := expandPlaceholders(fnName, fnArgs)
expanded, fnArgs := expandPlaceholders(fnName, fnArgs...)
columnValue.Value = exql.RawValue(expanded)
args = append(args, fnArgs...)
case db.RawValue:
expanded, rawArgs := expandPlaceholders(value.Raw(), value.Arguments()...)
columnValue.Value = exql.RawValue(expanded)
args = append(args, rawArgs...)
default:
v := toInterfaceArguments(value)
v, isSlice := toInterfaceArguments(value)

if v == nil {
// Nil value given.
columnValue.Value = sqlNull
if isSlice {
if columnValue.Operator == "" {
columnValue.Operator = sqlIsOperator
columnValue.Operator = sqlInOperator
}
} else {
if len(v) > 1 || reflect.TypeOf(value).Kind() == reflect.Slice {
if len(v) > 0 {
// Array value given.
columnValue.Value = exql.RawValue(fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1)))
} else {
// Single value given.
columnValue.Value = exql.RawValue(`(NULL)`)
}
args = append(args, v...)
} else {
if v == nil {
// Nil value given.
columnValue.Value = sqlNull
if columnValue.Operator == "" {
columnValue.Operator = sqlInOperator
columnValue.Operator = sqlIsOperator
}
} else {
// Single value given.
columnValue.Value = sqlPlaceholder
args = append(args, v...)
}
args = append(args, v...)
}

}

// Using guessed operator if no operator was given.
Expand Down
2 changes: 1 addition & 1 deletion lib/sqlbuilder/placeholder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func TestPlaceholderArray(t *testing.T) {

{
ret, _ := expandPlaceholders("??", []interface{}{1, 2, 3}, []interface{}{}, []interface{}{4, 5}, []interface{}{})
assert.Equal(t, "(?, ?, ?)?", ret)
assert.Equal(t, "(?, ?, ?)(NULL)", ret)
}
}

Expand Down
12 changes: 7 additions & 5 deletions lib/sqlbuilder/select.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type selector struct {
groupBy *exql.GroupBy
groupByArgs []interface{}

orderBy exql.OrderBy
orderBy *exql.OrderBy
orderByArgs []interface{}

limit exql.Limit
Expand Down Expand Up @@ -161,7 +161,7 @@ func (qs *selector) OrderBy(columns ...interface{}) Selector {
Column: exql.RawValue(col),
}
qs.mu.Lock()
qs.orderByArgs = args
qs.orderByArgs = append(qs.orderByArgs, args...)
qs.mu.Unlock()
case db.Function:
fnName, fnArgs := value.Name(), value.Arguments()
Expand All @@ -175,7 +175,7 @@ func (qs *selector) OrderBy(columns ...interface{}) Selector {
Column: exql.RawValue(expanded),
}
qs.mu.Lock()
qs.orderByArgs = fnArgs
qs.orderByArgs = append(qs.orderByArgs, fnArgs...)
qs.mu.Unlock()
case string:
if strings.HasPrefix(value, "-") {
Expand Down Expand Up @@ -204,7 +204,9 @@ func (qs *selector) OrderBy(columns ...interface{}) Selector {
}

qs.mu.Lock()
qs.orderBy.SortColumns = &sortColumns
qs.orderBy = &exql.OrderBy{
SortColumns: &sortColumns,
}
qs.mu.Unlock()

return qs
Expand Down Expand Up @@ -332,7 +334,7 @@ func (qs *selector) statement() *exql.Statement {
Offset: qs.offset,
Joins: exql.JoinConditions(qs.joins...),
Where: qs.where,
OrderBy: &qs.orderBy,
OrderBy: qs.orderBy,
GroupBy: qs.groupBy,
}
}
Expand Down

0 comments on commit ddf461a

Please sign in to comment.