Skip to content

Commit

Permalink
Merge pull request #228 from stephenafamo/function-fix
Browse files Browse the repository at this point in the history
Rework function starter for more flexibility
  • Loading branch information
stephenafamo committed Jun 4, 2024
2 parents f9f01a6 + 562489b commit a85e5ea
Show file tree
Hide file tree
Showing 24 changed files with 370 additions and 238 deletions.
20 changes: 20 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
models.SelectJoins.Jets.InnerJoin.Pilots(ctx).AliasedAs("p")
```

- Add `fm` mods to all supported dialects (psql, mysql and sqlite). These are mods for functions and are used to modify the function call. For example:

```go
// import "github.com/stephenafamo/bob/dialect/psql/fm"
psql.F( "count", "*",)(fm.Filter(psql.Quote("status").EQ(psql.S("done"))))
```

### Changed

- Change the function call point for generated relationship join mods. This reduces the amount of allocations and only does the work for the relationship being used.
Expand All @@ -45,9 +52,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
items, err := query.All()
```

- Changed how functions are modified. Instead of chained methods, the `F()` starter now returns a function which can be called with mods:

```go
// Before
psql.F( "count", "*",).FilterWhere(psql.Quote("status").EQ(psql.S("done"))),
// After
// import "github.com/stephenafamo/bob/dialect/psql/fm"
psql.F( "count", "*",)(fm.Filter(psql.Quote("status").EQ(psql.S("done")))),
```

This makes it possible to support more queries.

### Removed

- Remove `TableWhere` function from the generated code. It was not used by the rest of the generated code and offered no clear benefit.
- Removed `As` starter. It takes an `Expression` and is not needed since the `Expression` has an `As` method which can be used directly.

## [v0.26.1] - 2024-05-26

Expand Down
10 changes: 5 additions & 5 deletions clause/window.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,26 @@ type IWindow interface {
SetExclusion(string)
}

type WindowDef struct {
type Window struct {
From string // an existing window name
orderBy []any
partitionBy []any
Frame
}

func (wi *WindowDef) SetFrom(from string) {
func (wi *Window) SetFrom(from string) {
wi.From = from
}

func (wi *WindowDef) AddPartitionBy(condition ...any) {
func (wi *Window) AddPartitionBy(condition ...any) {
wi.partitionBy = append(wi.partitionBy, condition...)
}

func (wi *WindowDef) AddOrderBy(order ...any) {
func (wi *Window) AddOrderBy(order ...any) {
wi.orderBy = append(wi.orderBy, order...)
}

func (wi WindowDef) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) {
func (wi Window) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) {
if wi.From != "" {
w.Write([]byte(wi.From))
w.Write([]byte(" "))
Expand Down
61 changes: 25 additions & 36 deletions dialect/mysql/dialect/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,39 +8,29 @@ import (
"github.com/stephenafamo/bob/expr"
)

func NewFunction(name string, args ...any) Function {
return Function{name: name, args: args}
func NewFunction(name string, args ...any) *Function {
f := &Function{name: name, args: args}
f.Chain = expr.Chain[Expression, Expression]{Base: f}

return f
}

type Function struct {
name string
args []any

// Used in value functions. Supported by Sqlite and Postgres
filter []any
Distinct bool
clause.OrderBy
Filter []any
w *clause.Window

// For chain methods
expr.Chain[Expression, Expression]
}

// A function can be a target for a query
func (f *Function) Apply(q *clause.From) {
q.Table = f
}

func (f *Function) Filter(e ...any) *Function {
f.filter = append(f.filter, e...)

return f
}

func (f *Function) Over() *functionOver {
fo := &functionOver{
function: f,
}
fo.WindowChain = &WindowChain[*functionOver]{Wrap: fo}
fo.Base = fo
return fo
func (f *Function) SetWindow(w clause.Window) {
f.w = &w
}

func (f Function) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) {
Expand All @@ -50,37 +40,36 @@ func (f Function) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error)

w.Write([]byte(f.name))
w.Write([]byte("("))

if f.Distinct {
w.Write([]byte("DISTINCT "))
}

args, err := bob.ExpressSlice(w, d, start, f.args, "", ", ", "")
if err != nil {
return nil, err
}
w.Write([]byte(")"))

filterArgs, err := bob.ExpressSlice(w, d, start, f.filter, " FILTER (WHERE ", " AND ", ")")
orderArgs, err := bob.ExpressIf(w, d, start+len(args), f.OrderBy,
len(f.OrderBy.Expressions) > 0, " ", "")
if err != nil {
return nil, err
}
args = append(args, filterArgs...)

return args, nil
}
args = append(args, orderArgs...)

type functionOver struct {
function *Function
*WindowChain[*functionOver]
expr.Chain[Expression, Expression]
}
w.Write([]byte(")"))

func (wr *functionOver) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) {
fargs, err := bob.Express(w, d, start, wr.function)
filterArgs, err := bob.ExpressSlice(w, d, start, f.Filter, " FILTER (WHERE ", " AND ", ")")
if err != nil {
return nil, err
}
args = append(args, filterArgs...)

winargs, err := bob.ExpressIf(w, d, start+len(fargs), wr.def, true, "OVER (", ")")
winargs, err := bob.ExpressIf(w, d, start+len(args), f.w, f.w != nil, "OVER (", ")")
if err != nil {
return nil, err
}
args = append(args, winargs...)

return append(fargs, winargs...), nil
return args, nil
}
14 changes: 11 additions & 3 deletions dialect/mysql/dialect/mods.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,20 +328,28 @@ func (l LockChain[Q]) SkipLocked() LockChain[Q] {
})
}

type WindowMod[Q interface{ AppendWindow(clause.NamedWindow) }] struct {
Name string
type WindowMod[Q interface{ SetWindow(clause.Window) }] struct {
*WindowChain[*WindowMod[Q]]
}

func (w WindowMod[Q]) Apply(q Q) {
q.SetWindow(w.def)
}

type WindowsMod[Q interface{ AppendWindow(clause.NamedWindow) }] struct {
Name string
*WindowChain[*WindowsMod[Q]]
}

func (w WindowsMod[Q]) Apply(q Q) {
q.AppendWindow(clause.NamedWindow{
Name: w.Name,
Definition: w.def,
})
}

type WindowChain[T any] struct {
def clause.WindowDef
def clause.Window
Wrap T
}

Expand Down
36 changes: 36 additions & 0 deletions dialect/mysql/fm/qm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package fm

import (
"github.com/stephenafamo/bob"
"github.com/stephenafamo/bob/clause"
"github.com/stephenafamo/bob/dialect/mysql/dialect"
"github.com/stephenafamo/bob/mods"
)

func Distinct() bob.Mod[*dialect.Function] {
return mods.QueryModFunc[*dialect.Function](func(f *dialect.Function) {
f.Distinct = true
})
}

func OrderBy(e any) dialect.OrderBy[*dialect.Function] {
return dialect.OrderBy[*dialect.Function](func() clause.OrderDef {
return clause.OrderDef{
Expression: e,
}
})
}

func Filter(e ...any) bob.Mod[*dialect.Function] {
return mods.QueryModFunc[*dialect.Function](func(f *dialect.Function) {
f.Filter = append(f.Filter, e...)
})
}

func Over() dialect.WindowMod[*dialect.Function] {
m := dialect.WindowMod[*dialect.Function]{}
m.WindowChain = &dialect.WindowChain[*dialect.WindowMod[*dialect.Function]]{
Wrap: &m,
}
return m
}
10 changes: 4 additions & 6 deletions dialect/mysql/select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/antlr/antlr4/runtime/Go/antlr/v4"
"github.com/stephenafamo/bob/dialect/mysql"
"github.com/stephenafamo/bob/dialect/mysql/fm"
"github.com/stephenafamo/bob/dialect/mysql/sm"
testutils "github.com/stephenafamo/bob/test/utils"
mysqlparser "github.com/stephenafamo/sqlparser/mysql"
Expand Down Expand Up @@ -49,12 +50,9 @@ func TestSelect(t *testing.T) {
sm.From(mysql.Select(
sm.Columns(
"status",
mysql.F("LEAD", "created_date", 1, mysql.F("NOW")).
Over().
PartitionBy("presale_id").
OrderBy("created_date").
Minus(mysql.Quote("created_date")).
As("difference")),
mysql.F("LEAD", "created_date", 1, mysql.F("NOW"))(
fm.Over().PartitionBy("presale_id").OrderBy("created_date"),
).Minus(mysql.Quote("created_date")).As("difference")),
sm.From("presales_presalestatus")),
).As("differnce_by_status"),
sm.Where(mysql.Quote("status").In(mysql.S("A"), mysql.S("B"), mysql.S("C"))),
Expand Down
6 changes: 3 additions & 3 deletions dialect/mysql/sm/qm.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ func WithRollup(distinct bool) bob.Mod[*dialect.SelectQuery] {
})
}

func Window(name string) dialect.WindowMod[*dialect.SelectQuery] {
m := dialect.WindowMod[*dialect.SelectQuery]{
func Window(name string) dialect.WindowsMod[*dialect.SelectQuery] {
m := dialect.WindowsMod[*dialect.SelectQuery]{
Name: name,
}

m.WindowChain = &dialect.WindowChain[*dialect.WindowMod[*dialect.SelectQuery]]{
m.WindowChain = &dialect.WindowChain[*dialect.WindowsMod[*dialect.SelectQuery]]{
Wrap: &m,
}
return m
Expand Down
19 changes: 8 additions & 11 deletions dialect/mysql/starters.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"github.com/stephenafamo/bob"
"github.com/stephenafamo/bob/dialect/mysql/dialect"
"github.com/stephenafamo/bob/expr"
"github.com/stephenafamo/bob/mods"
)

type Expression = dialect.Expression
Expand All @@ -15,14 +16,16 @@ var bmod = expr.Builder[Expression, Expression]{}
//
// SQL: generate_series(1, 3)
// Go: mysql.F("generate_series", 1, 3)
func F(name string, args ...any) *dialect.Function {
func F(name string, args ...any) mods.Moddable[*dialect.Function] {
f := dialect.NewFunction(name, args...)

// We have embedded the same function as the chain base
// this is so that chained methods can also be used by functions
f.Chain.Base = &f
return mods.Moddable[*dialect.Function](func(mods ...bob.Mod[*dialect.Function]) *dialect.Function {
for _, mod := range mods {
mod.Apply(f)
}

return &f
return f
})
}

// S creates a string literal
Expand Down Expand Up @@ -91,9 +94,3 @@ func Quote(ss ...string) Expression {
func Raw(query string, args ...any) Expression {
return bmod.Raw(query, args...)
}

// SQL: a as "alias"
// Go: mysql.As("a", "alias")
func As(e Expression, alias string) bob.Expression {
return expr.OP("AS", e, expr.Quote(alias))
}
Loading

0 comments on commit a85e5ea

Please sign in to comment.