Skip to content

Commit

Permalink
Fix issues with window mod and chaining
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenafamo committed Aug 27, 2022
1 parent f92b055 commit d89e84a
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 29 deletions.
19 changes: 14 additions & 5 deletions dialect/psql/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"io"

"github.com/stephenafamo/bob"
"github.com/stephenafamo/bob/clause"
"github.com/stephenafamo/bob/expr"
)

Expand Down Expand Up @@ -59,11 +58,22 @@ func (f *function) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error
return args, nil
}

func (f *function) FilterWhere(e ...any) *functionOver {
f.filter = append(f.filter, e...)

fo := &functionOver{
function: f,
}
fo.windowChain = &windowChain[*functionOver]{wrap: fo}
fo.Base = fo
return fo
}

func (f *function) Over(window string) *functionOver {
fo := &functionOver{
function: f,
}
fo.def = fo
fo.windowChain = &windowChain[*functionOver]{wrap: fo}
fo.Base = fo
return fo
}
Expand Down Expand Up @@ -97,8 +107,7 @@ func (c columnDef) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error

type functionOver struct {
function *function
clause.WindowDef
windowChain[*functionOver]
*windowChain[*functionOver]
expr.Chain[Expression, Expression]
}

Expand All @@ -108,7 +117,7 @@ func (wr *functionOver) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any,
return nil, err
}

winargs, err := bob.ExpressIf(w, d, start+len(fargs), wr.WindowDef, true, "OVER (", ")")
winargs, err := bob.ExpressIf(w, d, start+len(fargs), wr.def, true, "OVER (", ")")
if err != nil {
return nil, err
}
Expand Down
46 changes: 23 additions & 23 deletions dialect/psql/mods.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,54 +313,54 @@ func (l lockChain[Q]) SkipLocked() lockChain[Q] {

type windowMod[Q interface{ AppendWindow(clause.NamedWindow) }] struct {
name string
clause.WindowDef
windowChain[*windowMod[Q]]
*windowChain[*windowMod[Q]]
}

func (w windowMod[Q]) Apply(q Q) {
q.AppendWindow(clause.NamedWindow{
Name: w.name,
Definition: w.WindowDef,
Definition: w.def,
})
}

type windowChain[T clause.IWindow] struct {
def T
type windowChain[T any] struct {
def clause.WindowDef
wrap T
}

func (w *windowChain[T]) From(name string) T {
w.def.SetFrom(name)
return w.def
return w.wrap
}

func (w *windowChain[T]) PartitionBy(condition ...any) T {
w.def.AddPartitionBy(condition...)
return w.def
return w.wrap
}

func (w *windowChain[T]) OrderBy(order ...any) T {
w.def.AddOrderBy(order...)
return w.def
return w.wrap
}

func (w *windowChain[T]) Range() T {
w.def.SetMode("RANGE")
return w.def
return w.wrap
}

func (w *windowChain[T]) Rows() T {
w.def.SetMode("ROWS")
return w.def
return w.wrap
}

func (w *windowChain[T]) Groups() T {
w.def.SetMode("GROUPS")
return w.def
return w.wrap
}

func (w *windowChain[T]) FromUnboundedPreceding() T {
w.def.SetStart("UNBOUNDED PRECEDING")
return w.def
return w.wrap
}

func (w *windowChain[T]) FromPreceding(exp any) T {
Expand All @@ -369,12 +369,12 @@ func (w *windowChain[T]) FromPreceding(exp any) T {
return bob.ExpressIf(w, d, start, exp, true, "", " PRECEDING")
}),
)
return w.def
return w.wrap
}

func (w *windowChain[T]) FromCurrentRow() T {
w.def.SetStart("CURRENT ROW")
return w.def
return w.wrap
}

func (w *windowChain[T]) FromFollowing(exp any) T {
Expand All @@ -383,7 +383,7 @@ func (w *windowChain[T]) FromFollowing(exp any) T {
return bob.ExpressIf(w, d, start, exp, true, "", " FOLLOWING")
}),
)
return w.def
return w.wrap
}

func (w *windowChain[T]) ToPreceding(exp any) T {
Expand All @@ -392,12 +392,12 @@ func (w *windowChain[T]) ToPreceding(exp any) T {
return bob.ExpressIf(w, d, start, exp, true, "", " PRECEDING")
}),
)
return w.def
return w.wrap
}

func (w *windowChain[T]) ToCurrentRow(count int) T {
w.def.SetEnd("CURRENT ROW")
return w.def
return w.wrap
}

func (w *windowChain[T]) ToFollowing(exp any) T {
Expand All @@ -406,30 +406,30 @@ func (w *windowChain[T]) ToFollowing(exp any) T {
return bob.ExpressIf(w, d, start, exp, true, "", " FOLLOWING")
}),
)
return w.def
return w.wrap
}

func (w *windowChain[T]) ToUnboundedFollowing() T {
w.def.SetEnd("UNBOUNDED FOLLOWING")
return w.def
return w.wrap
}

func (w *windowChain[T]) ExcludeNoOthers() T {
w.def.SetExclusion("NO OTHERS")
return w.def
return w.wrap
}

func (w *windowChain[T]) ExcludeCurrentRow() T {
w.def.SetExclusion("CURRENT ROW")
return w.def
return w.wrap
}

func (w *windowChain[T]) ExcludeGroup() T {
w.def.SetExclusion("GROUP")
return w.def
return w.wrap
}

func (w *windowChain[T]) ExcludeTies() T {
w.def.SetExclusion("TIES")
return w.def
return w.wrap
}
4 changes: 3 additions & 1 deletion dialect/psql/select.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ func (selectQM) Window(name string) windowMod[*SelectQuery] {
name: name,
}

m.windowChain.def = &m
m.windowChain = &windowChain[*windowMod[*SelectQuery]]{
wrap: &m,
}
return m
}

Expand Down

0 comments on commit d89e84a

Please sign in to comment.