Skip to content

Commit

Permalink
Merge pull request #318 from upper/feature/amend-query-before-executi…
Browse files Browse the repository at this point in the history
…ng-it

Add Amend method and tests.
  • Loading branch information
José Carlos committed Jan 6, 2017
2 parents 9b757d6 + a3f82e5 commit 0b3a983
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 7 deletions.
19 changes: 16 additions & 3 deletions internal/sqladapter/exql/statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ type Statement struct {

SQL string

hash hash
hash hash
amendFn func(string) string
}

type statementT struct {
Expand Down Expand Up @@ -63,6 +64,17 @@ func (s *Statement) Hash() string {
return s.hash.Hash(s)
}

func (s *Statement) SetAmendment(amendFn func(string) string) {
s.amendFn = amendFn
}

func (s *Statement) Amend(in string) string {
if s.amendFn == nil {
return in
}
return s.amendFn(in)
}

// Compile transforms the Statement into an equivalent SQL query.
func (s *Statement) Compile(layout *Template) (compiled string) {
if s.Type == SQL {
Expand All @@ -71,7 +83,7 @@ func (s *Statement) Compile(layout *Template) (compiled string) {
}

if z, ok := layout.Read(s); ok {
return z
return s.Amend(z)
}

data := statementT{
Expand Down Expand Up @@ -112,7 +124,8 @@ func (s *Statement) Compile(layout *Template) (compiled string) {

compiled = strings.TrimSpace(compiled)
layout.Write(s, compiled)
return compiled

return s.Amend(compiled)
}

// RawSQL represents a raw SQL statement.
Expand Down
11 changes: 10 additions & 1 deletion internal/sqladapter/testing/adapter.go.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -1294,7 +1294,16 @@ func TestBatchInsert(t *testing.T) {
err := sess.Collection("artist").Truncate()
assert.NoError(t, err)
batch := sess.InsertInto("artist").Columns("name").Batch(batchSize)
q := sess.InsertInto("artist").Columns("name")
if Adapter == "postgresql" {
q = q.Amend(func(query string) string {
return query + ` ON CONFLICT DO NOTHING`
})
}


batch := q.Batch(batchSize)

totalItems := int(rand.Int31n(21))

Expand Down
28 changes: 28 additions & 0 deletions lib/sqlbuilder/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ func TestSelect(t *testing.T) {
b.Select(db.Func("DATE")).String(),
)

assert.Equal(
`SELECT DATE() FOR UPDATE`,
b.Select(db.Func("DATE")).Amend(func(query string) string {
return query + " FOR UPDATE"
}).String(),
)

assert.Equal(
`SELECT * FROM "artist"`,
b.SelectFrom("artist").String(),
Expand Down Expand Up @@ -693,6 +700,13 @@ func TestInsert(t *testing.T) {
b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).Returning("id").String(),
)

assert.Equal(
`INSERT INTO "artist" ("id", "name") VALUES ($1, $2) RETURNING "id"`,
b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).Amend(func(query string) string {
return query + ` RETURNING "id"`
}).String(),
)

assert.Equal(
`INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`,
b.InsertInto("artist").Values(map[string]interface{}{"name": "Chavela Vargas", "id": 12}).String(),
Expand Down Expand Up @@ -882,6 +896,13 @@ func TestUpdate(t *testing.T) {
b.Update("artist").Set("name", "Artist").String(),
)

assert.Equal(
`UPDATE "artist" SET "name" = $1 RETURNING "name"`,
b.Update("artist").Set("name", "Artist").Amend(func(query string) string {
return query + ` RETURNING "name"`
}).String(),
)

{
idSlice := []int64{8, 7, 6}
q := b.Update("artist").Set(db.Cond{"some_column": 10}).Where(db.Cond{"id": 1}, db.Cond{"another_val": idSlice})
Expand Down Expand Up @@ -1014,6 +1035,13 @@ func TestDelete(t *testing.T) {
bt.DeleteFrom("artist").Where("name = ?", "Chavela Vargas").String(),
)

assert.Equal(
`DELETE FROM "artist" WHERE (name = $1) RETURNING 1`,
bt.DeleteFrom("artist").Where("name = ?", "Chavela Vargas").Amend(func(query string) string {
return fmt.Sprintf("%s RETURNING 1", query)
}).String(),
)

assert.Equal(
`DELETE FROM "artist" WHERE (id > 5)`,
bt.DeleteFrom("artist").Where("id > 5").String(),
Expand Down
8 changes: 8 additions & 0 deletions lib/sqlbuilder/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type deleter struct {
limit int
where *exql.Where
arguments []interface{}
amendFn func(string) string
}

func (qd *deleter) Where(terms ...interface{}) Deleter {
Expand All @@ -27,6 +28,11 @@ func (qd *deleter) Limit(limit int) Deleter {
return qd
}

func (qd *deleter) Amend(fn func(string) string) Deleter {
qd.amendFn = fn
return qd
}

func (qd *deleter) Arguments() []interface{} {
return qd.arguments
}
Expand All @@ -49,5 +55,7 @@ func (qd *deleter) statement() *exql.Statement {
stmt.Limit = exql.Limit(qd.limit)
}

stmt.SetAmendment(qd.amendFn)

return stmt
}
11 changes: 10 additions & 1 deletion lib/sqlbuilder/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ type inserter struct {
returning []exql.Fragment
columns []exql.Fragment
arguments []interface{}
extra string

amendFn func(string) string
extra string
}

func (qi *inserter) clone() *inserter {
Expand All @@ -31,6 +33,11 @@ func (qi *inserter) Batch(n int) *BatchInserter {
return newBatchInserter(qi.clone(), n)
}

func (qi *inserter) Amend(fn func(string) string) Inserter {
qi.amendFn = fn
return qi
}

func (qi *inserter) Arguments() []interface{} {
_ = qi.statement()
return qi.arguments
Expand Down Expand Up @@ -169,5 +176,7 @@ func (qi *inserter) statement() *exql.Statement {
stmt.Returning = exql.ReturningColumns(qi.returning...)
}

stmt.SetAmendment(qi.amendFn)

return stmt
}
16 changes: 16 additions & 0 deletions lib/sqlbuilder/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@ type Selector interface {
// return results.
Offset(int) Selector

// Amend lets you alter the query's text just before sending it to the
// database server.
Amend(func(queryIn string) (queryOut string)) Selector

// Iterator provides methods to iterate over the results returned by the
// Selector.
Iterator() Iterator
Expand Down Expand Up @@ -330,6 +334,10 @@ type Inserter interface {
// Inserter. This is only possible when using Returning().
Iterator() Iterator

// Amend lets you alter the query's text just before sending it to the
// database server.
Amend(func(queryIn string) (queryOut string)) Inserter

// Batch provies a BatchInserter that can be used to insert many elements at
// once by issuing several calls to Values(). It accepts a size parameter
// which defines the batch size. If size is < 1, the batch size is set to 1.
Expand Down Expand Up @@ -359,6 +367,10 @@ type Deleter interface {
// See Selector.Limit for documentation and usage examples.
Limit(int) Deleter

// Amend lets you alter the query's text just before sending it to the
// database server.
Amend(func(queryIn string) (queryOut string)) Deleter

// Execer provides the Exec method.
Execer

Expand Down Expand Up @@ -394,6 +406,10 @@ type Updater interface {

// Arguments returns the arguments that are prepared for this query.
Arguments() []interface{}

// Amend lets you alter the query's text just before sending it to the
// database server.
Amend(func(queryIn string) (queryOut string)) Updater
}

// Execer provides methods for executing statements that do not return results.
Expand Down
14 changes: 12 additions & 2 deletions lib/sqlbuilder/select.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ type selector struct {
joins []*exql.Join
joinsArgs []interface{}

mu sync.Mutex
mu sync.Mutex
amendFn func(string) string

err error
}
Expand Down Expand Up @@ -117,6 +118,11 @@ func (qs *selector) And(terms ...interface{}) Selector {
return qs
}

func (qs *selector) Amend(fn func(string) string) Selector {
qs.amendFn = fn
return qs
}

func (qs *selector) Arguments() []interface{} {
qs.mu.Lock()
defer qs.mu.Unlock()
Expand Down Expand Up @@ -326,7 +332,7 @@ func (qs *selector) Offset(n int) Selector {
}

func (qs *selector) statement() *exql.Statement {
return &exql.Statement{
stmt := &exql.Statement{
Type: exql.Select,
Table: qs.table,
Columns: qs.columns,
Expand All @@ -337,6 +343,10 @@ func (qs *selector) statement() *exql.Statement {
OrderBy: qs.orderBy,
GroupBy: qs.groupBy,
}

stmt.SetAmendment(qs.amendFn)

return stmt
}

func (qs *selector) Query() (*sql.Rows, error) {
Expand Down
9 changes: 9 additions & 0 deletions lib/sqlbuilder/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ type updater struct {
where *exql.Where
whereArgs []interface{}

amendFn func(string) string

mu sync.Mutex
}

Expand Down Expand Up @@ -58,6 +60,11 @@ func (qu *updater) Set(columns ...interface{}) Updater {
return qu
}

func (qu *updater) Amend(fn func(string) string) Updater {
qu.amendFn = fn
return qu
}

func (qu *updater) Arguments() []interface{} {
qu.mu.Lock()
defer qu.mu.Unlock()
Expand Down Expand Up @@ -99,5 +106,7 @@ func (qu *updater) statement() *exql.Statement {
stmt.Limit = exql.Limit(qu.limit)
}

stmt.SetAmendment(qu.amendFn)

return stmt
}

0 comments on commit 0b3a983

Please sign in to comment.