From a9cd307b085c748815bc8de50312be47373b7d25 Mon Sep 17 00:00:00 2001 From: Miguel Molina Date: Mon, 30 Jan 2017 16:29:52 +0100 Subject: [PATCH] Implement array operators --- operators.go | 115 ++++++++++++++++++++++++++++++++++++---- operators_test.go | 131 ++++++++++++++++++++++++++++++++++++++++++++++ store_test.go | 31 ----------- 3 files changed, 237 insertions(+), 40 deletions(-) create mode 100644 operators_test.go diff --git a/operators.go b/operators.go index 860134b..b772751 100644 --- a/operators.go +++ b/operators.go @@ -1,18 +1,17 @@ package kallax import ( + "database/sql/driver" "fmt" + "github.com/src-d/go-kallax/types" + "github.com/Masterminds/squirrel" ) // Condition represents a condition of filtering in a query. type Condition func(Schema) squirrel.Sqlizer -type not struct { - cond squirrel.Sqlizer -} - // Eq returns a condition that will be true when `col` is equal to `value`. func Eq(col SchemaField, value interface{}) Condition { return func(schema Schema) squirrel.Sqlizer { @@ -94,14 +93,100 @@ func NotIn(col SchemaField, values ...interface{}) Condition { } } -func condsToSqlizers(conds []Condition, schema Schema) []squirrel.Sqlizer { - var result = make([]squirrel.Sqlizer, len(conds)) - for i, v := range conds { - result[i] = v(schema) +// ArrayEq returns a condition that will be true when `col` is equal to an +// array with the given elements. +func ArrayEq(col SchemaField, values ...interface{}) Condition { + return func(schema Schema) squirrel.Sqlizer { + return &colOp{col.QualifiedName(schema), "=", types.Slice(values)} + } +} + +// ArrayNotEq returns a condition that will be true when `col` is not equal to +// an array with the given elements. +func ArrayNotEq(col SchemaField, values ...interface{}) Condition { + return func(schema Schema) squirrel.Sqlizer { + return &colOp{col.QualifiedName(schema), "<>", types.Slice(values)} + } +} + +// ArrayLt returns a condition that will be true when all elements in `col` +// are lower or equal than their counterparts in the given values, and one of +// the elements at least is lower than its counterpart in the given values. +// For example: for a col with values [1,2,2] and values [1,2,3], it will be +// true. +func ArrayLt(col SchemaField, values ...interface{}) Condition { + return func(schema Schema) squirrel.Sqlizer { + return &colOp{col.QualifiedName(schema), "<", types.Slice(values)} + } +} + +// ArrayGt returns a condition that will be true when all elements in `col` +// are greater or equal than their counterparts in the given values, and one of +// the elements at least is greater than its counterpart in the given values. +// For example: for a col with values [1,2,3] and values [1,2,2], it will be +// true. +func ArrayGt(col SchemaField, values ...interface{}) Condition { + return func(schema Schema) squirrel.Sqlizer { + return &colOp{col.QualifiedName(schema), ">", types.Slice(values)} + } +} + +// ArrayLtOrEq returns a condition that will be true when all elements in `col` +// are lower or equal than their counterparts in the given values. +// For example: for a col with values [1,2,2] and values [1,2,2], it will be +// true. +func ArrayLtOrEq(col SchemaField, values ...interface{}) Condition { + return func(schema Schema) squirrel.Sqlizer { + return &colOp{col.QualifiedName(schema), "<=", types.Slice(values)} + } +} + +// ArrayGtOrEq returns a condition that will be true when all elements in `col` +// are greater or equal than their counterparts in the given values. +// For example: for a col with values [1,2,2] and values [1,2,2], it will be +// true. +func ArrayGtOrEq(col SchemaField, values ...interface{}) Condition { + return func(schema Schema) squirrel.Sqlizer { + return &colOp{col.QualifiedName(schema), ">=", types.Slice(values)} + } +} + +// ArrayContains returns a condition that will be true when `col` contains all the +// given values. +func ArrayContains(col SchemaField, values ...interface{}) Condition { + return func(schema Schema) squirrel.Sqlizer { + return &colOp{col.QualifiedName(schema), "@>", types.Slice(values)} + } +} + +// ArrayContainedBy returns a condition that will be true when `col` has all +// its elements present in the given values. +func ArrayContainedBy(col SchemaField, values ...interface{}) Condition { + return func(schema Schema) squirrel.Sqlizer { + return &colOp{col.QualifiedName(schema), "<@", types.Slice(values)} } - return result } +// ArrayOverlap returns a condition that will be true when `col` has elements +// in common with an array formed by the given values. +func ArrayOverlap(col SchemaField, values ...interface{}) Condition { + return func(schema Schema) squirrel.Sqlizer { + return &colOp{col.QualifiedName(schema), "&&", types.Slice(values)} + } +} + +type ( + not struct { + cond squirrel.Sqlizer + } + + colOp struct { + col string + op string + valuer driver.Valuer + } +) + func (n not) ToSql() (string, []interface{}, error) { sql, args, err := n.cond.ToSql() if err != nil { @@ -110,3 +195,15 @@ func (n not) ToSql() (string, []interface{}, error) { return fmt.Sprintf("NOT (%s)", sql), args, err } + +func (o colOp) ToSql() (string, []interface{}, error) { + return fmt.Sprintf("%s %s ?", o.col, o.op), []interface{}{o.valuer}, nil +} + +func condsToSqlizers(conds []Condition, schema Schema) []squirrel.Sqlizer { + var result = make([]squirrel.Sqlizer, len(conds)) + for i, v := range conds { + result[i] = v(schema) + } + return result +} diff --git a/operators_test.go b/operators_test.go new file mode 100644 index 0000000..bc37354 --- /dev/null +++ b/operators_test.go @@ -0,0 +1,131 @@ +package kallax + +import ( + "database/sql" + "testing" + + "github.com/src-d/go-kallax/types" + "github.com/stretchr/testify/suite" +) + +type OpsSuite struct { + suite.Suite + db *sql.DB + store *Store +} + +func (s *OpsSuite) SetupTest() { + var err error + s.db, err = openTestDB() + s.Nil(err) + _, err = s.db.Exec(`CREATE TABLE model ( + id uuid PRIMARY KEY, + name varchar(255) not null, + email varchar(255) not null, + age int not null + )`) + s.Nil(err) + _, err = s.db.Exec(`CREATE TABLE slices ( + id uuid PRIMARY KEY, + elems bigint[] + )`) + s.Nil(err) + s.store = NewStore(s.db) +} + +func (s *OpsSuite) TearDownTest() { + _, err := s.db.Exec("DROP TABLE slices") + s.NoError(err) + + _, err = s.db.Exec("DROP TABLE model") + s.NoError(err) +} + +func (s *OpsSuite) TestOperators() { + cases := []struct { + name string + cond Condition + count int64 + }{ + {"Eq", Eq(f("name"), "Joe"), 1}, + {"Gt", Gt(f("age"), 1), 2}, + {"Lt", Lt(f("age"), 2), 1}, + {"Neq", Neq(f("name"), "Joe"), 2}, + {"GtOrEq", GtOrEq(f("age"), 2), 2}, + {"LtOrEq", LtOrEq(f("age"), 3), 3}, + {"Not", Not(Eq(f("name"), "Joe")), 2}, + {"And", And(Neq(f("name"), "Joe"), Gt(f("age"), 1)), 2}, + {"Or", Or(Neq(f("name"), "Joe"), Eq(f("age"), 1)), 3}, + {"In", In(f("name"), "Joe", "Jane"), 2}, + {"NotIn", NotIn(f("name"), "Joe", "Jane"), 1}, + } + + s.Nil(s.store.Insert(ModelSchema, newModel("Joe", "", 1))) + s.Nil(s.store.Insert(ModelSchema, newModel("Jane", "", 2))) + s.Nil(s.store.Insert(ModelSchema, newModel("Anna", "", 2))) + + for _, c := range cases { + q := NewBaseQuery(ModelSchema) + q.Where(c.cond) + + s.Equal(s.store.MustCount(q), c.count, c.name) + } +} + +func (s *OpsSuite) TestArrayOperators() { + f := f("elems") + + cases := []struct { + name string + cond Condition + ok bool + }{ + {"ArrayEq", ArrayEq(f, 1, 2, 3), true}, + {"ArrayEq fail", ArrayEq(f, 1, 2, 2), false}, + {"ArrayNotEq", ArrayNotEq(f, 1, 2, 2), true}, + {"ArrayNotEq fail", ArrayNotEq(f, 1, 2, 3), false}, + {"ArrayGt", ArrayGt(f, 1, 2, 2), true}, + {"ArrayGt all eq", ArrayGt(f, 1, 2, 3), false}, + {"ArrayGt some lt", ArrayGt(f, 1, 3, 1), false}, + {"ArrayLt", ArrayLt(f, 1, 2, 4), true}, + {"ArrayLt all eq", ArrayLt(f, 1, 2, 3), false}, + {"ArrayLt some gt", ArrayLt(f, 1, 1, 4), false}, + {"ArrayGtOrEq", ArrayGtOrEq(f, 1, 2, 2), true}, + {"ArrayGtOrEq all eq", ArrayGtOrEq(f, 1, 2, 3), true}, + {"ArrayGtOrEq some lt", ArrayGtOrEq(f, 1, 3, 1), false}, + {"ArrayLtOrEq", ArrayLtOrEq(f, 1, 2, 4), true}, + {"ArrayLtOrEq all eq", ArrayLtOrEq(f, 1, 2, 3), true}, + {"ArrayLtOrEq some gt", ArrayLtOrEq(f, 1, 1, 4), false}, + {"ArrayContains", ArrayContains(f, 1, 2), true}, + {"ArrayContains fail", ArrayContains(f, 5, 6), false}, + {"ArrayContainedBy", ArrayContainedBy(f, 1, 2, 3, 5, 6), true}, + {"ArrayContainedBy fail", ArrayContainedBy(f, 1, 2, 5, 6), false}, + {"ArrayOverlap", ArrayOverlap(f, 5, 1, 7), true}, + {"ArrayOverlap fail", ArrayOverlap(f, 6, 7, 8, 9), false}, + } + + _, err := s.db.Exec("INSERT INTO slices (id,elems) VALUES ($1, $2)", NewID(), types.Slice([]int64{1, 2, 3})) + s.NoError(err) + + for _, c := range cases { + q := NewBaseQuery(SlicesSchema) + q.Where(c.cond) + cnt, err := s.store.Count(q) + s.NoError(err, c.name) + s.Equal(c.ok, cnt > 0, "success: %s", c.name) + } +} + +func TestOperators(t *testing.T) { + suite.Run(t, new(OpsSuite)) +} + +var SlicesSchema = &BaseSchema{ + alias: "_sl", + table: "slices", + id: f("id"), + columns: []SchemaField{ + f("id"), + f("elems"), + }, +} diff --git a/store_test.go b/store_test.go index 5f51095..d9532f7 100644 --- a/store_test.go +++ b/store_test.go @@ -385,37 +385,6 @@ func (s *StoreSuite) TestFind_1toNMultiple() { s.Equal(100, i) } -func (s *StoreSuite) TestOperators() { - cases := []struct { - name string - cond Condition - count int64 - }{ - {"Eq", Eq(f("name"), "Joe"), 1}, - {"Gt", Gt(f("age"), 1), 2}, - {"Lt", Lt(f("age"), 2), 1}, - {"Neq", Neq(f("name"), "Joe"), 2}, - {"GtOrEq", GtOrEq(f("age"), 2), 2}, - {"LtOrEq", LtOrEq(f("age"), 3), 3}, - {"Not", Not(Eq(f("name"), "Joe")), 2}, - {"And", And(Neq(f("name"), "Joe"), Gt(f("age"), 1)), 2}, - {"Or", Or(Neq(f("name"), "Joe"), Eq(f("age"), 1)), 3}, - {"In", In(f("name"), "Joe", "Jane"), 2}, - {"NotIn", NotIn(f("name"), "Joe", "Jane"), 1}, - } - - s.Nil(s.store.Insert(ModelSchema, newModel("Joe", "", 1))) - s.Nil(s.store.Insert(ModelSchema, newModel("Jane", "", 2))) - s.Nil(s.store.Insert(ModelSchema, newModel("Anna", "", 2))) - - for _, c := range cases { - q := NewBaseQuery(ModelSchema) - q.Where(c.cond) - - s.Equal(s.store.MustCount(q), c.count, c.name) - } -} - func (s *StoreSuite) assertModel(m *model) { var result model err := s.db.QueryRow("SELECT id, name, email, age FROM model WHERE id = $1", m.ID).