diff --git a/internal/compiler/analyze.go b/internal/compiler/analyze.go index 739cd07993..1cc7807386 100644 --- a/internal/compiler/analyze.go +++ b/internal/compiler/analyze.go @@ -140,7 +140,7 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) if err := check(validate.In(c.catalog, raw)); err != nil { return nil, err } - rvs := rangeVars(raw.Stmt) + rvs, rss, rfs := rawRangeTblRefs(raw.Stmt) refs, errs := findParameters(raw.Stmt) if len(errs) > 0 { if failfast { @@ -160,7 +160,7 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) return nil, err } - params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams, embeds) + params, err := c.resolveCatalogRefs(qc, rvs, rss, rfs, refs, namedParams, embeds) if err := check(err); err != nil { return nil, err } diff --git a/internal/compiler/parse.go b/internal/compiler/parse.go index 53e3043c7d..8336274445 100644 --- a/internal/compiler/parse.go +++ b/internal/compiler/parse.go @@ -108,16 +108,22 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, }, nil } -func rangeVars(root ast.Node) []*ast.RangeVar { +func rawRangeTblRefs(root ast.Node) ([]*ast.RangeVar, []*ast.RangeSubselect, []*ast.RangeFunction) { var vars []*ast.RangeVar - find := astutils.VisitorFunc(func(node ast.Node) { + var subs []*ast.RangeSubselect + var funs []*ast.RangeFunction + find := astutils.SingleQueryVisitorFunc(func(node ast.Node) { switch n := node.(type) { case *ast.RangeVar: vars = append(vars, n) + case *ast.RangeSubselect: + subs = append(subs, n) + case *ast.RangeFunction: + funs = append(funs, n) } }) astutils.Walk(find, root) - return vars + return vars, subs, funs } func uniqueParamRefs(in []paramRef, dollar bool) []paramRef { diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index 4624c5a45d..ce5af9fcec 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -21,7 +21,7 @@ func dataType(n *ast.TypeName) string { } } -func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet, embeds rewrite.EmbedSet) ([]Parameter, error) { +func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, rss []*ast.RangeSubselect, rfs []*ast.RangeFunction, args []paramRef, params *named.ParamSet, embeds rewrite.EmbedSet) ([]Parameter, error) { c := comp.catalog aliasMap := map[string]*ast.TableName{} @@ -78,6 +78,19 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, } } + nonTableAliases := make(map[string]bool) + for _, rs := range rss { + if rs != nil && rs.Alias != nil && rs.Alias.Aliasname != nil { + nonTableAliases[*rs.Alias.Aliasname] = true + } + } + + for _, rf := range rfs { + if rf != nil && rf.Alias != nil && rf.Alias.Aliasname != nil { + nonTableAliases[*rf.Alias.Aliasname] = true + } + } + // resolve a table for an embed for _, embed := range embeds { table, err := c.GetTable(embed.Table) @@ -91,6 +104,10 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, continue } + if nonTableAliases[embed.Table.Name] { + return nil, fmt.Errorf("the embed macro can only be used with tables in models, not subqueries or function-defined tables: %q", embed.Orig()) + } + return nil, fmt.Errorf("unable to resolve table with %q: %w", embed.Orig(), err) } diff --git a/internal/sql/astutils/walk.go b/internal/sql/astutils/walk.go index 0943379f03..30471ded25 100644 --- a/internal/sql/astutils/walk.go +++ b/internal/sql/astutils/walk.go @@ -17,6 +17,19 @@ func (vf VisitorFunc) Visit(node ast.Node) Visitor { return vf } +type SingleQueryVisitorFunc func(ast.Node) + +func (vf SingleQueryVisitorFunc) Visit(node ast.Node) Visitor { + switch node.(type) { + case *ast.RangeSubselect, *ast.RangeTblEntry: + vf(node) + return nil + default: + vf(node) + return vf + } +} + func Walk(f Visitor, node ast.Node) { if f = f.Visit(node); f == nil { return