diff --git a/internal/endtoend/fmt_test.go b/internal/endtoend/fmt_test.go index 04e753e5b7..35b475ca4f 100644 --- a/internal/endtoend/fmt_test.go +++ b/internal/endtoend/fmt_test.go @@ -8,6 +8,7 @@ import ( "strings" "testing" + "github.com/sqlc-dev/sqlc/internal/config" "github.com/sqlc-dev/sqlc/internal/debug" "github.com/sqlc-dev/sqlc/internal/engine/postgresql" "github.com/sqlc-dev/sqlc/internal/sql/ast" @@ -15,59 +16,113 @@ import ( func TestFormat(t *testing.T) { t.Parallel() - parse := postgresql.NewParser() for _, tc := range FindTests(t, "testdata", "base") { tc := tc - - if !strings.Contains(tc.Path, filepath.Join("pgx/v5")) { - continue - } - - q := filepath.Join(tc.Path, "query.sql") - if _, err := os.Stat(q); os.IsNotExist(err) { - continue - } - t.Run(tc.Name, func(t *testing.T) { - contents, err := os.ReadFile(q) + // Parse the config file to determine the engine + configPath := filepath.Join(tc.Path, tc.ConfigName) + configFile, err := os.Open(configPath) if err != nil { t.Fatal(err) } - for i, query := range bytes.Split(bytes.TrimSpace(contents), []byte(";")) { - if len(query) <= 1 { - continue - } - query := query - t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { - expected, err := postgresql.Fingerprint(string(query)) - if err != nil { - t.Fatal(err) - } - stmts, err := parse.Parse(bytes.NewReader(query)) - if err != nil { - t.Fatal(err) - } - if len(stmts) != 1 { - t.Fatal("expected one statement") - } - if false { - r, err := postgresql.Parse(string(query)) - debug.Dump(r, err) - } + conf, err := config.ParseConfig(configFile) + configFile.Close() + if err != nil { + t.Fatal(err) + } + + // Skip if there are no SQL packages configured + if len(conf.SQL) == 0 { + return + } + + // For now, only test PostgreSQL since that's the only engine with Format support + engine := conf.SQL[0].Engine + if engine != config.EnginePostgreSQL { + return + } - out := ast.Format(stmts[0].Raw) - actual, err := postgresql.Fingerprint(out) + // Find query files from config + var queryFiles []string + for _, sql := range conf.SQL { + for _, q := range sql.Queries { + queryPath := filepath.Join(tc.Path, q) + info, err := os.Stat(queryPath) if err != nil { - t.Error(err) + continue } - if expected != actual { - debug.Dump(stmts[0].Raw) - t.Errorf("- %s", expected) - t.Errorf("- %s", string(query)) - t.Errorf("+ %s", actual) - t.Errorf("+ %s", out) + if info.IsDir() { + // If it's a directory, glob for .sql files + matches, err := filepath.Glob(filepath.Join(queryPath, "*.sql")) + if err != nil { + continue + } + queryFiles = append(queryFiles, matches...) + } else { + queryFiles = append(queryFiles, queryPath) } - }) + } + } + + if len(queryFiles) == 0 { + return + } + + parse := postgresql.NewParser() + + for _, queryFile := range queryFiles { + if _, err := os.Stat(queryFile); os.IsNotExist(err) { + continue + } + + contents, err := os.ReadFile(queryFile) + if err != nil { + t.Fatal(err) + } + + // Parse the entire file to get proper statement boundaries + stmts, err := parse.Parse(bytes.NewReader(contents)) + if err != nil { + // Skip files with parse errors (e.g., syntax_errors test cases) + return + } + + for i, stmt := range stmts { + stmt := stmt + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + // Extract the original query text using statement location and length + start := stmt.Raw.StmtLocation + length := stmt.Raw.StmtLen + if length == 0 { + // If StmtLen is 0, it means the statement goes to the end of the input + length = len(contents) - start + } + query := strings.TrimSpace(string(contents[start : start+length])) + + expected, err := postgresql.Fingerprint(query) + if err != nil { + t.Fatal(err) + } + + if false { + r, err := postgresql.Parse(query) + debug.Dump(r, err) + } + + out := ast.Format(stmt.Raw, parse) + actual, err := postgresql.Fingerprint(out) + if err != nil { + t.Error(err) + } + if expected != actual { + debug.Dump(stmt.Raw) + t.Errorf("- %s", expected) + t.Errorf("- %s", query) + t.Errorf("+ %s", actual) + t.Errorf("+ %s", out) + } + }) + } } }) } diff --git a/internal/engine/postgresql/convert.go b/internal/engine/postgresql/convert.go index f56a572c16..321294c59e 100644 --- a/internal/engine/postgresql/convert.go +++ b/internal/engine/postgresql/convert.go @@ -1965,6 +1965,22 @@ func convertNullTest(n *pg.NullTest) *ast.NullTest { } } +func convertNullIfExpr(n *pg.NullIfExpr) *ast.NullIfExpr { + if n == nil { + return nil + } + return &ast.NullIfExpr{ + Xpr: convertNode(n.Xpr), + Opno: ast.Oid(n.Opno), + Opresulttype: ast.Oid(n.Opresulttype), + Opretset: n.Opretset, + Opcollid: ast.Oid(n.Opcollid), + Inputcollid: ast.Oid(n.Inputcollid), + Args: convertSlice(n.Args), + Location: int(n.Location), + } +} + func convertObjectWithArgs(n *pg.ObjectWithArgs) *ast.ObjectWithArgs { if n == nil { return nil @@ -3420,6 +3436,9 @@ func convertNode(node *pg.Node) ast.Node { case *pg.Node_NullTest: return convertNullTest(n.NullTest) + case *pg.Node_NullIfExpr: + return convertNullIfExpr(n.NullIfExpr) + case *pg.Node_ObjectWithArgs: return convertObjectWithArgs(n.ObjectWithArgs) diff --git a/internal/engine/postgresql/parse.go b/internal/engine/postgresql/parse.go index 40af125962..0c6b3a0fc2 100644 --- a/internal/engine/postgresql/parse.go +++ b/internal/engine/postgresql/parse.go @@ -494,6 +494,7 @@ func translate(node *nodes.Node) (ast.Node, error) { ReturnType: rt, Replace: n.Replace, Params: &ast.List{}, + Options: convertSlice(n.Options), } for _, item := range n.Parameters { arg := item.Node.(*nodes.Node_FunctionParameter).FunctionParameter diff --git a/internal/engine/postgresql/reserved.go b/internal/engine/postgresql/reserved.go index 8f796ffa19..0be5c54b8d 100644 --- a/internal/engine/postgresql/reserved.go +++ b/internal/engine/postgresql/reserved.go @@ -2,6 +2,59 @@ package postgresql import "strings" +// hasMixedCase returns true if the string has any uppercase letters +// (identifiers with mixed case need quoting in PostgreSQL) +func hasMixedCase(s string) bool { + for _, r := range s { + if r >= 'A' && r <= 'Z' { + return true + } + } + return false +} + +// QuoteIdent returns a quoted identifier if it needs quoting. +// This implements the format.Formatter interface. +func (p *Parser) QuoteIdent(s string) string { + if p.IsReservedKeyword(s) || hasMixedCase(s) { + return `"` + s + `"` + } + return s +} + +// TypeName returns the SQL type name for the given namespace and name. +// This implements the format.Formatter interface. +func (p *Parser) TypeName(ns, name string) string { + if ns == "pg_catalog" { + switch name { + case "int4": + return "integer" + case "int8": + return "bigint" + case "int2": + return "smallint" + case "float4": + return "real" + case "float8": + return "double precision" + case "bool": + return "boolean" + case "bpchar": + return "character" + case "timestamptz": + return "timestamp with time zone" + case "timetz": + return "time with time zone" + default: + return name + } + } + if ns != "" { + return ns + "." + name + } + return name +} + // https://www.postgresql.org/docs/current/sql-keywords-appendix.html func (p *Parser) IsReservedKeyword(s string) bool { switch strings.ToLower(s) { diff --git a/internal/sql/ast/a_array_expr.go b/internal/sql/ast/a_array_expr.go index dafa0e8e85..970e95deb1 100644 --- a/internal/sql/ast/a_array_expr.go +++ b/internal/sql/ast/a_array_expr.go @@ -8,3 +8,12 @@ type A_ArrayExpr struct { func (n *A_ArrayExpr) Pos() int { return n.Location } + +func (n *A_ArrayExpr) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("ARRAY[") + buf.join(n.Elements, ", ") + buf.WriteString("]") +} diff --git a/internal/sql/ast/a_expr.go b/internal/sql/ast/a_expr.go index b0b7f75367..3b73d66d37 100644 --- a/internal/sql/ast/a_expr.go +++ b/internal/sql/ast/a_expr.go @@ -16,19 +16,88 @@ func (n *A_Expr) Format(buf *TrackedBuffer) { if n == nil { return } - buf.astFormat(n.Lexpr) - buf.WriteString(" ") switch n.Kind { case A_Expr_Kind_IN: + buf.astFormat(n.Lexpr) buf.WriteString(" IN (") buf.astFormat(n.Rexpr) buf.WriteString(")") case A_Expr_Kind_LIKE: + buf.astFormat(n.Lexpr) buf.WriteString(" LIKE ") buf.astFormat(n.Rexpr) + case A_Expr_Kind_ILIKE: + buf.astFormat(n.Lexpr) + buf.WriteString(" ILIKE ") + buf.astFormat(n.Rexpr) + case A_Expr_Kind_SIMILAR: + buf.astFormat(n.Lexpr) + buf.WriteString(" SIMILAR TO ") + buf.astFormat(n.Rexpr) + case A_Expr_Kind_BETWEEN: + buf.astFormat(n.Lexpr) + buf.WriteString(" BETWEEN ") + if l, ok := n.Rexpr.(*List); ok && len(l.Items) == 2 { + buf.astFormat(l.Items[0]) + buf.WriteString(" AND ") + buf.astFormat(l.Items[1]) + } + case A_Expr_Kind_NOT_BETWEEN: + buf.astFormat(n.Lexpr) + buf.WriteString(" NOT BETWEEN ") + if l, ok := n.Rexpr.(*List); ok && len(l.Items) == 2 { + buf.astFormat(l.Items[0]) + buf.WriteString(" AND ") + buf.astFormat(l.Items[1]) + } + case A_Expr_Kind_DISTINCT: + buf.astFormat(n.Lexpr) + buf.WriteString(" IS DISTINCT FROM ") + buf.astFormat(n.Rexpr) + case A_Expr_Kind_NOT_DISTINCT: + buf.astFormat(n.Lexpr) + buf.WriteString(" IS NOT DISTINCT FROM ") + buf.astFormat(n.Rexpr) + case A_Expr_Kind_NULLIF: + buf.WriteString("NULLIF(") + buf.astFormat(n.Lexpr) + buf.WriteString(", ") + buf.astFormat(n.Rexpr) + buf.WriteString(")") + case A_Expr_Kind_OP: + // Check if this is a named parameter (@name) + opName := "" + if n.Name != nil && len(n.Name.Items) == 1 { + if s, ok := n.Name.Items[0].(*String); ok { + opName = s.Str + } + } + if opName == "@" && !set(n.Lexpr) && set(n.Rexpr) { + // Named parameter: @name (no space after @) + buf.WriteString("@") + buf.astFormat(n.Rexpr) + } else { + // Standard binary operator + if set(n.Lexpr) { + buf.astFormat(n.Lexpr) + buf.WriteString(" ") + } + buf.astFormat(n.Name) + if set(n.Rexpr) { + buf.WriteString(" ") + buf.astFormat(n.Rexpr) + } + } default: + // Fallback for other cases + if set(n.Lexpr) { + buf.astFormat(n.Lexpr) + buf.WriteString(" ") + } buf.astFormat(n.Name) - buf.WriteString(" ") - buf.astFormat(n.Rexpr) + if set(n.Rexpr) { + buf.WriteString(" ") + buf.astFormat(n.Rexpr) + } } } diff --git a/internal/sql/ast/a_expr_kind.go b/internal/sql/ast/a_expr_kind.go index 53a237896b..3adc9232cf 100644 --- a/internal/sql/ast/a_expr_kind.go +++ b/internal/sql/ast/a_expr_kind.go @@ -3,8 +3,20 @@ package ast type A_Expr_Kind uint const ( - A_Expr_Kind_IN A_Expr_Kind = 7 - A_Expr_Kind_LIKE A_Expr_Kind = 8 + A_Expr_Kind_OP A_Expr_Kind = 1 + A_Expr_Kind_OP_ANY A_Expr_Kind = 2 + A_Expr_Kind_OP_ALL A_Expr_Kind = 3 + A_Expr_Kind_DISTINCT A_Expr_Kind = 4 + A_Expr_Kind_NOT_DISTINCT A_Expr_Kind = 5 + A_Expr_Kind_NULLIF A_Expr_Kind = 6 + A_Expr_Kind_IN A_Expr_Kind = 7 + A_Expr_Kind_LIKE A_Expr_Kind = 8 + A_Expr_Kind_ILIKE A_Expr_Kind = 9 + A_Expr_Kind_SIMILAR A_Expr_Kind = 10 + A_Expr_Kind_BETWEEN A_Expr_Kind = 11 + A_Expr_Kind_NOT_BETWEEN A_Expr_Kind = 12 + A_Expr_Kind_BETWEEN_SYM A_Expr_Kind = 13 + A_Expr_Kind_NOT_BETWEEN_SYM A_Expr_Kind = 14 ) func (n *A_Expr_Kind) Pos() int { diff --git a/internal/sql/ast/a_indices.go b/internal/sql/ast/a_indices.go index 8972f3a556..a143ae6d05 100644 --- a/internal/sql/ast/a_indices.go +++ b/internal/sql/ast/a_indices.go @@ -9,3 +9,22 @@ type A_Indices struct { func (n *A_Indices) Pos() int { return 0 } + +func (n *A_Indices) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("[") + if n.IsSlice { + if set(n.Lidx) { + buf.astFormat(n.Lidx) + } + buf.WriteString(":") + if set(n.Uidx) { + buf.astFormat(n.Uidx) + } + } else { + buf.astFormat(n.Uidx) + } + buf.WriteString("]") +} diff --git a/internal/sql/ast/case_expr.go b/internal/sql/ast/case_expr.go index 1da54f0d78..1d19dbdeec 100644 --- a/internal/sql/ast/case_expr.go +++ b/internal/sql/ast/case_expr.go @@ -19,8 +19,14 @@ func (n *CaseExpr) Format(buf *TrackedBuffer) { return } buf.WriteString("CASE ") - buf.astFormat(n.Args) - buf.WriteString(" ELSE ") - buf.astFormat(n.Defresult) - buf.WriteString(" END ") + if set(n.Arg) { + buf.astFormat(n.Arg) + buf.WriteString(" ") + } + buf.join(n.Args, " ") + if set(n.Defresult) { + buf.WriteString(" ELSE ") + buf.astFormat(n.Defresult) + } + buf.WriteString(" END") } diff --git a/internal/sql/ast/column_def.go b/internal/sql/ast/column_def.go index f9504eefc7..cd8ba115fc 100644 --- a/internal/sql/ast/column_def.go +++ b/internal/sql/ast/column_def.go @@ -39,6 +39,11 @@ func (n *ColumnDef) Format(buf *TrackedBuffer) { buf.WriteString(n.Colname) buf.WriteString(" ") buf.astFormat(n.TypeName) + // Use IsArray from ColumnDef since TypeName.ArrayBounds may not be set + // (for type resolution compatibility) + if n.IsArray && !items(n.TypeName.ArrayBounds) { + buf.WriteString("[]") + } if n.PrimaryKey { buf.WriteString(" PRIMARY KEY") } else if n.IsNotNull { diff --git a/internal/sql/ast/column_ref.go b/internal/sql/ast/column_ref.go index e95b844896..97ea3ab20a 100644 --- a/internal/sql/ast/column_ref.go +++ b/internal/sql/ast/column_ref.go @@ -24,11 +24,7 @@ func (n *ColumnRef) Format(buf *TrackedBuffer) { for _, item := range n.Fields.Items { switch nn := item.(type) { case *String: - if nn.Str == "user" { - items = append(items, `"user"`) - } else { - items = append(items, nn.Str) - } + items = append(items, buf.QuoteIdent(nn.Str)) case *A_Star: items = append(items, "*") } diff --git a/internal/sql/ast/common_table_expr.go b/internal/sql/ast/common_table_expr.go index f2edddff79..b36b3f23d3 100644 --- a/internal/sql/ast/common_table_expr.go +++ b/internal/sql/ast/common_table_expr.go @@ -1,9 +1,5 @@ package ast -import ( - "fmt" -) - type CommonTableExpr struct { Ctename *string Aliascolnames *List @@ -26,8 +22,14 @@ func (n *CommonTableExpr) Format(buf *TrackedBuffer) { return } if n.Ctename != nil { - fmt.Fprintf(buf, " %s AS (", *n.Ctename) + buf.WriteString(*n.Ctename) + } + if items(n.Aliascolnames) { + buf.WriteString("(") + buf.join(n.Aliascolnames, ", ") + buf.WriteString(")") } + buf.WriteString(" AS (") buf.astFormat(n.Ctequery) buf.WriteString(")") } diff --git a/internal/sql/ast/create_extension_stmt.go b/internal/sql/ast/create_extension_stmt.go index 2fe8755b6a..cd12e7505b 100644 --- a/internal/sql/ast/create_extension_stmt.go +++ b/internal/sql/ast/create_extension_stmt.go @@ -9,3 +9,16 @@ type CreateExtensionStmt struct { func (n *CreateExtensionStmt) Pos() int { return 0 } + +func (n *CreateExtensionStmt) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("CREATE EXTENSION ") + if n.IfNotExists { + buf.WriteString("IF NOT EXISTS ") + } + if n.Extname != nil { + buf.WriteString(*n.Extname) + } +} diff --git a/internal/sql/ast/create_function_stmt.go b/internal/sql/ast/create_function_stmt.go index 86605344f7..e070a8720b 100644 --- a/internal/sql/ast/create_function_stmt.go +++ b/internal/sql/ast/create_function_stmt.go @@ -13,3 +13,31 @@ type CreateFunctionStmt struct { func (n *CreateFunctionStmt) Pos() int { return 0 } + +func (n *CreateFunctionStmt) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("CREATE ") + if n.Replace { + buf.WriteString("OR REPLACE ") + } + buf.WriteString("FUNCTION ") + buf.astFormat(n.Func) + buf.WriteString("(") + if items(n.Params) { + buf.join(n.Params, ", ") + } + buf.WriteString(")") + if n.ReturnType != nil { + buf.WriteString(" RETURNS ") + buf.astFormat(n.ReturnType) + } + // Format options (AS, LANGUAGE, etc.) + if items(n.Options) { + for _, opt := range n.Options.Items { + buf.WriteString(" ") + buf.astFormat(opt) + } + } +} diff --git a/internal/sql/ast/def_elem.go b/internal/sql/ast/def_elem.go index 03ecf88e77..d70090339d 100644 --- a/internal/sql/ast/def_elem.go +++ b/internal/sql/ast/def_elem.go @@ -11,3 +11,56 @@ type DefElem struct { func (n *DefElem) Pos() int { return n.Location } + +func (n *DefElem) Format(buf *TrackedBuffer) { + if n == nil { + return + } + if n.Defname != nil { + switch *n.Defname { + case "as": + buf.WriteString("AS ") + // AS clause contains function body which needs quoting + if l, ok := n.Arg.(*List); ok { + for i, item := range l.Items { + if i > 0 { + buf.WriteString(", ") + } + if s, ok := item.(*String); ok { + buf.WriteString("'") + buf.WriteString(s.Str) + buf.WriteString("'") + } else { + buf.astFormat(item) + } + } + } else { + buf.astFormat(n.Arg) + } + case "language": + buf.WriteString("LANGUAGE ") + buf.astFormat(n.Arg) + case "volatility": + // VOLATILE, STABLE, IMMUTABLE + buf.astFormat(n.Arg) + case "strict": + if s, ok := n.Arg.(*Boolean); ok && s.Boolval { + buf.WriteString("STRICT") + } else { + buf.WriteString("CALLED ON NULL INPUT") + } + case "security": + if s, ok := n.Arg.(*Boolean); ok && s.Boolval { + buf.WriteString("SECURITY DEFINER") + } else { + buf.WriteString("SECURITY INVOKER") + } + default: + buf.WriteString(*n.Defname) + if n.Arg != nil { + buf.WriteString(" ") + buf.astFormat(n.Arg) + } + } + } +} diff --git a/internal/sql/ast/delete_stmt.go b/internal/sql/ast/delete_stmt.go index d77f043a12..45c2621095 100644 --- a/internal/sql/ast/delete_stmt.go +++ b/internal/sql/ast/delete_stmt.go @@ -28,6 +28,11 @@ func (n *DeleteStmt) Format(buf *TrackedBuffer) { buf.astFormat(n.Relations) } + if items(n.UsingClause) { + buf.WriteString(" USING ") + buf.join(n.UsingClause, ", ") + } + if set(n.WhereClause) { buf.WriteString(" WHERE ") buf.astFormat(n.WhereClause) diff --git a/internal/sql/ast/do_stmt.go b/internal/sql/ast/do_stmt.go index edc831f15c..a14ddfd537 100644 --- a/internal/sql/ast/do_stmt.go +++ b/internal/sql/ast/do_stmt.go @@ -7,3 +7,22 @@ type DoStmt struct { func (n *DoStmt) Pos() int { return 0 } + +func (n *DoStmt) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("DO ") + // Find the "as" argument which contains the body + if items(n.Args) { + for _, arg := range n.Args.Items { + if de, ok := arg.(*DefElem); ok && de.Defname != nil && *de.Defname == "as" { + if s, ok := de.Arg.(*String); ok { + buf.WriteString("$$") + buf.WriteString(s.Str) + buf.WriteString("$$") + } + } + } + } +} diff --git a/internal/sql/ast/func_call.go b/internal/sql/ast/func_call.go index 2bfe961b50..3b7dcc5400 100644 --- a/internal/sql/ast/func_call.go +++ b/internal/sql/ast/func_call.go @@ -24,10 +24,33 @@ func (n *FuncCall) Format(buf *TrackedBuffer) { } buf.astFormat(n.Func) buf.WriteString("(") + if n.AggDistinct { + buf.WriteString("DISTINCT ") + } if n.AggStar { buf.WriteString("*") } else { buf.astFormat(n.Args) } + // ORDER BY inside function call (not WITHIN GROUP) + if items(n.AggOrder) && !n.AggWithinGroup { + buf.WriteString(" ORDER BY ") + buf.join(n.AggOrder, ", ") + } buf.WriteString(")") + // WITHIN GROUP clause for ordered-set aggregates + if items(n.AggOrder) && n.AggWithinGroup { + buf.WriteString(" WITHIN GROUP (ORDER BY ") + buf.join(n.AggOrder, ", ") + buf.WriteString(")") + } + if set(n.AggFilter) { + buf.WriteString(" FILTER (WHERE ") + buf.astFormat(n.AggFilter) + buf.WriteString(")") + } + if n.Over != nil { + buf.WriteString(" OVER ") + buf.astFormat(n.Over) + } } diff --git a/internal/sql/ast/func_param.go b/internal/sql/ast/func_param.go index b5cf8cfcf0..812d9c629a 100644 --- a/internal/sql/ast/func_param.go +++ b/internal/sql/ast/func_param.go @@ -21,3 +21,25 @@ type FuncParam struct { func (n *FuncParam) Pos() int { return 0 } + +func (n *FuncParam) Format(buf *TrackedBuffer) { + if n == nil { + return + } + // Parameter mode prefix (OUT, INOUT, VARIADIC) + switch n.Mode { + case FuncParamOut: + buf.WriteString("OUT ") + case FuncParamInOut: + buf.WriteString("INOUT ") + case FuncParamVariadic: + buf.WriteString("VARIADIC ") + } + // Parameter name (if present) + if n.Name != nil { + buf.WriteString(*n.Name) + buf.WriteString(" ") + } + // Parameter type + buf.astFormat(n.Type) +} diff --git a/internal/sql/ast/index_elem.go b/internal/sql/ast/index_elem.go index 52ac09688b..d1400699ee 100644 --- a/internal/sql/ast/index_elem.go +++ b/internal/sql/ast/index_elem.go @@ -13,3 +13,14 @@ type IndexElem struct { func (n *IndexElem) Pos() int { return 0 } + +func (n *IndexElem) Format(buf *TrackedBuffer) { + if n == nil { + return + } + if n.Name != nil && *n.Name != "" { + buf.WriteString(*n.Name) + } else if set(n.Expr) { + buf.astFormat(n.Expr) + } +} diff --git a/internal/sql/ast/infer_clause.go b/internal/sql/ast/infer_clause.go index 1e1d93c3d8..ff3855cae5 100644 --- a/internal/sql/ast/infer_clause.go +++ b/internal/sql/ast/infer_clause.go @@ -10,3 +10,21 @@ type InferClause struct { func (n *InferClause) Pos() int { return n.Location } + +func (n *InferClause) Format(buf *TrackedBuffer) { + if n == nil { + return + } + if n.Conname != nil && *n.Conname != "" { + buf.WriteString("ON CONSTRAINT ") + buf.WriteString(*n.Conname) + } else if items(n.IndexElems) { + buf.WriteString("(") + buf.join(n.IndexElems, ", ") + buf.WriteString(")") + if set(n.WhereClause) { + buf.WriteString(" WHERE ") + buf.astFormat(n.WhereClause) + } + } +} diff --git a/internal/sql/ast/insert_stmt.go b/internal/sql/ast/insert_stmt.go index 3cdf854091..cbf480b187 100644 --- a/internal/sql/ast/insert_stmt.go +++ b/internal/sql/ast/insert_stmt.go @@ -31,15 +31,17 @@ func (n *InsertStmt) Format(buf *TrackedBuffer) { if items(n.Cols) { buf.WriteString(" (") buf.astFormat(n.Cols) - buf.WriteString(") ") + buf.WriteString(")") } if set(n.SelectStmt) { + buf.WriteString(" ") buf.astFormat(n.SelectStmt) } if n.OnConflictClause != nil { - buf.WriteString(" ON CONFLICT DO NOTHING ") + buf.WriteString(" ") + buf.astFormat(n.OnConflictClause) } if items(n.ReturningList) { diff --git a/internal/sql/ast/join_expr.go b/internal/sql/ast/join_expr.go index e316869560..69c3089b1b 100644 --- a/internal/sql/ast/join_expr.go +++ b/internal/sql/ast/join_expr.go @@ -20,23 +20,33 @@ func (n *JoinExpr) Format(buf *TrackedBuffer) { return } buf.astFormat(n.Larg) + if n.IsNatural { + buf.WriteString(" NATURAL") + } switch n.Jointype { case JoinTypeLeft: buf.WriteString(" LEFT JOIN ") + case JoinTypeRight: + buf.WriteString(" RIGHT JOIN ") + case JoinTypeFull: + buf.WriteString(" FULL JOIN ") case JoinTypeInner: - buf.WriteString(" INNER JOIN ") + // CROSS JOIN has no ON or USING clause + if !items(n.UsingClause) && !set(n.Quals) { + buf.WriteString(" CROSS JOIN ") + } else { + buf.WriteString(" JOIN ") + } default: buf.WriteString(" JOIN ") } buf.astFormat(n.Rarg) - buf.WriteString(" ON ") - if n.Jointype == JoinTypeInner { - if set(n.Quals) { - buf.astFormat(n.Quals) - } else { - buf.WriteString("TRUE") - } - } else { + if items(n.UsingClause) { + buf.WriteString(" USING (") + buf.join(n.UsingClause, ", ") + buf.WriteString(")") + } else if set(n.Quals) { + buf.WriteString(" ON ") buf.astFormat(n.Quals) } } diff --git a/internal/sql/ast/locking_clause.go b/internal/sql/ast/locking_clause.go index 11a9159de2..286d726edd 100644 --- a/internal/sql/ast/locking_clause.go +++ b/internal/sql/ast/locking_clause.go @@ -10,15 +10,46 @@ func (n *LockingClause) Pos() int { return 0 } +// LockClauseStrength values (matching pg_query_go) +const ( + LockClauseStrengthUndefined LockClauseStrength = 0 + LockClauseStrengthNone LockClauseStrength = 1 + LockClauseStrengthForKeyShare LockClauseStrength = 2 + LockClauseStrengthForShare LockClauseStrength = 3 + LockClauseStrengthForNoKeyUpdate LockClauseStrength = 4 + LockClauseStrengthForUpdate LockClauseStrength = 5 +) + +// LockWaitPolicy values +const ( + LockWaitPolicyBlock LockWaitPolicy = 1 + LockWaitPolicySkip LockWaitPolicy = 2 + LockWaitPolicyError LockWaitPolicy = 3 +) + func (n *LockingClause) Format(buf *TrackedBuffer) { if n == nil { return } buf.WriteString("FOR ") switch n.Strength { - case 3: + case LockClauseStrengthForKeyShare: + buf.WriteString("KEY SHARE") + case LockClauseStrengthForShare: buf.WriteString("SHARE") - case 5: + case LockClauseStrengthForNoKeyUpdate: + buf.WriteString("NO KEY UPDATE") + case LockClauseStrengthForUpdate: buf.WriteString("UPDATE") } + if items(n.LockedRels) { + buf.WriteString(" OF ") + buf.join(n.LockedRels, ", ") + } + switch n.WaitPolicy { + case LockWaitPolicySkip: + buf.WriteString(" SKIP LOCKED") + case LockWaitPolicyError: + buf.WriteString(" NOWAIT") + } } diff --git a/internal/sql/ast/null_test_expr.go b/internal/sql/ast/null_test_expr.go index 51fd37f6bb..42059bca6e 100644 --- a/internal/sql/ast/null_test_expr.go +++ b/internal/sql/ast/null_test_expr.go @@ -11,3 +11,22 @@ type NullTest struct { func (n *NullTest) Pos() int { return n.Location } + +// NullTestType values +const ( + NullTestTypeIsNull NullTestType = 1 + NullTestTypeIsNotNull NullTestType = 2 +) + +func (n *NullTest) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.astFormat(n.Arg) + switch n.Nulltesttype { + case NullTestTypeIsNull: + buf.WriteString(" IS NULL") + case NullTestTypeIsNotNull: + buf.WriteString(" IS NOT NULL") + } +} diff --git a/internal/sql/ast/on_conflict_clause.go b/internal/sql/ast/on_conflict_clause.go index 25333d6d59..055532fb3c 100644 --- a/internal/sql/ast/on_conflict_clause.go +++ b/internal/sql/ast/on_conflict_clause.go @@ -11,3 +11,49 @@ type OnConflictClause struct { func (n *OnConflictClause) Pos() int { return n.Location } + +// OnConflictAction values matching pg_query_go +const ( + OnConflictActionUndefined OnConflictAction = 0 + OnConflictActionNone OnConflictAction = 1 + OnConflictActionNothing OnConflictAction = 2 + OnConflictActionUpdate OnConflictAction = 3 +) + +func (n *OnConflictClause) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("ON CONFLICT ") + if n.Infer != nil { + buf.astFormat(n.Infer) + buf.WriteString(" ") + } + switch n.Action { + case OnConflictActionNothing: + buf.WriteString("DO NOTHING") + case OnConflictActionUpdate: + buf.WriteString("DO UPDATE SET ") + // Format as assignment list: name = val + if n.TargetList != nil { + for i, item := range n.TargetList.Items { + if i > 0 { + buf.WriteString(", ") + } + if rt, ok := item.(*ResTarget); ok { + if rt.Name != nil { + buf.WriteString(*rt.Name) + } + buf.WriteString(" = ") + buf.astFormat(rt.Val) + } else { + buf.astFormat(item) + } + } + } + if set(n.WhereClause) { + buf.WriteString(" WHERE ") + buf.astFormat(n.WhereClause) + } + } +} diff --git a/internal/sql/ast/print.go b/internal/sql/ast/print.go index 867a53a177..8db19ba7d1 100644 --- a/internal/sql/ast/print.go +++ b/internal/sql/ast/print.go @@ -4,26 +4,50 @@ import ( "strings" "github.com/sqlc-dev/sqlc/internal/debug" + "github.com/sqlc-dev/sqlc/internal/sql/format" ) -type formatter interface { +type nodeFormatter interface { Format(*TrackedBuffer) } type TrackedBuffer struct { *strings.Builder + formatter format.Formatter } -// NewTrackedBuffer creates a new TrackedBuffer. -func NewTrackedBuffer() *TrackedBuffer { +// NewTrackedBuffer creates a new TrackedBuffer with the given formatter. +func NewTrackedBuffer(f format.Formatter) *TrackedBuffer { buf := &TrackedBuffer{ - Builder: new(strings.Builder), + Builder: new(strings.Builder), + formatter: f, } return buf } +// QuoteIdent returns a quoted identifier if it needs quoting. +// If no formatter is set, it returns the identifier unchanged. +func (t *TrackedBuffer) QuoteIdent(s string) string { + if t.formatter != nil { + return t.formatter.QuoteIdent(s) + } + return s +} + +// TypeName returns the SQL type name for the given namespace and name. +// If no formatter is set, it returns "ns.name" or just "name". +func (t *TrackedBuffer) TypeName(ns, name string) string { + if t.formatter != nil { + return t.formatter.TypeName(ns, name) + } + if ns != "" { + return ns + "." + name + } + return name +} + func (t *TrackedBuffer) astFormat(n Node) { - if ft, ok := n.(formatter); ok { + if ft, ok := n.(nodeFormatter); ok { ft.Format(t) } else { debug.Dump(n) @@ -45,9 +69,9 @@ func (t *TrackedBuffer) join(n *List, sep string) { } } -func Format(n Node) string { - tb := NewTrackedBuffer() - if ft, ok := n.(formatter); ok { +func Format(n Node, f format.Formatter) string { + tb := NewTrackedBuffer(f) + if ft, ok := n.(nodeFormatter); ok { ft.Format(tb) } return tb.String() diff --git a/internal/sql/ast/range_function.go b/internal/sql/ast/range_function.go index 299078d481..6a95388fd1 100644 --- a/internal/sql/ast/range_function.go +++ b/internal/sql/ast/range_function.go @@ -17,9 +17,15 @@ func (n *RangeFunction) Format(buf *TrackedBuffer) { if n == nil { return } + if n.Lateral { + buf.WriteString("LATERAL ") + } buf.astFormat(n.Functions) if n.Ordinality { - buf.WriteString(" WITH ORDINALITY ") + buf.WriteString(" WITH ORDINALITY") + } + if n.Alias != nil { + buf.WriteString(" AS ") + buf.astFormat(n.Alias) } - buf.astFormat(n.Alias) } diff --git a/internal/sql/ast/range_subselect.go b/internal/sql/ast/range_subselect.go index 1506ee7994..a5d63235d3 100644 --- a/internal/sql/ast/range_subselect.go +++ b/internal/sql/ast/range_subselect.go @@ -14,11 +14,14 @@ func (n *RangeSubselect) Format(buf *TrackedBuffer) { if n == nil { return } + if n.Lateral { + buf.WriteString("LATERAL ") + } buf.WriteString("(") buf.astFormat(n.Subquery) buf.WriteString(")") if n.Alias != nil { - buf.WriteString(" ") + buf.WriteString(" AS ") buf.astFormat(n.Alias) } } diff --git a/internal/sql/ast/range_var.go b/internal/sql/ast/range_var.go index 1d1656f6c0..b7fb316ee9 100644 --- a/internal/sql/ast/range_var.go +++ b/internal/sql/ast/range_var.go @@ -19,18 +19,11 @@ func (n *RangeVar) Format(buf *TrackedBuffer) { return } if n.Schemaname != nil { - buf.WriteString(*n.Schemaname) + buf.WriteString(buf.QuoteIdent(*n.Schemaname)) buf.WriteString(".") } if n.Relname != nil { - // TODO: What names need to be quoted - if *n.Relname == "user" { - buf.WriteString(`"`) - buf.WriteString(*n.Relname) - buf.WriteString(`"`) - } else { - buf.WriteString(*n.Relname) - } + buf.WriteString(buf.QuoteIdent(*n.Relname)) } if n.Alias != nil { buf.WriteString(" ") diff --git a/internal/sql/ast/res_target.go b/internal/sql/ast/res_target.go index 4ee2e72112..b652c2293e 100644 --- a/internal/sql/ast/res_target.go +++ b/internal/sql/ast/res_target.go @@ -19,11 +19,11 @@ func (n *ResTarget) Format(buf *TrackedBuffer) { buf.astFormat(n.Val) if n.Name != nil { buf.WriteString(" AS ") - buf.WriteString(*n.Name) + buf.WriteString(buf.QuoteIdent(*n.Name)) } } else { if n.Name != nil { - buf.WriteString(*n.Name) + buf.WriteString(buf.QuoteIdent(*n.Name)) } } } diff --git a/internal/sql/ast/scalar_array_op_expr.go b/internal/sql/ast/scalar_array_op_expr.go index fc438c10b3..f887bf6508 100644 --- a/internal/sql/ast/scalar_array_op_expr.go +++ b/internal/sql/ast/scalar_array_op_expr.go @@ -12,3 +12,22 @@ type ScalarArrayOpExpr struct { func (n *ScalarArrayOpExpr) Pos() int { return n.Location } + +func (n *ScalarArrayOpExpr) Format(buf *TrackedBuffer) { + if n == nil { + return + } + // ScalarArrayOpExpr represents "scalar op ANY/ALL (array)" + // Args[0] is the left operand, Args[1] is the array + if n.Args != nil && len(n.Args.Items) >= 2 { + buf.astFormat(n.Args.Items[0]) + buf.WriteString(" = ") // TODO: Use actual operator based on Opno + if n.UseOr { + buf.WriteString("ANY(") + } else { + buf.WriteString("ALL(") + } + buf.astFormat(n.Args.Items[1]) + buf.WriteString(")") + } +} diff --git a/internal/sql/ast/select_stmt.go b/internal/sql/ast/select_stmt.go index 051dd5c8c5..a0f0fd4f43 100644 --- a/internal/sql/ast/select_stmt.go +++ b/internal/sql/ast/select_stmt.go @@ -89,6 +89,11 @@ func (n *SelectStmt) Format(buf *TrackedBuffer) { buf.astFormat(n.GroupClause) } + if set(n.HavingClause) { + buf.WriteString(" HAVING ") + buf.astFormat(n.HavingClause) + } + if items(n.SortClause) { buf.WriteString(" ORDER BY ") buf.astFormat(n.SortClause) diff --git a/internal/sql/ast/sort_by.go b/internal/sql/ast/sort_by.go index 21a7a079aa..6d43f541a1 100644 --- a/internal/sql/ast/sort_by.go +++ b/internal/sql/ast/sort_by.go @@ -23,4 +23,10 @@ func (n *SortBy) Format(buf *TrackedBuffer) { case SortByDirDesc: buf.WriteString(" DESC") } + switch n.SortbyNulls { + case SortByNullsFirst: + buf.WriteString(" NULLS FIRST") + case SortByNullsLast: + buf.WriteString(" NULLS LAST") + } } diff --git a/internal/sql/ast/type_name.go b/internal/sql/ast/type_name.go index e26404b3ba..5979d7a90d 100644 --- a/internal/sql/ast/type_name.go +++ b/internal/sql/ast/type_name.go @@ -25,13 +25,32 @@ func (n *TypeName) Format(buf *TrackedBuffer) { return } if items(n.Names) { + // Check if this is a qualified type (e.g., pg_catalog.int4) + if len(n.Names.Items) == 2 { + first, _ := n.Names.Items[0].(*String) + second, _ := n.Names.Items[1].(*String) + if first != nil && second != nil { + buf.WriteString(buf.TypeName(first.Str, second.Str)) + goto addMods + } + } + // For single name types, just output as-is + if len(n.Names.Items) == 1 { + if s, ok := n.Names.Items[0].(*String); ok { + buf.WriteString(buf.TypeName("", s.Str)) + goto addMods + } + } buf.join(n.Names, ".") } else { - if n.Name == "int4" { - buf.WriteString("INTEGER") - } else { - buf.WriteString(n.Name) - } + buf.WriteString(buf.TypeName(n.Schema, n.Name)) + } +addMods: + // Add type modifiers (e.g., varchar(255)) + if items(n.Typmods) { + buf.WriteString("(") + buf.join(n.Typmods, ", ") + buf.WriteString(")") } if items(n.ArrayBounds) { buf.WriteString("[]") diff --git a/internal/sql/ast/typedefs.go b/internal/sql/ast/typedefs.go index 351008e841..46b0e66120 100644 --- a/internal/sql/ast/typedefs.go +++ b/internal/sql/ast/typedefs.go @@ -18,6 +18,15 @@ func (n *NullIfExpr) Pos() int { return 0 } +func (n *NullIfExpr) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("NULLIF(") + buf.join(n.Args, ", ") + buf.WriteString(")") +} + type Selectivity float64 func (n *Selectivity) Pos() int { diff --git a/internal/sql/ast/update_stmt.go b/internal/sql/ast/update_stmt.go index efd496ad75..c98d422130 100644 --- a/internal/sql/ast/update_stmt.go +++ b/internal/sql/ast/update_stmt.go @@ -79,7 +79,13 @@ func (n *UpdateStmt) Format(buf *TrackedBuffer) { switch nn := item.(type) { case *ResTarget: if nn.Name != nil { - buf.WriteString(*nn.Name) + buf.WriteString(buf.QuoteIdent(*nn.Name)) + } + // Handle array subscript indirection (e.g., names[$1]) + if items(nn.Indirection) { + for _, ind := range nn.Indirection.Items { + buf.astFormat(ind) + } } buf.WriteString(" = ") buf.astFormat(nn.Val) diff --git a/internal/sql/ast/window_def.go b/internal/sql/ast/window_def.go index 29840767c9..7e9db4aeef 100644 --- a/internal/sql/ast/window_def.go +++ b/internal/sql/ast/window_def.go @@ -14,3 +14,99 @@ type WindowDef struct { func (n *WindowDef) Pos() int { return n.Location } + +// Frame option constants (from PostgreSQL's parsenodes.h) +const ( + FrameOptionNonDefault = 0x00001 + FrameOptionRange = 0x00002 + FrameOptionRows = 0x00004 + FrameOptionGroups = 0x00008 + FrameOptionBetween = 0x00010 + FrameOptionStartUnboundedPreceding = 0x00020 + FrameOptionEndUnboundedPreceding = 0x00040 + FrameOptionStartUnboundedFollowing = 0x00080 + FrameOptionEndUnboundedFollowing = 0x00100 + FrameOptionStartCurrentRow = 0x00200 + FrameOptionEndCurrentRow = 0x00400 + FrameOptionStartOffset = 0x00800 + FrameOptionEndOffset = 0x01000 + FrameOptionExcludeCurrentRow = 0x02000 + FrameOptionExcludeGroup = 0x04000 + FrameOptionExcludeTies = 0x08000 +) + +func (n *WindowDef) Format(buf *TrackedBuffer) { + if n == nil { + return + } + + // Named window reference + if n.Refname != nil && *n.Refname != "" { + buf.WriteString(*n.Refname) + return + } + + buf.WriteString("(") + needSpace := false + + if items(n.PartitionClause) { + buf.WriteString("PARTITION BY ") + buf.join(n.PartitionClause, ", ") + needSpace = true + } + + if items(n.OrderClause) { + if needSpace { + buf.WriteString(" ") + } + buf.WriteString("ORDER BY ") + buf.join(n.OrderClause, ", ") + needSpace = true + } + + // Frame clause + if n.FrameOptions&FrameOptionNonDefault != 0 { + if needSpace { + buf.WriteString(" ") + } + + // Frame type + if n.FrameOptions&FrameOptionRows != 0 { + buf.WriteString("ROWS ") + } else if n.FrameOptions&FrameOptionRange != 0 { + buf.WriteString("RANGE ") + } else if n.FrameOptions&FrameOptionGroups != 0 { + buf.WriteString("GROUPS ") + } + + if n.FrameOptions&FrameOptionBetween != 0 { + buf.WriteString("BETWEEN ") + } + + // Start bound + if n.FrameOptions&FrameOptionStartUnboundedPreceding != 0 { + buf.WriteString("UNBOUNDED PRECEDING") + } else if n.FrameOptions&FrameOptionStartCurrentRow != 0 { + buf.WriteString("CURRENT ROW") + } else if n.FrameOptions&FrameOptionStartOffset != 0 { + buf.astFormat(n.StartOffset) + buf.WriteString(" PRECEDING") + } + + if n.FrameOptions&FrameOptionBetween != 0 { + buf.WriteString(" AND ") + + // End bound + if n.FrameOptions&FrameOptionEndUnboundedFollowing != 0 { + buf.WriteString("UNBOUNDED FOLLOWING") + } else if n.FrameOptions&FrameOptionEndCurrentRow != 0 { + buf.WriteString("CURRENT ROW") + } else if n.FrameOptions&FrameOptionEndOffset != 0 { + buf.astFormat(n.EndOffset) + buf.WriteString(" FOLLOWING") + } + } + } + + buf.WriteString(")") +} diff --git a/internal/sql/ast/with_clause.go b/internal/sql/ast/with_clause.go index 634326fa7e..86c53fb544 100644 --- a/internal/sql/ast/with_clause.go +++ b/internal/sql/ast/with_clause.go @@ -14,9 +14,9 @@ func (n *WithClause) Format(buf *TrackedBuffer) { if n == nil { return } - buf.WriteString("WITH") + buf.WriteString("WITH ") if n.Recursive { - buf.WriteString(" RECURSIVE") + buf.WriteString("RECURSIVE ") } - buf.astFormat(n.Ctes) + buf.join(n.Ctes, ", ") } diff --git a/internal/sql/format/format.go b/internal/sql/format/format.go new file mode 100644 index 0000000000..f47587dd0b --- /dev/null +++ b/internal/sql/format/format.go @@ -0,0 +1,12 @@ +package format + +// Formatter provides SQL dialect-specific formatting behavior +type Formatter interface { + // QuoteIdent returns a quoted identifier if it needs quoting + // (e.g., reserved words, mixed case identifiers) + QuoteIdent(s string) string + + // TypeName returns the SQL type name for the given namespace and name. + // This handles dialect-specific type name mappings (e.g., pg_catalog.int4 -> integer) + TypeName(ns, name string) string +}