Skip to content
This repository has been archived by the owner on Jan 28, 2021. It is now read-only.

Commit

Permalink
Fixed bug in is false evaluation for nil values. Added tests for bool…
Browse files Browse the repository at this point in the history
… evaluation of string fields. This required allowing conversion of strings to bool (always false, as MySQL). This in turn exposed a bug in optimization, where non-column expressions were inappropriately getting coerced into boolean values. Removed the boolean casting, which is slightly less performant but correct in all cases.

Signed-off-by: Zach Musgrave <zach@liquidata.co>
  • Loading branch information
zachmu committed Jul 9, 2019
1 parent f86f7e2 commit ded9391
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 31 deletions.
14 changes: 11 additions & 3 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@ var queries = []struct {
"SELECT i FROM mytable ORDER BY i DESC;",
[]sql.Row{{int64(3)}, {int64(2)}, {int64(1)}},
},
{
"SELECT i FROM mytable WHERE 'hello';",
[]sql.Row{},
},
{
"SELECT i FROM mytable WHERE not 'hello';",
[]sql.Row{{int64(1)}, {int64(2)}, {int64(3)}},
},
{
"SELECT i FROM mytable WHERE s = 'first row' ORDER BY i DESC;",
[]sql.Row{{int64(1)}},
Expand Down Expand Up @@ -157,12 +165,12 @@ var queries = []struct {
[]sql.Row{{int64(2)}, {nil}, {nil}},
},
{
"SELECT i FROM niltable WHERE b IS FALSE",
[]sql.Row{{int64(2)}, {nil}, {nil}},
"SELECT f FROM niltable WHERE b IS FALSE",
[]sql.Row{{3.0}},
},
{
"SELECT i FROM niltable WHERE b IS NOT FALSE",
[]sql.Row{{int64(1)}, {int64(4)}},
[]sql.Row{{int64(1)}, {int64(2)}, {int64(4)}, {nil}},
},
{
"SELECT COUNT(*) FROM mytable;",
Expand Down
25 changes: 4 additions & 21 deletions sql/analyzer/optimization_rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -380,37 +380,20 @@ func evalFilter(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error)
return e.Left, nil
}

return e, nil
case *expression.Literal, expression.Tuple:
return e, nil
default:
if !isEvaluable(e) {
return e, nil
}

if _, ok := e.(*expression.Literal); ok {
return e, nil
}

// UnaryMinus expressions come back from the parser when a negative float is evaluated. Treat them just like
// normal literal expressions.
if um, ok := e.(*expression.UnaryMinus); ok {
negated, err := e.Eval(ctx, nil)
if err != nil {
return nil, err
}
return expression.NewLiteral(negated, um.Type()), nil
}

// All other expressions types can be evaluated once and turned into literals for the rest of query execution
val, err := e.Eval(ctx, nil)
if err != nil {
return e, nil
}

val, err = sql.Boolean.Convert(val)
if err != nil {
return e, nil
}

return expression.NewLiteral(val.(bool), sql.Boolean), nil
return expression.NewLiteral(val, e.Type()), nil
}
})
if err != nil {
Expand Down
7 changes: 3 additions & 4 deletions sql/expression/istrue.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package expression

import "github.com/src-d/go-mysql-server/sql"

// IsNull is an expression that checks if an expression is true.
// IsTrue is an expression that checks if an expression is true.
type IsTrue struct {
UnaryExpression
invert bool
Expand Down Expand Up @@ -40,7 +40,7 @@ func (e *IsTrue) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {

var boolVal interface{}
if v == nil {
boolVal = false
return false, nil
} else {
boolVal, err = sql.Boolean.Convert(v)
if err != nil {
Expand Down Expand Up @@ -70,8 +70,7 @@ func (e *IsTrue) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) {
}
if e.invert {
return f(NewIsFalse(child))
} else {
return f(NewIsTrue(child))
}
return f(NewIsTrue(child))
}

40 changes: 38 additions & 2 deletions sql/expression/istrue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,24 @@ func TestIsTrue(t *testing.T) {
require.Equal(true, eval(t, e, sql.NewRow(100)))
require.Equal(true, eval(t, e, sql.NewRow(-1)))
require.Equal(false, eval(t, e, sql.NewRow(0)))

floatF := NewGetField(0, sql.Float64, "col1", true)
e = NewIsTrue(floatF)
require.Equal(sql.Boolean, e.Type())
require.False(e.IsNullable())
require.Equal(false, eval(t, e, sql.NewRow(nil)))
require.Equal(true, eval(t, e, sql.NewRow(1.5)))
require.Equal(true, eval(t, e, sql.NewRow(-1.5)))
require.Equal(false, eval(t, e, sql.NewRow(0)))

stringF := NewGetField(0, sql.Text, "col1", true)
e = NewIsTrue(stringF)
require.Equal(sql.Boolean, e.Type())
require.False(e.IsNullable())
require.Equal(false, eval(t, e, sql.NewRow(nil)))
require.Equal(false, eval(t, e, sql.NewRow("")))
require.Equal(false, eval(t, e, sql.NewRow("false")))
require.Equal(false, eval(t, e, sql.NewRow("true")))
}

func TestIsFalse(t *testing.T) {
Expand All @@ -34,16 +52,34 @@ func TestIsFalse(t *testing.T) {
e := NewIsFalse(boolF)
require.Equal(sql.Boolean, e.Type())
require.False(e.IsNullable())
require.Equal(true, eval(t, e, sql.NewRow(nil)))
require.Equal(false, eval(t, e, sql.NewRow(nil)))
require.Equal(false, eval(t, e, sql.NewRow(true)))
require.Equal(true, eval(t, e, sql.NewRow(false)))

intF := NewGetField(0, sql.Int64, "col1", true)
e = NewIsFalse(intF)
require.Equal(sql.Boolean, e.Type())
require.False(e.IsNullable())
require.Equal(true, eval(t, e, sql.NewRow(nil)))
require.Equal(false, eval(t, e, sql.NewRow(nil)))
require.Equal(false, eval(t, e, sql.NewRow(100)))
require.Equal(false, eval(t, e, sql.NewRow(-1)))
require.Equal(true, eval(t, e, sql.NewRow(0)))

floatF := NewGetField(0, sql.Float64, "col1", true)
e = NewIsFalse(floatF)
require.Equal(sql.Boolean, e.Type())
require.False(e.IsNullable())
require.Equal(false, eval(t, e, sql.NewRow(nil)))
require.Equal(false, eval(t, e, sql.NewRow(1.5)))
require.Equal(false, eval(t, e, sql.NewRow(-1.5)))
require.Equal(true, eval(t, e, sql.NewRow(0)))

stringF := NewGetField(0, sql.Text, "col1", true)
e = NewIsFalse(stringF)
require.Equal(sql.Boolean, e.Type())
require.False(e.IsNullable())
require.Equal(false, eval(t, e, sql.NewRow(nil)))
require.Equal(true, eval(t, e, sql.NewRow("")))
require.Equal(true, eval(t, e, sql.NewRow("false")))
require.Equal(true, eval(t, e, sql.NewRow("true")))
}
2 changes: 1 addition & 1 deletion sql/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,7 @@ func (t booleanT) Convert(v interface{}) (interface{}, error) {
case float32, float64:
return int(math.Round(v.(float64))) != 0, nil
case string:
return false, fmt.Errorf("unable to cast string to bool")
return false, nil

case nil:
return nil, fmt.Errorf("unable to cast nil to bool")
Expand Down
16 changes: 16 additions & 0 deletions sql/type_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,22 @@ func TestText(t *testing.T) {
convert(t, Text, var3, "abc")
}

func TestBoolean(t *testing.T) {
convert(t, Boolean, "", false)
convert(t, Boolean, "true", false)
convert(t, Boolean, 0, false)
convert(t, Boolean, 1, true)
convert(t, Boolean, -1, true)
convert(t, Boolean, 0.0, false)
convert(t, Boolean, 0.4, false)
convert(t, Boolean, 0.5, true)
convert(t, Boolean, 1.0, true)
convert(t, Boolean, -1.0, true)

eq(t, Boolean, true, true)
eq(t, Boolean, false, false)
}

func TestInt32(t *testing.T) {
convert(t, Int32, int32(1), int32(1))
convert(t, Int32, 1, int32(1))
Expand Down

0 comments on commit ded9391

Please sign in to comment.