diff --git a/dialect/psql/function.go b/dialect/psql/function.go index b32e3f5c..b3883878 100644 --- a/dialect/psql/function.go +++ b/dialect/psql/function.go @@ -4,7 +4,6 @@ import ( "io" "github.com/stephenafamo/bob" - "github.com/stephenafamo/bob/clause" "github.com/stephenafamo/bob/expr" ) @@ -59,11 +58,22 @@ func (f *function) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error return args, nil } +func (f *function) FilterWhere(e ...any) *functionOver { + f.filter = append(f.filter, e...) + + fo := &functionOver{ + function: f, + } + fo.windowChain = &windowChain[*functionOver]{wrap: fo} + fo.Base = fo + return fo +} + func (f *function) Over(window string) *functionOver { fo := &functionOver{ function: f, } - fo.def = fo + fo.windowChain = &windowChain[*functionOver]{wrap: fo} fo.Base = fo return fo } @@ -97,8 +107,7 @@ func (c columnDef) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error type functionOver struct { function *function - clause.WindowDef - windowChain[*functionOver] + *windowChain[*functionOver] expr.Chain[Expression, Expression] } @@ -108,7 +117,7 @@ func (wr *functionOver) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, return nil, err } - winargs, err := bob.ExpressIf(w, d, start+len(fargs), wr.WindowDef, true, "OVER (", ")") + winargs, err := bob.ExpressIf(w, d, start+len(fargs), wr.def, true, "OVER (", ")") if err != nil { return nil, err } diff --git a/dialect/psql/mods.go b/dialect/psql/mods.go index f222d950..4929d81a 100644 --- a/dialect/psql/mods.go +++ b/dialect/psql/mods.go @@ -313,54 +313,54 @@ func (l lockChain[Q]) SkipLocked() lockChain[Q] { type windowMod[Q interface{ AppendWindow(clause.NamedWindow) }] struct { name string - clause.WindowDef - windowChain[*windowMod[Q]] + *windowChain[*windowMod[Q]] } func (w windowMod[Q]) Apply(q Q) { q.AppendWindow(clause.NamedWindow{ Name: w.name, - Definition: w.WindowDef, + Definition: w.def, }) } -type windowChain[T clause.IWindow] struct { - def T +type windowChain[T any] struct { + def clause.WindowDef + wrap T } func (w *windowChain[T]) From(name string) T { w.def.SetFrom(name) - return w.def + return w.wrap } func (w *windowChain[T]) PartitionBy(condition ...any) T { w.def.AddPartitionBy(condition...) - return w.def + return w.wrap } func (w *windowChain[T]) OrderBy(order ...any) T { w.def.AddOrderBy(order...) - return w.def + return w.wrap } func (w *windowChain[T]) Range() T { w.def.SetMode("RANGE") - return w.def + return w.wrap } func (w *windowChain[T]) Rows() T { w.def.SetMode("ROWS") - return w.def + return w.wrap } func (w *windowChain[T]) Groups() T { w.def.SetMode("GROUPS") - return w.def + return w.wrap } func (w *windowChain[T]) FromUnboundedPreceding() T { w.def.SetStart("UNBOUNDED PRECEDING") - return w.def + return w.wrap } func (w *windowChain[T]) FromPreceding(exp any) T { @@ -369,12 +369,12 @@ func (w *windowChain[T]) FromPreceding(exp any) T { return bob.ExpressIf(w, d, start, exp, true, "", " PRECEDING") }), ) - return w.def + return w.wrap } func (w *windowChain[T]) FromCurrentRow() T { w.def.SetStart("CURRENT ROW") - return w.def + return w.wrap } func (w *windowChain[T]) FromFollowing(exp any) T { @@ -383,7 +383,7 @@ func (w *windowChain[T]) FromFollowing(exp any) T { return bob.ExpressIf(w, d, start, exp, true, "", " FOLLOWING") }), ) - return w.def + return w.wrap } func (w *windowChain[T]) ToPreceding(exp any) T { @@ -392,12 +392,12 @@ func (w *windowChain[T]) ToPreceding(exp any) T { return bob.ExpressIf(w, d, start, exp, true, "", " PRECEDING") }), ) - return w.def + return w.wrap } func (w *windowChain[T]) ToCurrentRow(count int) T { w.def.SetEnd("CURRENT ROW") - return w.def + return w.wrap } func (w *windowChain[T]) ToFollowing(exp any) T { @@ -406,30 +406,30 @@ func (w *windowChain[T]) ToFollowing(exp any) T { return bob.ExpressIf(w, d, start, exp, true, "", " FOLLOWING") }), ) - return w.def + return w.wrap } func (w *windowChain[T]) ToUnboundedFollowing() T { w.def.SetEnd("UNBOUNDED FOLLOWING") - return w.def + return w.wrap } func (w *windowChain[T]) ExcludeNoOthers() T { w.def.SetExclusion("NO OTHERS") - return w.def + return w.wrap } func (w *windowChain[T]) ExcludeCurrentRow() T { w.def.SetExclusion("CURRENT ROW") - return w.def + return w.wrap } func (w *windowChain[T]) ExcludeGroup() T { w.def.SetExclusion("GROUP") - return w.def + return w.wrap } func (w *windowChain[T]) ExcludeTies() T { w.def.SetExclusion("TIES") - return w.def + return w.wrap } diff --git a/dialect/psql/select.go b/dialect/psql/select.go index e2ce2f64..8f1ad8a6 100644 --- a/dialect/psql/select.go +++ b/dialect/psql/select.go @@ -208,7 +208,9 @@ func (selectQM) Window(name string) windowMod[*SelectQuery] { name: name, } - m.windowChain.def = &m + m.windowChain = &windowChain[*windowMod[*SelectQuery]]{ + wrap: &m, + } return m }