Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Strict comparisons #2483

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
44 changes: 44 additions & 0 deletions internal/bloblang/parser/query_arithmetic_parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -392,3 +392,47 @@ func TestArithmeticLiteralsParser(t *testing.T) {
assert.Equal(t, v, res, k)
}
}

func TestArithmeticParserErrors(t *testing.T) {
tests := map[string]string{
`false == "false"`: "line 1 char 1: expected bool value, got string (\"false\")",
`"false" == false`: "line 1 char 1: expected string value, got bool (false)",
`true == "true"`: "line 1 char 1: expected bool value, got string (\"true\")",
`"true" == true`: "line 1 char 1: expected string value, got bool (true)",
`420 == "420"`: "line 1 char 1: expected number value, got string (\"420\")",
`"420" == 420`: "line 1 char 1: expected string value, got number (420)",
`420.69 == "420.69"`: "line 1 char 1: expected number value, got string (\"420.69\")",
`"420.69" == 420.69`: "line 1 char 1: expected string value, got number (420.69)",
`false >= "false"`: "line 1 char 1: cannot compare types bool (from bool literal) and string (from string literal)",
`"false" <= false`: "line 1 char 1: cannot compare types string (from string literal) and bool (from bool literal)",
`true <= "true"`: "line 1 char 1: cannot compare types bool (from bool literal) and string (from string literal)",
`"true" >= true`: "line 1 char 1: cannot compare types string (from string literal) and bool (from bool literal)",
`420 <= "420"`: "line 1 char 1: cannot compare types number (from number literal) and string (from string literal)",
`"420" >= 420`: "line 1 char 1: cannot compare types string (from string literal) and number (from number literal)",
`420.69 >= "420.69"`: "line 1 char 1: cannot compare types number (from number literal) and string (from string literal)",
`"420.69" <= 420.69`: "line 1 char 1: cannot compare types string (from string literal) and number (from number literal)",
`false > "false"`: "line 1 char 1: cannot compare types bool (from bool literal) and string (from string literal)",
`"false" < false`: "line 1 char 1: cannot compare types string (from string literal) and bool (from bool literal)",
`true > "true"`: "line 1 char 1: cannot compare types bool (from bool literal) and string (from string literal)",
`"true" < true`: "line 1 char 1: cannot compare types string (from string literal) and bool (from bool literal)",
`420 > "420"`: "line 1 char 1: cannot compare types number (from number literal) and string (from string literal)",
`"420" < 420`: "line 1 char 1: cannot compare types string (from string literal) and number (from number literal)",
`420.69 > "420.69"`: "line 1 char 1: cannot compare types number (from number literal) and string (from string literal)",
`"420.69" < 420.69`: "line 1 char 1: cannot compare types string (from string literal) and number (from number literal)",
`false != "false"`: "line 1 char 1: expected bool value, got string (\"false\")",
`"false" != false`: "line 1 char 1: expected string value, got bool (false)",
`true != "true"`: "line 1 char 1: expected bool value, got string (\"true\")",
`"true" != true`: "line 1 char 1: expected string value, got bool (true)",
`420 != "420"`: "line 1 char 1: expected number value, got string (\"420\")",
`"420" != 420`: "line 1 char 1: expected string value, got number (420)",
`420.69 != "420.69"`: "line 1 char 1: expected number value, got string (\"420.69\")",
`"420.69" != 420.69`: "line 1 char 1: expected string value, got number (420.69)",
}

for test, err := range tests {
test := test
_, pErr := tryParseQuery(test)
require.NotNil(t, pErr)
assert.Equal(t, err, pErr.ErrorAtPosition([]rune(test)))
}
}
2 changes: 1 addition & 1 deletion internal/bloblang/parser/query_expression_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func matchCaseParser(pCtx Context) Func[query.MatchCase] {
if v == nil {
return false, nil
}
return value.ICompare(*v, lit.Value), nil
return value.ICompare(*v, lit.Value)
}, nil)
} else {
caseFn = p
Expand Down
15 changes: 8 additions & 7 deletions internal/bloblang/query/arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,13 +297,14 @@ func compareBoolFn(op ArithmeticOperator) func(lhs, rhs bool) bool {
return nil
}

func compareGenericFn(op ArithmeticOperator) func(lhs, rhs any) bool {
func compareGenericFn(op ArithmeticOperator) func(lhs, rhs any) (bool, error) {
switch op {
case ArithmeticEq:
return value.ICompare
case ArithmeticNeq:
return func(lhs, rhs any) bool {
return !value.ICompare(lhs, rhs)
return func(lhs, rhs any) (bool, error) {
val, err := value.ICompare(lhs, rhs)
return !val, err
}
}
return nil
Expand Down Expand Up @@ -332,7 +333,7 @@ func compareOp(op ArithmeticOperator) (arithmeticOpFunc, bool) {
if genericOpFn == nil {
return nil, NewTypeMismatch(op.String(), lFn, rFn, left, right)
}
return genericOpFn(lhs, value.RestrictForComparison(right)), nil
return genericOpFn(lhs, value.RestrictForComparison(right))
}
return strOpFn(lhs, rhs), nil
case float64:
Expand All @@ -344,7 +345,7 @@ func compareOp(op ArithmeticOperator) (arithmeticOpFunc, bool) {
if genericOpFn == nil {
return nil, NewTypeMismatch(op.String(), lFn, rFn, left, right)
}
return genericOpFn(lhs, value.RestrictForComparison(right)), nil
return genericOpFn(lhs, value.RestrictForComparison(right))
}
return numOpFn(lhs, rhs), nil
case bool:
Expand All @@ -356,14 +357,14 @@ func compareOp(op ArithmeticOperator) (arithmeticOpFunc, bool) {
if genericOpFn == nil {
return nil, NewTypeMismatch(op.String(), lFn, rFn, left, right)
}
return genericOpFn(lhs, value.RestrictForComparison(right)), nil
return genericOpFn(lhs, value.RestrictForComparison(right))
}
return boolOpFn(lhs, rhs), nil
default:
if genericOpFn == nil {
return nil, NewTypeMismatch(op.String(), lFn, rFn, left, right)
}
return genericOpFn(left, right), nil
return genericOpFn(left, right)
}
}, true
}
Expand Down
182 changes: 98 additions & 84 deletions internal/bloblang/query/arithmetic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,97 +126,13 @@ func TestArithmeticComparisons(t *testing.T) {
result any
errContains string
}{
{
name: "right null equal to int",
left: int64(12),
right: nil,
op: ArithmeticEq,
result: false,
},
{
name: "right null not equal to int",
left: int64(12),
right: nil,
op: ArithmeticNeq,
result: true,
},
{
name: "left null equal to int",
left: nil,
right: int64(10),
op: ArithmeticEq,
result: false,
},
{
name: "left null not equal to int",
left: nil,
right: int64(12),
op: ArithmeticNeq,
result: true,
},
{
name: "null equal to null",
left: nil,
right: nil,
op: ArithmeticEq,
result: true,
},
{
name: "right null equal to string",
left: "foo",
right: nil,
op: ArithmeticEq,
result: false,
},
{
name: "right null not equal to string",
left: "foo",
right: nil,
op: ArithmeticNeq,
result: true,
},
{
name: "left null equal to string",
left: nil,
right: "foo",
op: ArithmeticEq,
result: false,
},
{
name: "left null not equal to string",
left: nil,
right: "foo",
op: ArithmeticNeq,
result: true,
},
{
name: "right null equal to bool",
left: true,
right: nil,
op: ArithmeticEq,
result: false,
},
{
name: "right null not equal to bool",
left: true,
right: nil,
op: ArithmeticNeq,
result: true,
},
{
name: "left null equal to bool",
left: nil,
right: true,
op: ArithmeticEq,
result: false,
},
{
name: "left null not equal to bool",
left: nil,
right: true,
op: ArithmeticNeq,
result: true,
},
{
name: "false equal true",
left: false,
Expand Down Expand Up @@ -902,3 +818,101 @@ func TestArithmeticTargets(t *testing.T) {
})
}
}

func TestArithmeticComparisonsErrors(t *testing.T) {
testCases := []struct {
name string
left any
right any
op ArithmeticOperator
result any
errContains string
}{
{
name: "right null equal to int",
left: int64(12),
right: nil,
op: ArithmeticEq,
},
{
name: "right null not equal to int",
left: int64(12),
right: nil,
op: ArithmeticNeq,
},
{
name: "left null equal to int",
left: nil,
right: int64(10),
op: ArithmeticEq,
},
{
name: "left null not equal to int",
left: nil,
right: int64(12),
op: ArithmeticNeq,
},
{
name: "right null equal to string",
left: "foo",
right: nil,
op: ArithmeticEq,
},
{
name: "right null not equal to string",
left: "foo",
right: nil,
op: ArithmeticNeq,
},
{
name: "left null equal to string",
left: nil,
right: "foo",
op: ArithmeticEq,
},
{
name: "left null not equal to string",
left: nil,
right: "foo",
op: ArithmeticNeq,
},
{
name: "right null equal to bool",
left: true,
right: nil,
op: ArithmeticEq,
},
{
name: "right null not equal to bool",
left: true,
right: nil,
op: ArithmeticNeq,
},
{
name: "left null equal to bool",
left: nil,
right: true,
op: ArithmeticEq,
},
{
name: "left null not equal to bool",
left: nil,
right: true,
op: ArithmeticNeq,
},
}

for _, test := range testCases {
test := test
t.Run(test.name, func(t *testing.T) {
_, err := NewArithmeticExpression(
[]Function{
NewLiteralFunction("left", test.left),
NewLiteralFunction("right", test.right),
},
[]ArithmeticOperator{test.op},
)
require.Error(t, err)
})
}
}
12 changes: 4 additions & 8 deletions internal/bloblang/query/methods_structured.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,15 +221,11 @@ var _ = registerSimpleMethod(
return bytes.Contains(t, bsub), nil
case []any:
for _, compareLeft := range t {
if value.ICompare(compareRight, compareLeft) {
return true, nil
}
return value.ICompare(compareRight, compareLeft)
}
case map[string]any:
for _, compareLeft := range t {
if value.ICompare(compareRight, compareLeft) {
return true, nil
}
return value.ICompare(compareRight, compareLeft)
}
default:
return nil, value.NewTypeError(v, value.TString, value.TArray, value.TObject)
Expand Down Expand Up @@ -458,7 +454,7 @@ var _ = registerSimpleMethod(
}

for i, elem := range array {
if value.ICompare(val, elem) {
if val, _ := value.ICompare(val, elem); val {
return i, nil
}
}
Expand Down Expand Up @@ -500,7 +496,7 @@ var _ = registerSimpleMethod(

output := []any{}
for i, elem := range array {
if value.ICompare(val, elem) {
if val, _ := value.ICompare(val, elem); val {
output = append(output, i)
}
}
Expand Down
2 changes: 1 addition & 1 deletion internal/config/test/output.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ func (m MetadataEqualsCondition) Check(fs fs.FS, dir string, p *message.Part) er
if !exists {
return fmt.Errorf("metadata key '%v' expected but not found", k)
}
if !value.ICompare(exp, act) {
if val, _ := value.ICompare(exp, act); !val {
return fmt.Errorf("metadata key '%v' mismatch\n expected: %v\n received: %v", k, blue(exp), red(act))
}
}
Expand Down