Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 106 additions & 9 deletions operators.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
package kallax

import (
"database/sql/driver"
"fmt"

"github.com/src-d/go-kallax/types"

"github.com/Masterminds/squirrel"
)

// Condition represents a condition of filtering in a query.
type Condition func(Schema) squirrel.Sqlizer

type not struct {
cond squirrel.Sqlizer
}

// Eq returns a condition that will be true when `col` is equal to `value`.
func Eq(col SchemaField, value interface{}) Condition {
return func(schema Schema) squirrel.Sqlizer {
Expand Down Expand Up @@ -94,14 +93,100 @@ func NotIn(col SchemaField, values ...interface{}) Condition {
}
}

func condsToSqlizers(conds []Condition, schema Schema) []squirrel.Sqlizer {
var result = make([]squirrel.Sqlizer, len(conds))
for i, v := range conds {
result[i] = v(schema)
// ArrayEq returns a condition that will be true when `col` is equal to an
// array with the given elements.
func ArrayEq(col SchemaField, values ...interface{}) Condition {
return func(schema Schema) squirrel.Sqlizer {
return &colOp{col.QualifiedName(schema), "=", types.Slice(values)}
}
}

// ArrayNotEq returns a condition that will be true when `col` is not equal to
// an array with the given elements.
func ArrayNotEq(col SchemaField, values ...interface{}) Condition {
return func(schema Schema) squirrel.Sqlizer {
return &colOp{col.QualifiedName(schema), "<>", types.Slice(values)}
}
}

// ArrayLt returns a condition that will be true when all elements in `col`
// are lower or equal than their counterparts in the given values, and one of
// the elements at least is lower than its counterpart in the given values.
// For example: for a col with values [1,2,2] and values [1,2,3], it will be
// true.
func ArrayLt(col SchemaField, values ...interface{}) Condition {
return func(schema Schema) squirrel.Sqlizer {
return &colOp{col.QualifiedName(schema), "<", types.Slice(values)}
}
}

// ArrayGt returns a condition that will be true when all elements in `col`
// are greater or equal than their counterparts in the given values, and one of
// the elements at least is greater than its counterpart in the given values.
// For example: for a col with values [1,2,3] and values [1,2,2], it will be
// true.
func ArrayGt(col SchemaField, values ...interface{}) Condition {
return func(schema Schema) squirrel.Sqlizer {
return &colOp{col.QualifiedName(schema), ">", types.Slice(values)}
}
}

// ArrayLtOrEq returns a condition that will be true when all elements in `col`
// are lower or equal than their counterparts in the given values.
// For example: for a col with values [1,2,2] and values [1,2,2], it will be
// true.
func ArrayLtOrEq(col SchemaField, values ...interface{}) Condition {
return func(schema Schema) squirrel.Sqlizer {
return &colOp{col.QualifiedName(schema), "<=", types.Slice(values)}
}
}

// ArrayGtOrEq returns a condition that will be true when all elements in `col`
// are greater or equal than their counterparts in the given values.
// For example: for a col with values [1,2,2] and values [1,2,2], it will be
// true.
func ArrayGtOrEq(col SchemaField, values ...interface{}) Condition {
return func(schema Schema) squirrel.Sqlizer {
return &colOp{col.QualifiedName(schema), ">=", types.Slice(values)}
}
}

// ArrayContains returns a condition that will be true when `col` contains all the
// given values.
func ArrayContains(col SchemaField, values ...interface{}) Condition {
return func(schema Schema) squirrel.Sqlizer {
return &colOp{col.QualifiedName(schema), "@>", types.Slice(values)}
}
}

// ArrayContainedBy returns a condition that will be true when `col` has all
// its elements present in the given values.
func ArrayContainedBy(col SchemaField, values ...interface{}) Condition {
return func(schema Schema) squirrel.Sqlizer {
return &colOp{col.QualifiedName(schema), "<@", types.Slice(values)}
}
return result
}

// ArrayOverlap returns a condition that will be true when `col` has elements
// in common with an array formed by the given values.
func ArrayOverlap(col SchemaField, values ...interface{}) Condition {
return func(schema Schema) squirrel.Sqlizer {
return &colOp{col.QualifiedName(schema), "&&", types.Slice(values)}
}
}

type (
not struct {
cond squirrel.Sqlizer
}

colOp struct {
col string
op string
valuer driver.Valuer
}
)

func (n not) ToSql() (string, []interface{}, error) {
sql, args, err := n.cond.ToSql()
if err != nil {
Expand All @@ -110,3 +195,15 @@ func (n not) ToSql() (string, []interface{}, error) {

return fmt.Sprintf("NOT (%s)", sql), args, err
}

func (o colOp) ToSql() (string, []interface{}, error) {
return fmt.Sprintf("%s %s ?", o.col, o.op), []interface{}{o.valuer}, nil
}

func condsToSqlizers(conds []Condition, schema Schema) []squirrel.Sqlizer {
var result = make([]squirrel.Sqlizer, len(conds))
for i, v := range conds {
result[i] = v(schema)
}
return result
}
131 changes: 131 additions & 0 deletions operators_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package kallax

import (
"database/sql"
"testing"

"github.com/src-d/go-kallax/types"
"github.com/stretchr/testify/suite"
)

type OpsSuite struct {
suite.Suite
db *sql.DB
store *Store
}

func (s *OpsSuite) SetupTest() {
var err error
s.db, err = openTestDB()
s.Nil(err)
_, err = s.db.Exec(`CREATE TABLE model (
id uuid PRIMARY KEY,
name varchar(255) not null,
email varchar(255) not null,
age int not null
)`)
s.Nil(err)
_, err = s.db.Exec(`CREATE TABLE slices (
id uuid PRIMARY KEY,
elems bigint[]
)`)
s.Nil(err)
s.store = NewStore(s.db)
}

func (s *OpsSuite) TearDownTest() {
_, err := s.db.Exec("DROP TABLE slices")
s.NoError(err)

_, err = s.db.Exec("DROP TABLE model")
s.NoError(err)
}

func (s *OpsSuite) TestOperators() {
cases := []struct {
name string
cond Condition
count int64
}{
{"Eq", Eq(f("name"), "Joe"), 1},
{"Gt", Gt(f("age"), 1), 2},
{"Lt", Lt(f("age"), 2), 1},
{"Neq", Neq(f("name"), "Joe"), 2},
{"GtOrEq", GtOrEq(f("age"), 2), 2},
{"LtOrEq", LtOrEq(f("age"), 3), 3},
{"Not", Not(Eq(f("name"), "Joe")), 2},
{"And", And(Neq(f("name"), "Joe"), Gt(f("age"), 1)), 2},
{"Or", Or(Neq(f("name"), "Joe"), Eq(f("age"), 1)), 3},
{"In", In(f("name"), "Joe", "Jane"), 2},
{"NotIn", NotIn(f("name"), "Joe", "Jane"), 1},
}

s.Nil(s.store.Insert(ModelSchema, newModel("Joe", "", 1)))
s.Nil(s.store.Insert(ModelSchema, newModel("Jane", "", 2)))
s.Nil(s.store.Insert(ModelSchema, newModel("Anna", "", 2)))

for _, c := range cases {
q := NewBaseQuery(ModelSchema)
q.Where(c.cond)

s.Equal(s.store.MustCount(q), c.count, c.name)
}
}

func (s *OpsSuite) TestArrayOperators() {
f := f("elems")

cases := []struct {
name string
cond Condition
ok bool
}{
{"ArrayEq", ArrayEq(f, 1, 2, 3), true},
{"ArrayEq fail", ArrayEq(f, 1, 2, 2), false},
{"ArrayNotEq", ArrayNotEq(f, 1, 2, 2), true},
{"ArrayNotEq fail", ArrayNotEq(f, 1, 2, 3), false},
{"ArrayGt", ArrayGt(f, 1, 2, 2), true},
{"ArrayGt all eq", ArrayGt(f, 1, 2, 3), false},
{"ArrayGt some lt", ArrayGt(f, 1, 3, 1), false},
{"ArrayLt", ArrayLt(f, 1, 2, 4), true},
{"ArrayLt all eq", ArrayLt(f, 1, 2, 3), false},
{"ArrayLt some gt", ArrayLt(f, 1, 1, 4), false},
{"ArrayGtOrEq", ArrayGtOrEq(f, 1, 2, 2), true},
{"ArrayGtOrEq all eq", ArrayGtOrEq(f, 1, 2, 3), true},
{"ArrayGtOrEq some lt", ArrayGtOrEq(f, 1, 3, 1), false},
{"ArrayLtOrEq", ArrayLtOrEq(f, 1, 2, 4), true},
{"ArrayLtOrEq all eq", ArrayLtOrEq(f, 1, 2, 3), true},
{"ArrayLtOrEq some gt", ArrayLtOrEq(f, 1, 1, 4), false},
{"ArrayContains", ArrayContains(f, 1, 2), true},
{"ArrayContains fail", ArrayContains(f, 5, 6), false},
{"ArrayContainedBy", ArrayContainedBy(f, 1, 2, 3, 5, 6), true},
{"ArrayContainedBy fail", ArrayContainedBy(f, 1, 2, 5, 6), false},
{"ArrayOverlap", ArrayOverlap(f, 5, 1, 7), true},
{"ArrayOverlap fail", ArrayOverlap(f, 6, 7, 8, 9), false},
}

_, err := s.db.Exec("INSERT INTO slices (id,elems) VALUES ($1, $2)", NewID(), types.Slice([]int64{1, 2, 3}))
s.NoError(err)

for _, c := range cases {
q := NewBaseQuery(SlicesSchema)
q.Where(c.cond)
cnt, err := s.store.Count(q)
s.NoError(err, c.name)
s.Equal(c.ok, cnt > 0, "success: %s", c.name)
}
}

func TestOperators(t *testing.T) {
suite.Run(t, new(OpsSuite))
}

var SlicesSchema = &BaseSchema{
alias: "_sl",
table: "slices",
id: f("id"),
columns: []SchemaField{
f("id"),
f("elems"),
},
}
31 changes: 0 additions & 31 deletions store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -385,37 +385,6 @@ func (s *StoreSuite) TestFind_1toNMultiple() {
s.Equal(100, i)
}

func (s *StoreSuite) TestOperators() {
cases := []struct {
name string
cond Condition
count int64
}{
{"Eq", Eq(f("name"), "Joe"), 1},
{"Gt", Gt(f("age"), 1), 2},
{"Lt", Lt(f("age"), 2), 1},
{"Neq", Neq(f("name"), "Joe"), 2},
{"GtOrEq", GtOrEq(f("age"), 2), 2},
{"LtOrEq", LtOrEq(f("age"), 3), 3},
{"Not", Not(Eq(f("name"), "Joe")), 2},
{"And", And(Neq(f("name"), "Joe"), Gt(f("age"), 1)), 2},
{"Or", Or(Neq(f("name"), "Joe"), Eq(f("age"), 1)), 3},
{"In", In(f("name"), "Joe", "Jane"), 2},
{"NotIn", NotIn(f("name"), "Joe", "Jane"), 1},
}

s.Nil(s.store.Insert(ModelSchema, newModel("Joe", "", 1)))
s.Nil(s.store.Insert(ModelSchema, newModel("Jane", "", 2)))
s.Nil(s.store.Insert(ModelSchema, newModel("Anna", "", 2)))

for _, c := range cases {
q := NewBaseQuery(ModelSchema)
q.Where(c.cond)

s.Equal(s.store.MustCount(q), c.count, c.name)
}
}

func (s *StoreSuite) assertModel(m *model) {
var result model
err := s.db.QueryRow("SELECT id, name, email, age FROM model WHERE id = $1", m.ID).
Expand Down