From 55bf7ae528287036636fffab8cc08ee61555e5e6 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 2 Apr 2024 02:47:32 +0530 Subject: [PATCH 1/2] Make comparisons strict --- .../parser/query_arithmetic_parser_test.go | 44 +++++++++++++++++++ .../parser/query_expression_parser.go | 2 +- internal/bloblang/query/arithmetic.go | 15 ++++--- internal/bloblang/query/methods_structured.go | 12 ++--- internal/config/test/output.go | 2 +- internal/value/type_helpers.go | 38 ++++++++-------- 6 files changed, 77 insertions(+), 36 deletions(-) diff --git a/internal/bloblang/parser/query_arithmetic_parser_test.go b/internal/bloblang/parser/query_arithmetic_parser_test.go index ac7107f001..70245b88ef 100644 --- a/internal/bloblang/parser/query_arithmetic_parser_test.go +++ b/internal/bloblang/parser/query_arithmetic_parser_test.go @@ -392,3 +392,47 @@ func TestArithmeticLiteralsParser(t *testing.T) { assert.Equal(t, v, res, k) } } + +func TestArithmeticLiteralsParserErrors(t *testing.T) { + tests := map[string]string{ + `false == "false"`: "line 1 char 1: expected bool value, got string (\"false\")", + `"false" == false`: "line 1 char 1: expected string value, got bool (false)", + `true == "true"`: "line 1 char 1: expected bool value, got string (\"true\")", + `"true" == true`: "line 1 char 1: expected string value, got bool (true)", + `420 == "420"`: "line 1 char 1: expected number value, got string (\"420\")", + `"420" == 420`: "line 1 char 1: expected string value, got number (420)", + `420.69 == "420.69"`: "line 1 char 1: expected number value, got string (\"420.69\")", + `"420.69" == 420.69`: "line 1 char 1: expected string value, got number (420.69)", + `false >= "false"`: "line 1 char 1: cannot compare types bool (from bool literal) and string (from string literal)", + `"false" <= false`: "line 1 char 1: cannot compare types string (from string literal) and bool (from bool literal)", + `true <= "true"`: "line 1 char 1: cannot compare types bool (from bool literal) and string (from string literal)", + `"true" >= true`: "line 1 char 1: cannot compare types string (from string literal) and bool (from bool literal)", + `420 <= "420"`: "line 1 char 1: cannot compare types number (from number literal) and string (from string literal)", + `"420" >= 420`: "line 1 char 1: cannot compare types string (from string literal) and number (from number literal)", + `420.69 >= "420.69"`: "line 1 char 1: cannot compare types number (from number literal) and string (from string literal)", + `"420.69" <= 420.69`: "line 1 char 1: cannot compare types string (from string literal) and number (from number literal)", + `false > "false"`: "line 1 char 1: cannot compare types bool (from bool literal) and string (from string literal)", + `"false" < false`: "line 1 char 1: cannot compare types string (from string literal) and bool (from bool literal)", + `true > "true"`: "line 1 char 1: cannot compare types bool (from bool literal) and string (from string literal)", + `"true" < true`: "line 1 char 1: cannot compare types string (from string literal) and bool (from bool literal)", + `420 > "420"`: "line 1 char 1: cannot compare types number (from number literal) and string (from string literal)", + `"420" < 420`: "line 1 char 1: cannot compare types string (from string literal) and number (from number literal)", + `420.69 > "420.69"`: "line 1 char 1: cannot compare types number (from number literal) and string (from string literal)", + `"420.69" < 420.69`: "line 1 char 1: cannot compare types string (from string literal) and number (from number literal)", + `false != "false"`: "line 1 char 1: expected bool value, got string (\"false\")", + `"false" != false`: "line 1 char 1: expected string value, got bool (false)", + `true != "true"`: "line 1 char 1: expected bool value, got string (\"true\")", + `"true" != true`: "line 1 char 1: expected string value, got bool (true)", + `420 != "420"`: "line 1 char 1: expected number value, got string (\"420\")", + `"420" != 420`: "line 1 char 1: expected string value, got number (420)", + `420.69 != "420.69"`: "line 1 char 1: expected number value, got string (\"420.69\")", + `"420.69" != 420.69`: "line 1 char 1: expected string value, got number (420.69)", + } + + for test, err := range tests { + test := test + _, pErr := tryParseQuery(test) + require.NotNil(t, pErr) + assert.Equal(t, err, pErr.ErrorAtPosition([]rune(test))) + } +} diff --git a/internal/bloblang/parser/query_expression_parser.go b/internal/bloblang/parser/query_expression_parser.go index 7d5903fdfd..a0b9e1e21d 100644 --- a/internal/bloblang/parser/query_expression_parser.go +++ b/internal/bloblang/parser/query_expression_parser.go @@ -49,7 +49,7 @@ func matchCaseParser(pCtx Context) Func[query.MatchCase] { if v == nil { return false, nil } - return value.ICompare(*v, lit.Value), nil + return value.ICompare(*v, lit.Value) }, nil) } else { caseFn = p diff --git a/internal/bloblang/query/arithmetic.go b/internal/bloblang/query/arithmetic.go index 8de294472b..c5e3f79420 100644 --- a/internal/bloblang/query/arithmetic.go +++ b/internal/bloblang/query/arithmetic.go @@ -297,13 +297,14 @@ func compareBoolFn(op ArithmeticOperator) func(lhs, rhs bool) bool { return nil } -func compareGenericFn(op ArithmeticOperator) func(lhs, rhs any) bool { +func compareGenericFn(op ArithmeticOperator) func(lhs, rhs any) (bool, error) { switch op { case ArithmeticEq: return value.ICompare case ArithmeticNeq: - return func(lhs, rhs any) bool { - return !value.ICompare(lhs, rhs) + return func(lhs, rhs any) (bool, error) { + val, err := value.ICompare(lhs, rhs) + return !val, err } } return nil @@ -332,7 +333,7 @@ func compareOp(op ArithmeticOperator) (arithmeticOpFunc, bool) { if genericOpFn == nil { return nil, NewTypeMismatch(op.String(), lFn, rFn, left, right) } - return genericOpFn(lhs, value.RestrictForComparison(right)), nil + return genericOpFn(lhs, value.RestrictForComparison(right)) } return strOpFn(lhs, rhs), nil case float64: @@ -344,7 +345,7 @@ func compareOp(op ArithmeticOperator) (arithmeticOpFunc, bool) { if genericOpFn == nil { return nil, NewTypeMismatch(op.String(), lFn, rFn, left, right) } - return genericOpFn(lhs, value.RestrictForComparison(right)), nil + return genericOpFn(lhs, value.RestrictForComparison(right)) } return numOpFn(lhs, rhs), nil case bool: @@ -356,14 +357,14 @@ func compareOp(op ArithmeticOperator) (arithmeticOpFunc, bool) { if genericOpFn == nil { return nil, NewTypeMismatch(op.String(), lFn, rFn, left, right) } - return genericOpFn(lhs, value.RestrictForComparison(right)), nil + return genericOpFn(lhs, value.RestrictForComparison(right)) } return boolOpFn(lhs, rhs), nil default: if genericOpFn == nil { return nil, NewTypeMismatch(op.String(), lFn, rFn, left, right) } - return genericOpFn(left, right), nil + return genericOpFn(left, right) } }, true } diff --git a/internal/bloblang/query/methods_structured.go b/internal/bloblang/query/methods_structured.go index 7cbe7f5988..77582bdccc 100644 --- a/internal/bloblang/query/methods_structured.go +++ b/internal/bloblang/query/methods_structured.go @@ -221,15 +221,11 @@ var _ = registerSimpleMethod( return bytes.Contains(t, bsub), nil case []any: for _, compareLeft := range t { - if value.ICompare(compareRight, compareLeft) { - return true, nil - } + return value.ICompare(compareRight, compareLeft) } case map[string]any: for _, compareLeft := range t { - if value.ICompare(compareRight, compareLeft) { - return true, nil - } + return value.ICompare(compareRight, compareLeft) } default: return nil, value.NewTypeError(v, value.TString, value.TArray, value.TObject) @@ -458,7 +454,7 @@ var _ = registerSimpleMethod( } for i, elem := range array { - if value.ICompare(val, elem) { + if val, _ := value.ICompare(val, elem); val { return i, nil } } @@ -500,7 +496,7 @@ var _ = registerSimpleMethod( output := []any{} for i, elem := range array { - if value.ICompare(val, elem) { + if val, _ := value.ICompare(val, elem); val { output = append(output, i) } } diff --git a/internal/config/test/output.go b/internal/config/test/output.go index 0d703fe02f..9c2af9ada8 100644 --- a/internal/config/test/output.go +++ b/internal/config/test/output.go @@ -300,7 +300,7 @@ func (m MetadataEqualsCondition) Check(fs fs.FS, dir string, p *message.Part) er if !exists { return fmt.Errorf("metadata key '%v' expected but not found", k) } - if !value.ICompare(exp, act) { + if val, _ := value.ICompare(exp, act); !val { return fmt.Errorf("metadata key '%v' mismatch\n expected: %v\n received: %v", k, blue(exp), red(act)) } } diff --git a/internal/value/type_helpers.go b/internal/value/type_helpers.go index d18422ec26..cf6b79a560 100644 --- a/internal/value/type_helpers.go +++ b/internal/value/type_helpers.go @@ -817,59 +817,59 @@ func IClone(root any) any { // - The types are both either a string or byte slice and the underlying data is the same // - The types are both numerical and have the same value // - Both types are a matching slice or map containing values matching these same conditions. -func ICompare(left, right any) bool { +func ICompare(left, right any) (bool, error) { if left == nil && right == nil { - return true + return true, nil } switch lhs := RestrictForComparison(left).(type) { case string: rhs, err := IGetString(right) if err != nil { - return false + return false, err } - return lhs == rhs + return lhs == rhs, nil case float64: rhs, err := IGetNumber(right) if err != nil { - return false + return false, err } - return lhs == rhs + return lhs == rhs, nil case bool: rhs, err := IGetBool(right) if err != nil { - return false + return false, err } - return lhs == rhs + return lhs == rhs, nil case []any: rhs, matches := right.([]any) if !matches { - return false + return false, NewTypeError(rhs, ITypeOf(rhs)) } if len(lhs) != len(rhs) { - return false + return false, errors.New("length mismatch") } for i, vl := range lhs { - if !ICompare(vl, rhs[i]) { - return false + if val, err := ICompare(vl, rhs[i]); !val { + return false, err } } - return true + return true, nil case map[string]any: rhs, matches := right.(map[string]any) if !matches { - return false + return false, NewTypeError(rhs, ITypeOf(rhs)) } if len(lhs) != len(rhs) { - return false + return false, errors.New("length mismatch") } for k, vl := range lhs { - if !ICompare(vl, rhs[k]) { - return false + if val, err := ICompare(vl, rhs[k]); !val { + return false, err } } - return true + return true, nil } - return false + return false, NewTypeError(left, ITypeOf(left)) } func IGetStringMap(v any) (map[string]string, error) { From 7f0820869fb72583d8f7773d7d4bbbf6ac0c808a Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 3 Apr 2024 05:18:52 +0530 Subject: [PATCH 2/2] Fix few tests --- .../parser/query_arithmetic_parser_test.go | 2 +- internal/bloblang/query/arithmetic_test.go | 182 ++++++++++-------- 2 files changed, 99 insertions(+), 85 deletions(-) diff --git a/internal/bloblang/parser/query_arithmetic_parser_test.go b/internal/bloblang/parser/query_arithmetic_parser_test.go index 70245b88ef..8efb068bcc 100644 --- a/internal/bloblang/parser/query_arithmetic_parser_test.go +++ b/internal/bloblang/parser/query_arithmetic_parser_test.go @@ -393,7 +393,7 @@ func TestArithmeticLiteralsParser(t *testing.T) { } } -func TestArithmeticLiteralsParserErrors(t *testing.T) { +func TestArithmeticParserErrors(t *testing.T) { tests := map[string]string{ `false == "false"`: "line 1 char 1: expected bool value, got string (\"false\")", `"false" == false`: "line 1 char 1: expected string value, got bool (false)", diff --git a/internal/bloblang/query/arithmetic_test.go b/internal/bloblang/query/arithmetic_test.go index 21f01377a8..40974064ee 100644 --- a/internal/bloblang/query/arithmetic_test.go +++ b/internal/bloblang/query/arithmetic_test.go @@ -126,34 +126,6 @@ func TestArithmeticComparisons(t *testing.T) { result any errContains string }{ - { - name: "right null equal to int", - left: int64(12), - right: nil, - op: ArithmeticEq, - result: false, - }, - { - name: "right null not equal to int", - left: int64(12), - right: nil, - op: ArithmeticNeq, - result: true, - }, - { - name: "left null equal to int", - left: nil, - right: int64(10), - op: ArithmeticEq, - result: false, - }, - { - name: "left null not equal to int", - left: nil, - right: int64(12), - op: ArithmeticNeq, - result: true, - }, { name: "null equal to null", left: nil, @@ -161,62 +133,6 @@ func TestArithmeticComparisons(t *testing.T) { op: ArithmeticEq, result: true, }, - { - name: "right null equal to string", - left: "foo", - right: nil, - op: ArithmeticEq, - result: false, - }, - { - name: "right null not equal to string", - left: "foo", - right: nil, - op: ArithmeticNeq, - result: true, - }, - { - name: "left null equal to string", - left: nil, - right: "foo", - op: ArithmeticEq, - result: false, - }, - { - name: "left null not equal to string", - left: nil, - right: "foo", - op: ArithmeticNeq, - result: true, - }, - { - name: "right null equal to bool", - left: true, - right: nil, - op: ArithmeticEq, - result: false, - }, - { - name: "right null not equal to bool", - left: true, - right: nil, - op: ArithmeticNeq, - result: true, - }, - { - name: "left null equal to bool", - left: nil, - right: true, - op: ArithmeticEq, - result: false, - }, - { - name: "left null not equal to bool", - left: nil, - right: true, - op: ArithmeticNeq, - result: true, - }, { name: "false equal true", left: false, @@ -902,3 +818,101 @@ func TestArithmeticTargets(t *testing.T) { }) } } + +func TestArithmeticComparisonsErrors(t *testing.T) { + testCases := []struct { + name string + left any + right any + op ArithmeticOperator + result any + errContains string + }{ + { + name: "right null equal to int", + left: int64(12), + right: nil, + op: ArithmeticEq, + }, + { + name: "right null not equal to int", + left: int64(12), + right: nil, + op: ArithmeticNeq, + }, + { + name: "left null equal to int", + left: nil, + right: int64(10), + op: ArithmeticEq, + }, + { + name: "left null not equal to int", + left: nil, + right: int64(12), + op: ArithmeticNeq, + }, + { + name: "right null equal to string", + left: "foo", + right: nil, + op: ArithmeticEq, + }, + { + name: "right null not equal to string", + left: "foo", + right: nil, + op: ArithmeticNeq, + }, + { + name: "left null equal to string", + left: nil, + right: "foo", + op: ArithmeticEq, + }, + { + name: "left null not equal to string", + left: nil, + right: "foo", + op: ArithmeticNeq, + }, + { + name: "right null equal to bool", + left: true, + right: nil, + op: ArithmeticEq, + }, + { + name: "right null not equal to bool", + left: true, + right: nil, + op: ArithmeticNeq, + }, + { + name: "left null equal to bool", + left: nil, + right: true, + op: ArithmeticEq, + }, + { + name: "left null not equal to bool", + left: nil, + right: true, + op: ArithmeticNeq, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.name, func(t *testing.T) { + _, err := NewArithmeticExpression( + []Function{ + NewLiteralFunction("left", test.left), + NewLiteralFunction("right", test.right), + }, + []ArithmeticOperator{test.op}, + ) + require.Error(t, err) + }) + } +}