From ab9ce9efc463661ef129aa28d3b64555e77c5b7b Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sat, 29 Nov 2025 22:30:32 -0800 Subject: [PATCH 1/9] refactor(fmt_test): use config-based engine detection and parser for statement boundaries MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Parse sqlc config file to determine database engine instead of hardcoding pgx/v5 path filter - Use parser's StmtLocation/StmtLen for proper statement boundaries instead of naive semicolon splitting - Handle both file and directory paths in queries config - Only test PostgreSQL for now (formatting support is PostgreSQL-only) This fixes issues with multi-query files containing semicolons in strings, PL/pgSQL functions, or DO blocks. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- internal/endtoend/fmt_test.go | 140 +++++++++++++++++++++++----------- 1 file changed, 97 insertions(+), 43 deletions(-) diff --git a/internal/endtoend/fmt_test.go b/internal/endtoend/fmt_test.go index 04e753e5b7..5650172b90 100644 --- a/internal/endtoend/fmt_test.go +++ b/internal/endtoend/fmt_test.go @@ -8,6 +8,7 @@ import ( "strings" "testing" + "github.com/sqlc-dev/sqlc/internal/config" "github.com/sqlc-dev/sqlc/internal/debug" "github.com/sqlc-dev/sqlc/internal/engine/postgresql" "github.com/sqlc-dev/sqlc/internal/sql/ast" @@ -15,59 +16,112 @@ import ( func TestFormat(t *testing.T) { t.Parallel() - parse := postgresql.NewParser() for _, tc := range FindTests(t, "testdata", "base") { tc := tc - - if !strings.Contains(tc.Path, filepath.Join("pgx/v5")) { - continue - } - - q := filepath.Join(tc.Path, "query.sql") - if _, err := os.Stat(q); os.IsNotExist(err) { - continue - } - t.Run(tc.Name, func(t *testing.T) { - contents, err := os.ReadFile(q) + // Parse the config file to determine the engine + configPath := filepath.Join(tc.Path, tc.ConfigName) + configFile, err := os.Open(configPath) if err != nil { t.Fatal(err) } - for i, query := range bytes.Split(bytes.TrimSpace(contents), []byte(";")) { - if len(query) <= 1 { - continue - } - query := query - t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { - expected, err := postgresql.Fingerprint(string(query)) - if err != nil { - t.Fatal(err) - } - stmts, err := parse.Parse(bytes.NewReader(query)) - if err != nil { - t.Fatal(err) - } - if len(stmts) != 1 { - t.Fatal("expected one statement") - } - if false { - r, err := postgresql.Parse(string(query)) - debug.Dump(r, err) - } + conf, err := config.ParseConfig(configFile) + configFile.Close() + if err != nil { + t.Fatal(err) + } + + // Skip if there are no SQL packages configured + if len(conf.SQL) == 0 { + return + } + + // For now, only test PostgreSQL since that's the only engine with Format support + engine := conf.SQL[0].Engine + if engine != config.EnginePostgreSQL { + return + } - out := ast.Format(stmts[0].Raw) - actual, err := postgresql.Fingerprint(out) + // Find query files from config + var queryFiles []string + for _, sql := range conf.SQL { + for _, q := range sql.Queries { + queryPath := filepath.Join(tc.Path, q) + info, err := os.Stat(queryPath) if err != nil { - t.Error(err) + continue } - if expected != actual { - debug.Dump(stmts[0].Raw) - t.Errorf("- %s", expected) - t.Errorf("- %s", string(query)) - t.Errorf("+ %s", actual) - t.Errorf("+ %s", out) + if info.IsDir() { + // If it's a directory, glob for .sql files + matches, err := filepath.Glob(filepath.Join(queryPath, "*.sql")) + if err != nil { + continue + } + queryFiles = append(queryFiles, matches...) + } else { + queryFiles = append(queryFiles, queryPath) } - }) + } + } + + if len(queryFiles) == 0 { + return + } + + parse := postgresql.NewParser() + + for _, queryFile := range queryFiles { + if _, err := os.Stat(queryFile); os.IsNotExist(err) { + continue + } + + contents, err := os.ReadFile(queryFile) + if err != nil { + t.Fatal(err) + } + + // Parse the entire file to get proper statement boundaries + stmts, err := parse.Parse(bytes.NewReader(contents)) + if err != nil { + t.Fatal(err) + } + + for i, stmt := range stmts { + stmt := stmt + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + // Extract the original query text using statement location and length + start := stmt.Raw.StmtLocation + length := stmt.Raw.StmtLen + if length == 0 { + // If StmtLen is 0, it means the statement goes to the end of the input + length = len(contents) - start + } + query := strings.TrimSpace(string(contents[start : start+length])) + + expected, err := postgresql.Fingerprint(query) + if err != nil { + t.Fatal(err) + } + + if false { + r, err := postgresql.Parse(query) + debug.Dump(r, err) + } + + out := ast.Format(stmt.Raw) + actual, err := postgresql.Fingerprint(out) + if err != nil { + t.Error(err) + } + if expected != actual { + debug.Dump(stmt.Raw) + t.Errorf("- %s", expected) + t.Errorf("- %s", query) + t.Errorf("+ %s", actual) + t.Errorf("+ %s", out) + } + }) + } } }) } From 7731b596a3707df5797080201d58419a331f0c72 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sat, 29 Nov 2025 22:44:44 -0800 Subject: [PATCH 2/9] feat(ast): add and improve Format methods for SQL AST nodes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add Format methods: - A_ArrayExpr: Format ARRAY[...] literals - NullIfExpr: Format NULLIF(arg1, arg2) function calls - OnConflictClause: Format ON CONFLICT ... DO UPDATE/NOTHING - InferClause: Format conflict target (columns) or ON CONSTRAINT - IndexElem: Format index elements for conflict targets - WindowDef: Format window definitions with PARTITION BY, ORDER BY, and frame clauses Improve existing Format methods: - A_Expr: Add BETWEEN, NOT BETWEEN, ILIKE, SIMILAR TO, IS DISTINCT FROM handling - A_Expr_Kind: Add all expression kind constants - CaseExpr: Handle CASE with test argument and optional ELSE - DeleteStmt: Add USING clause formatting - FuncCall: Add DISTINCT, ORDER BY, FILTER, and OVER clause support - InsertStmt: Delegate to OnConflictClause.Format - JoinExpr: Add RIGHT JOIN, FULL JOIN, NATURAL, and USING clause - LockingClause: Add OF clause, SKIP LOCKED, NOWAIT, and fix strength values - RangeFunction: Add LATERAL support and fix alias spacing - SelectStmt: Add HAVING clause formatting These changes reduce test failures from 135 to 102. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- internal/sql/ast/a_array_expr.go | 9 +++ internal/sql/ast/a_expr.go | 58 ++++++++++++++-- internal/sql/ast/a_expr_kind.go | 16 ++++- internal/sql/ast/case_expr.go | 14 ++-- internal/sql/ast/delete_stmt.go | 5 ++ internal/sql/ast/func_call.go | 16 +++++ internal/sql/ast/index_elem.go | 11 +++ internal/sql/ast/infer_clause.go | 18 +++++ internal/sql/ast/insert_stmt.go | 3 +- internal/sql/ast/join_expr.go | 23 +++--- internal/sql/ast/locking_clause.go | 35 +++++++++- internal/sql/ast/on_conflict_clause.go | 46 ++++++++++++ internal/sql/ast/range_function.go | 10 ++- internal/sql/ast/select_stmt.go | 5 ++ internal/sql/ast/typedefs.go | 9 +++ internal/sql/ast/window_def.go | 96 ++++++++++++++++++++++++++ 16 files changed, 350 insertions(+), 24 deletions(-) diff --git a/internal/sql/ast/a_array_expr.go b/internal/sql/ast/a_array_expr.go index dafa0e8e85..970e95deb1 100644 --- a/internal/sql/ast/a_array_expr.go +++ b/internal/sql/ast/a_array_expr.go @@ -8,3 +8,12 @@ type A_ArrayExpr struct { func (n *A_ArrayExpr) Pos() int { return n.Location } + +func (n *A_ArrayExpr) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("ARRAY[") + buf.join(n.Elements, ", ") + buf.WriteString("]") +} diff --git a/internal/sql/ast/a_expr.go b/internal/sql/ast/a_expr.go index b0b7f75367..37484d4ce8 100644 --- a/internal/sql/ast/a_expr.go +++ b/internal/sql/ast/a_expr.go @@ -16,19 +16,69 @@ func (n *A_Expr) Format(buf *TrackedBuffer) { if n == nil { return } - buf.astFormat(n.Lexpr) - buf.WriteString(" ") switch n.Kind { case A_Expr_Kind_IN: + buf.astFormat(n.Lexpr) buf.WriteString(" IN (") buf.astFormat(n.Rexpr) buf.WriteString(")") case A_Expr_Kind_LIKE: + buf.astFormat(n.Lexpr) buf.WriteString(" LIKE ") buf.astFormat(n.Rexpr) + case A_Expr_Kind_ILIKE: + buf.astFormat(n.Lexpr) + buf.WriteString(" ILIKE ") + buf.astFormat(n.Rexpr) + case A_Expr_Kind_SIMILAR: + buf.astFormat(n.Lexpr) + buf.WriteString(" SIMILAR TO ") + buf.astFormat(n.Rexpr) + case A_Expr_Kind_BETWEEN: + buf.astFormat(n.Lexpr) + buf.WriteString(" BETWEEN ") + if l, ok := n.Rexpr.(*List); ok && len(l.Items) == 2 { + buf.astFormat(l.Items[0]) + buf.WriteString(" AND ") + buf.astFormat(l.Items[1]) + } + case A_Expr_Kind_NOT_BETWEEN: + buf.astFormat(n.Lexpr) + buf.WriteString(" NOT BETWEEN ") + if l, ok := n.Rexpr.(*List); ok && len(l.Items) == 2 { + buf.astFormat(l.Items[0]) + buf.WriteString(" AND ") + buf.astFormat(l.Items[1]) + } + case A_Expr_Kind_DISTINCT: + buf.astFormat(n.Lexpr) + buf.WriteString(" IS DISTINCT FROM ") + buf.astFormat(n.Rexpr) + case A_Expr_Kind_NOT_DISTINCT: + buf.astFormat(n.Lexpr) + buf.WriteString(" IS NOT DISTINCT FROM ") + buf.astFormat(n.Rexpr) + case A_Expr_Kind_OP: + // Standard binary operator + if set(n.Lexpr) { + buf.astFormat(n.Lexpr) + buf.WriteString(" ") + } + buf.astFormat(n.Name) + if set(n.Rexpr) { + buf.WriteString(" ") + buf.astFormat(n.Rexpr) + } default: + // Fallback for other cases + if set(n.Lexpr) { + buf.astFormat(n.Lexpr) + buf.WriteString(" ") + } buf.astFormat(n.Name) - buf.WriteString(" ") - buf.astFormat(n.Rexpr) + if set(n.Rexpr) { + buf.WriteString(" ") + buf.astFormat(n.Rexpr) + } } } diff --git a/internal/sql/ast/a_expr_kind.go b/internal/sql/ast/a_expr_kind.go index 53a237896b..3adc9232cf 100644 --- a/internal/sql/ast/a_expr_kind.go +++ b/internal/sql/ast/a_expr_kind.go @@ -3,8 +3,20 @@ package ast type A_Expr_Kind uint const ( - A_Expr_Kind_IN A_Expr_Kind = 7 - A_Expr_Kind_LIKE A_Expr_Kind = 8 + A_Expr_Kind_OP A_Expr_Kind = 1 + A_Expr_Kind_OP_ANY A_Expr_Kind = 2 + A_Expr_Kind_OP_ALL A_Expr_Kind = 3 + A_Expr_Kind_DISTINCT A_Expr_Kind = 4 + A_Expr_Kind_NOT_DISTINCT A_Expr_Kind = 5 + A_Expr_Kind_NULLIF A_Expr_Kind = 6 + A_Expr_Kind_IN A_Expr_Kind = 7 + A_Expr_Kind_LIKE A_Expr_Kind = 8 + A_Expr_Kind_ILIKE A_Expr_Kind = 9 + A_Expr_Kind_SIMILAR A_Expr_Kind = 10 + A_Expr_Kind_BETWEEN A_Expr_Kind = 11 + A_Expr_Kind_NOT_BETWEEN A_Expr_Kind = 12 + A_Expr_Kind_BETWEEN_SYM A_Expr_Kind = 13 + A_Expr_Kind_NOT_BETWEEN_SYM A_Expr_Kind = 14 ) func (n *A_Expr_Kind) Pos() int { diff --git a/internal/sql/ast/case_expr.go b/internal/sql/ast/case_expr.go index 1da54f0d78..1d19dbdeec 100644 --- a/internal/sql/ast/case_expr.go +++ b/internal/sql/ast/case_expr.go @@ -19,8 +19,14 @@ func (n *CaseExpr) Format(buf *TrackedBuffer) { return } buf.WriteString("CASE ") - buf.astFormat(n.Args) - buf.WriteString(" ELSE ") - buf.astFormat(n.Defresult) - buf.WriteString(" END ") + if set(n.Arg) { + buf.astFormat(n.Arg) + buf.WriteString(" ") + } + buf.join(n.Args, " ") + if set(n.Defresult) { + buf.WriteString(" ELSE ") + buf.astFormat(n.Defresult) + } + buf.WriteString(" END") } diff --git a/internal/sql/ast/delete_stmt.go b/internal/sql/ast/delete_stmt.go index d77f043a12..45c2621095 100644 --- a/internal/sql/ast/delete_stmt.go +++ b/internal/sql/ast/delete_stmt.go @@ -28,6 +28,11 @@ func (n *DeleteStmt) Format(buf *TrackedBuffer) { buf.astFormat(n.Relations) } + if items(n.UsingClause) { + buf.WriteString(" USING ") + buf.join(n.UsingClause, ", ") + } + if set(n.WhereClause) { buf.WriteString(" WHERE ") buf.astFormat(n.WhereClause) diff --git a/internal/sql/ast/func_call.go b/internal/sql/ast/func_call.go index 2bfe961b50..5eda3f027f 100644 --- a/internal/sql/ast/func_call.go +++ b/internal/sql/ast/func_call.go @@ -24,10 +24,26 @@ func (n *FuncCall) Format(buf *TrackedBuffer) { } buf.astFormat(n.Func) buf.WriteString("(") + if n.AggDistinct { + buf.WriteString("DISTINCT ") + } if n.AggStar { buf.WriteString("*") } else { buf.astFormat(n.Args) } + if items(n.AggOrder) { + buf.WriteString(" ORDER BY ") + buf.join(n.AggOrder, ", ") + } buf.WriteString(")") + if set(n.AggFilter) { + buf.WriteString(" FILTER (WHERE ") + buf.astFormat(n.AggFilter) + buf.WriteString(")") + } + if n.Over != nil { + buf.WriteString(" OVER ") + buf.astFormat(n.Over) + } } diff --git a/internal/sql/ast/index_elem.go b/internal/sql/ast/index_elem.go index 52ac09688b..d1400699ee 100644 --- a/internal/sql/ast/index_elem.go +++ b/internal/sql/ast/index_elem.go @@ -13,3 +13,14 @@ type IndexElem struct { func (n *IndexElem) Pos() int { return 0 } + +func (n *IndexElem) Format(buf *TrackedBuffer) { + if n == nil { + return + } + if n.Name != nil && *n.Name != "" { + buf.WriteString(*n.Name) + } else if set(n.Expr) { + buf.astFormat(n.Expr) + } +} diff --git a/internal/sql/ast/infer_clause.go b/internal/sql/ast/infer_clause.go index 1e1d93c3d8..ff3855cae5 100644 --- a/internal/sql/ast/infer_clause.go +++ b/internal/sql/ast/infer_clause.go @@ -10,3 +10,21 @@ type InferClause struct { func (n *InferClause) Pos() int { return n.Location } + +func (n *InferClause) Format(buf *TrackedBuffer) { + if n == nil { + return + } + if n.Conname != nil && *n.Conname != "" { + buf.WriteString("ON CONSTRAINT ") + buf.WriteString(*n.Conname) + } else if items(n.IndexElems) { + buf.WriteString("(") + buf.join(n.IndexElems, ", ") + buf.WriteString(")") + if set(n.WhereClause) { + buf.WriteString(" WHERE ") + buf.astFormat(n.WhereClause) + } + } +} diff --git a/internal/sql/ast/insert_stmt.go b/internal/sql/ast/insert_stmt.go index 3cdf854091..90cecf1a46 100644 --- a/internal/sql/ast/insert_stmt.go +++ b/internal/sql/ast/insert_stmt.go @@ -39,7 +39,8 @@ func (n *InsertStmt) Format(buf *TrackedBuffer) { } if n.OnConflictClause != nil { - buf.WriteString(" ON CONFLICT DO NOTHING ") + buf.WriteString(" ") + buf.astFormat(n.OnConflictClause) } if items(n.ReturningList) { diff --git a/internal/sql/ast/join_expr.go b/internal/sql/ast/join_expr.go index e316869560..ba6d2d2298 100644 --- a/internal/sql/ast/join_expr.go +++ b/internal/sql/ast/join_expr.go @@ -20,23 +20,28 @@ func (n *JoinExpr) Format(buf *TrackedBuffer) { return } buf.astFormat(n.Larg) + if n.IsNatural { + buf.WriteString(" NATURAL") + } switch n.Jointype { case JoinTypeLeft: buf.WriteString(" LEFT JOIN ") + case JoinTypeRight: + buf.WriteString(" RIGHT JOIN ") + case JoinTypeFull: + buf.WriteString(" FULL JOIN ") case JoinTypeInner: - buf.WriteString(" INNER JOIN ") + buf.WriteString(" JOIN ") default: buf.WriteString(" JOIN ") } buf.astFormat(n.Rarg) - buf.WriteString(" ON ") - if n.Jointype == JoinTypeInner { - if set(n.Quals) { - buf.astFormat(n.Quals) - } else { - buf.WriteString("TRUE") - } - } else { + if items(n.UsingClause) { + buf.WriteString(" USING (") + buf.join(n.UsingClause, ", ") + buf.WriteString(")") + } else if set(n.Quals) { + buf.WriteString(" ON ") buf.astFormat(n.Quals) } } diff --git a/internal/sql/ast/locking_clause.go b/internal/sql/ast/locking_clause.go index 11a9159de2..286d726edd 100644 --- a/internal/sql/ast/locking_clause.go +++ b/internal/sql/ast/locking_clause.go @@ -10,15 +10,46 @@ func (n *LockingClause) Pos() int { return 0 } +// LockClauseStrength values (matching pg_query_go) +const ( + LockClauseStrengthUndefined LockClauseStrength = 0 + LockClauseStrengthNone LockClauseStrength = 1 + LockClauseStrengthForKeyShare LockClauseStrength = 2 + LockClauseStrengthForShare LockClauseStrength = 3 + LockClauseStrengthForNoKeyUpdate LockClauseStrength = 4 + LockClauseStrengthForUpdate LockClauseStrength = 5 +) + +// LockWaitPolicy values +const ( + LockWaitPolicyBlock LockWaitPolicy = 1 + LockWaitPolicySkip LockWaitPolicy = 2 + LockWaitPolicyError LockWaitPolicy = 3 +) + func (n *LockingClause) Format(buf *TrackedBuffer) { if n == nil { return } buf.WriteString("FOR ") switch n.Strength { - case 3: + case LockClauseStrengthForKeyShare: + buf.WriteString("KEY SHARE") + case LockClauseStrengthForShare: buf.WriteString("SHARE") - case 5: + case LockClauseStrengthForNoKeyUpdate: + buf.WriteString("NO KEY UPDATE") + case LockClauseStrengthForUpdate: buf.WriteString("UPDATE") } + if items(n.LockedRels) { + buf.WriteString(" OF ") + buf.join(n.LockedRels, ", ") + } + switch n.WaitPolicy { + case LockWaitPolicySkip: + buf.WriteString(" SKIP LOCKED") + case LockWaitPolicyError: + buf.WriteString(" NOWAIT") + } } diff --git a/internal/sql/ast/on_conflict_clause.go b/internal/sql/ast/on_conflict_clause.go index 25333d6d59..055532fb3c 100644 --- a/internal/sql/ast/on_conflict_clause.go +++ b/internal/sql/ast/on_conflict_clause.go @@ -11,3 +11,49 @@ type OnConflictClause struct { func (n *OnConflictClause) Pos() int { return n.Location } + +// OnConflictAction values matching pg_query_go +const ( + OnConflictActionUndefined OnConflictAction = 0 + OnConflictActionNone OnConflictAction = 1 + OnConflictActionNothing OnConflictAction = 2 + OnConflictActionUpdate OnConflictAction = 3 +) + +func (n *OnConflictClause) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("ON CONFLICT ") + if n.Infer != nil { + buf.astFormat(n.Infer) + buf.WriteString(" ") + } + switch n.Action { + case OnConflictActionNothing: + buf.WriteString("DO NOTHING") + case OnConflictActionUpdate: + buf.WriteString("DO UPDATE SET ") + // Format as assignment list: name = val + if n.TargetList != nil { + for i, item := range n.TargetList.Items { + if i > 0 { + buf.WriteString(", ") + } + if rt, ok := item.(*ResTarget); ok { + if rt.Name != nil { + buf.WriteString(*rt.Name) + } + buf.WriteString(" = ") + buf.astFormat(rt.Val) + } else { + buf.astFormat(item) + } + } + } + if set(n.WhereClause) { + buf.WriteString(" WHERE ") + buf.astFormat(n.WhereClause) + } + } +} diff --git a/internal/sql/ast/range_function.go b/internal/sql/ast/range_function.go index 299078d481..3ce821482e 100644 --- a/internal/sql/ast/range_function.go +++ b/internal/sql/ast/range_function.go @@ -17,9 +17,15 @@ func (n *RangeFunction) Format(buf *TrackedBuffer) { if n == nil { return } + if n.Lateral { + buf.WriteString("LATERAL ") + } buf.astFormat(n.Functions) if n.Ordinality { - buf.WriteString(" WITH ORDINALITY ") + buf.WriteString(" WITH ORDINALITY") + } + if n.Alias != nil { + buf.WriteString(" ") + buf.astFormat(n.Alias) } - buf.astFormat(n.Alias) } diff --git a/internal/sql/ast/select_stmt.go b/internal/sql/ast/select_stmt.go index 051dd5c8c5..a0f0fd4f43 100644 --- a/internal/sql/ast/select_stmt.go +++ b/internal/sql/ast/select_stmt.go @@ -89,6 +89,11 @@ func (n *SelectStmt) Format(buf *TrackedBuffer) { buf.astFormat(n.GroupClause) } + if set(n.HavingClause) { + buf.WriteString(" HAVING ") + buf.astFormat(n.HavingClause) + } + if items(n.SortClause) { buf.WriteString(" ORDER BY ") buf.astFormat(n.SortClause) diff --git a/internal/sql/ast/typedefs.go b/internal/sql/ast/typedefs.go index 351008e841..46b0e66120 100644 --- a/internal/sql/ast/typedefs.go +++ b/internal/sql/ast/typedefs.go @@ -18,6 +18,15 @@ func (n *NullIfExpr) Pos() int { return 0 } +func (n *NullIfExpr) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("NULLIF(") + buf.join(n.Args, ", ") + buf.WriteString(")") +} + type Selectivity float64 func (n *Selectivity) Pos() int { diff --git a/internal/sql/ast/window_def.go b/internal/sql/ast/window_def.go index 29840767c9..7e9db4aeef 100644 --- a/internal/sql/ast/window_def.go +++ b/internal/sql/ast/window_def.go @@ -14,3 +14,99 @@ type WindowDef struct { func (n *WindowDef) Pos() int { return n.Location } + +// Frame option constants (from PostgreSQL's parsenodes.h) +const ( + FrameOptionNonDefault = 0x00001 + FrameOptionRange = 0x00002 + FrameOptionRows = 0x00004 + FrameOptionGroups = 0x00008 + FrameOptionBetween = 0x00010 + FrameOptionStartUnboundedPreceding = 0x00020 + FrameOptionEndUnboundedPreceding = 0x00040 + FrameOptionStartUnboundedFollowing = 0x00080 + FrameOptionEndUnboundedFollowing = 0x00100 + FrameOptionStartCurrentRow = 0x00200 + FrameOptionEndCurrentRow = 0x00400 + FrameOptionStartOffset = 0x00800 + FrameOptionEndOffset = 0x01000 + FrameOptionExcludeCurrentRow = 0x02000 + FrameOptionExcludeGroup = 0x04000 + FrameOptionExcludeTies = 0x08000 +) + +func (n *WindowDef) Format(buf *TrackedBuffer) { + if n == nil { + return + } + + // Named window reference + if n.Refname != nil && *n.Refname != "" { + buf.WriteString(*n.Refname) + return + } + + buf.WriteString("(") + needSpace := false + + if items(n.PartitionClause) { + buf.WriteString("PARTITION BY ") + buf.join(n.PartitionClause, ", ") + needSpace = true + } + + if items(n.OrderClause) { + if needSpace { + buf.WriteString(" ") + } + buf.WriteString("ORDER BY ") + buf.join(n.OrderClause, ", ") + needSpace = true + } + + // Frame clause + if n.FrameOptions&FrameOptionNonDefault != 0 { + if needSpace { + buf.WriteString(" ") + } + + // Frame type + if n.FrameOptions&FrameOptionRows != 0 { + buf.WriteString("ROWS ") + } else if n.FrameOptions&FrameOptionRange != 0 { + buf.WriteString("RANGE ") + } else if n.FrameOptions&FrameOptionGroups != 0 { + buf.WriteString("GROUPS ") + } + + if n.FrameOptions&FrameOptionBetween != 0 { + buf.WriteString("BETWEEN ") + } + + // Start bound + if n.FrameOptions&FrameOptionStartUnboundedPreceding != 0 { + buf.WriteString("UNBOUNDED PRECEDING") + } else if n.FrameOptions&FrameOptionStartCurrentRow != 0 { + buf.WriteString("CURRENT ROW") + } else if n.FrameOptions&FrameOptionStartOffset != 0 { + buf.astFormat(n.StartOffset) + buf.WriteString(" PRECEDING") + } + + if n.FrameOptions&FrameOptionBetween != 0 { + buf.WriteString(" AND ") + + // End bound + if n.FrameOptions&FrameOptionEndUnboundedFollowing != 0 { + buf.WriteString("UNBOUNDED FOLLOWING") + } else if n.FrameOptions&FrameOptionEndCurrentRow != 0 { + buf.WriteString("CURRENT ROW") + } else if n.FrameOptions&FrameOptionEndOffset != 0 { + buf.astFormat(n.EndOffset) + buf.WriteString(" FOLLOWING") + } + } + } + + buf.WriteString(")") +} From c355232f4a130e7a6b9f768bfc144ed7a4929bff Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sat, 29 Nov 2025 22:48:03 -0800 Subject: [PATCH 3/9] feat(ast): add more Format methods for SQL AST nodes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add Format methods: - NullTest: Format IS NULL / IS NOT NULL expressions - ScalarArrayOpExpr: Format scalar op ANY/ALL (array) expressions - CommonTableExpr: Add column alias list support Improve existing Format methods: - WithClause: Fix spacing after WITH and RECURSIVE keywords These changes reduce test failures from 102 to 91. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- internal/sql/ast/common_table_expr.go | 12 +++++++----- internal/sql/ast/null_test_expr.go | 19 +++++++++++++++++++ internal/sql/ast/scalar_array_op_expr.go | 19 +++++++++++++++++++ internal/sql/ast/with_clause.go | 6 +++--- 4 files changed, 48 insertions(+), 8 deletions(-) diff --git a/internal/sql/ast/common_table_expr.go b/internal/sql/ast/common_table_expr.go index f2edddff79..b36b3f23d3 100644 --- a/internal/sql/ast/common_table_expr.go +++ b/internal/sql/ast/common_table_expr.go @@ -1,9 +1,5 @@ package ast -import ( - "fmt" -) - type CommonTableExpr struct { Ctename *string Aliascolnames *List @@ -26,8 +22,14 @@ func (n *CommonTableExpr) Format(buf *TrackedBuffer) { return } if n.Ctename != nil { - fmt.Fprintf(buf, " %s AS (", *n.Ctename) + buf.WriteString(*n.Ctename) + } + if items(n.Aliascolnames) { + buf.WriteString("(") + buf.join(n.Aliascolnames, ", ") + buf.WriteString(")") } + buf.WriteString(" AS (") buf.astFormat(n.Ctequery) buf.WriteString(")") } diff --git a/internal/sql/ast/null_test_expr.go b/internal/sql/ast/null_test_expr.go index 51fd37f6bb..42059bca6e 100644 --- a/internal/sql/ast/null_test_expr.go +++ b/internal/sql/ast/null_test_expr.go @@ -11,3 +11,22 @@ type NullTest struct { func (n *NullTest) Pos() int { return n.Location } + +// NullTestType values +const ( + NullTestTypeIsNull NullTestType = 1 + NullTestTypeIsNotNull NullTestType = 2 +) + +func (n *NullTest) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.astFormat(n.Arg) + switch n.Nulltesttype { + case NullTestTypeIsNull: + buf.WriteString(" IS NULL") + case NullTestTypeIsNotNull: + buf.WriteString(" IS NOT NULL") + } +} diff --git a/internal/sql/ast/scalar_array_op_expr.go b/internal/sql/ast/scalar_array_op_expr.go index fc438c10b3..f887bf6508 100644 --- a/internal/sql/ast/scalar_array_op_expr.go +++ b/internal/sql/ast/scalar_array_op_expr.go @@ -12,3 +12,22 @@ type ScalarArrayOpExpr struct { func (n *ScalarArrayOpExpr) Pos() int { return n.Location } + +func (n *ScalarArrayOpExpr) Format(buf *TrackedBuffer) { + if n == nil { + return + } + // ScalarArrayOpExpr represents "scalar op ANY/ALL (array)" + // Args[0] is the left operand, Args[1] is the array + if n.Args != nil && len(n.Args.Items) >= 2 { + buf.astFormat(n.Args.Items[0]) + buf.WriteString(" = ") // TODO: Use actual operator based on Opno + if n.UseOr { + buf.WriteString("ANY(") + } else { + buf.WriteString("ALL(") + } + buf.astFormat(n.Args.Items[1]) + buf.WriteString(")") + } +} diff --git a/internal/sql/ast/with_clause.go b/internal/sql/ast/with_clause.go index 634326fa7e..86c53fb544 100644 --- a/internal/sql/ast/with_clause.go +++ b/internal/sql/ast/with_clause.go @@ -14,9 +14,9 @@ func (n *WithClause) Format(buf *TrackedBuffer) { if n == nil { return } - buf.WriteString("WITH") + buf.WriteString("WITH ") if n.Recursive { - buf.WriteString(" RECURSIVE") + buf.WriteString("RECURSIVE ") } - buf.astFormat(n.Ctes) + buf.join(n.Ctes, ", ") } From 25ee7051150e61eb54a2cb5d0143193389739b9f Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sat, 29 Nov 2025 23:06:07 -0800 Subject: [PATCH 4/9] feat(postgresql): add custom Deparse wrapper with bug fixes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Switch fmt_test.go to use postgresql.Deparse instead of ast.Format - Add deparse.go and deparse_wasi.go with Deparse wrapper function - Fix pg_query_go bug: missing space before SKIP LOCKED - Skip tests with parse errors (e.g., syntax_errors test cases) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- internal/endtoend/fmt_test.go | 21 ++++++++++++----- internal/engine/postgresql/deparse.go | 26 +++++++++++++++++++++ internal/engine/postgresql/deparse_wasi.go | 26 +++++++++++++++++++++ internal/engine/postgresql/parse_default.go | 2 ++ internal/engine/postgresql/parse_wasi.go | 2 ++ 5 files changed, 71 insertions(+), 6 deletions(-) create mode 100644 internal/engine/postgresql/deparse.go create mode 100644 internal/engine/postgresql/deparse_wasi.go diff --git a/internal/endtoend/fmt_test.go b/internal/endtoend/fmt_test.go index 5650172b90..657d417b3a 100644 --- a/internal/endtoend/fmt_test.go +++ b/internal/endtoend/fmt_test.go @@ -11,7 +11,6 @@ import ( "github.com/sqlc-dev/sqlc/internal/config" "github.com/sqlc-dev/sqlc/internal/debug" "github.com/sqlc-dev/sqlc/internal/engine/postgresql" - "github.com/sqlc-dev/sqlc/internal/sql/ast" ) func TestFormat(t *testing.T) { @@ -83,7 +82,8 @@ func TestFormat(t *testing.T) { // Parse the entire file to get proper statement boundaries stmts, err := parse.Parse(bytes.NewReader(contents)) if err != nil { - t.Fatal(err) + // Skip files with parse errors (e.g., syntax_errors test cases) + return } for i, stmt := range stmts { @@ -103,18 +103,27 @@ func TestFormat(t *testing.T) { t.Fatal(err) } + // Parse the query to get a ParseResult for Deparse + parseResult, err := postgresql.Parse(query) + if err != nil { + t.Fatal(err) + } + if false { - r, err := postgresql.Parse(query) - debug.Dump(r, err) + debug.Dump(parseResult) + } + + out, err := postgresql.Deparse(parseResult) + if err != nil { + t.Fatal(err) } - out := ast.Format(stmt.Raw) actual, err := postgresql.Fingerprint(out) if err != nil { t.Error(err) } if expected != actual { - debug.Dump(stmt.Raw) + debug.Dump(parseResult) t.Errorf("- %s", expected) t.Errorf("- %s", query) t.Errorf("+ %s", actual) diff --git a/internal/engine/postgresql/deparse.go b/internal/engine/postgresql/deparse.go new file mode 100644 index 0000000000..5326eefc05 --- /dev/null +++ b/internal/engine/postgresql/deparse.go @@ -0,0 +1,26 @@ +//go:build !windows && cgo + +package postgresql + +import ( + "strings" + + nodes "github.com/pganalyze/pg_query_go/v6" +) + +func Deparse(tree *nodes.ParseResult) (string, error) { + output, err := nodeDeparse(tree) + if err != nil { + return output, err + } + return fixDeparse(output), nil +} + +// fixDeparse corrects known bugs in pg_query_go's Deparse output +func fixDeparse(s string) string { + // Fix missing space before SKIP LOCKED + // pg_query_go outputs "OF tableSKIP LOCKED" instead of "OF table SKIP LOCKED" + s = strings.ReplaceAll(s, "SKIP LOCKED", " SKIP LOCKED") + s = strings.ReplaceAll(s, " SKIP LOCKED", " SKIP LOCKED") // normalize double spaces + return s +} diff --git a/internal/engine/postgresql/deparse_wasi.go b/internal/engine/postgresql/deparse_wasi.go new file mode 100644 index 0000000000..66606abe13 --- /dev/null +++ b/internal/engine/postgresql/deparse_wasi.go @@ -0,0 +1,26 @@ +//go:build windows || !cgo + +package postgresql + +import ( + "strings" + + nodes "github.com/wasilibs/go-pgquery" +) + +func Deparse(tree *nodes.ParseResult) (string, error) { + output, err := nodeDeparse(tree) + if err != nil { + return output, err + } + return fixDeparse(output), nil +} + +// fixDeparse corrects known bugs in pg_query's Deparse output +func fixDeparse(s string) string { + // Fix missing space before SKIP LOCKED + // pg_query outputs "OF tableSKIP LOCKED" instead of "OF table SKIP LOCKED" + s = strings.ReplaceAll(s, "SKIP LOCKED", " SKIP LOCKED") + s = strings.ReplaceAll(s, " SKIP LOCKED", " SKIP LOCKED") // normalize double spaces + return s +} diff --git a/internal/engine/postgresql/parse_default.go b/internal/engine/postgresql/parse_default.go index 272f189649..59eed74565 100644 --- a/internal/engine/postgresql/parse_default.go +++ b/internal/engine/postgresql/parse_default.go @@ -8,3 +8,5 @@ import ( var Parse = nodes.Parse var Fingerprint = nodes.Fingerprint + +var nodeDeparse = nodes.Deparse diff --git a/internal/engine/postgresql/parse_wasi.go b/internal/engine/postgresql/parse_wasi.go index 377b812cdb..51bba538e0 100644 --- a/internal/engine/postgresql/parse_wasi.go +++ b/internal/engine/postgresql/parse_wasi.go @@ -8,3 +8,5 @@ import ( var Parse = nodes.Parse var Fingerprint = nodes.Fingerprint + +var nodeDeparse = nodes.Deparse From c4e885fc64c4fc93a9f5a37311ff810776c6c829 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sun, 30 Nov 2025 12:48:10 -0800 Subject: [PATCH 5/9] feat(ast): complete SQL AST formatting implementation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes all ast.Format test failures by implementing comprehensive Format methods for SQL AST nodes. Key improvements include: - Named parameters (@param) formatting without space after @ - NULLIF expression support in A_Expr - NULLS FIRST/LAST in ORDER BY clauses - Type name mapping (int4→integer, timestamptz→timestamp with time zone) - Array type support (text[]) and type modifiers (varchar(32)) - CREATE FUNCTION with parameters, options (AS, LANGUAGE), and modes - CREATE EXTENSION statement formatting - DO $$ ... $$ anonymous code blocks - WITHIN GROUP clause for ordered-set aggregates - Automatic quoting for SQL reserved words and mixed-case identifiers - CROSS JOIN detection (JOIN without ON/USING clause) - LATERAL keyword in subselects and function calls - Array subscript access in UPDATE statements (names[$1]) - Proper AS keyword before aliases Also removes unused deparse files and cleans up fmt_test.go to use ast.Format directly. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- internal/endtoend/fmt_test.go | 18 ++-- internal/engine/postgresql/convert.go | 19 ++++ internal/engine/postgresql/deparse.go | 26 ------ internal/engine/postgresql/deparse_wasi.go | 26 ------ internal/engine/postgresql/parse.go | 96 +++++++++++++++++++-- internal/engine/postgresql/parse_default.go | 2 - internal/engine/postgresql/parse_wasi.go | 2 - internal/sql/ast/a_expr.go | 33 +++++-- internal/sql/ast/a_indices.go | 19 ++++ internal/sql/ast/column_ref.go | 57 ++++++++++-- internal/sql/ast/create_extension_stmt.go | 13 +++ internal/sql/ast/create_function_stmt.go | 28 ++++++ internal/sql/ast/def_elem.go | 53 ++++++++++++ internal/sql/ast/do_stmt.go | 19 ++++ internal/sql/ast/func_call.go | 9 +- internal/sql/ast/func_param.go | 22 +++++ internal/sql/ast/insert_stmt.go | 3 +- internal/sql/ast/join_expr.go | 7 +- internal/sql/ast/range_function.go | 2 +- internal/sql/ast/range_subselect.go | 5 +- internal/sql/ast/range_var.go | 11 +-- internal/sql/ast/res_target.go | 4 +- internal/sql/ast/sort_by.go | 6 ++ internal/sql/ast/type_name.go | 67 ++++++++++++-- internal/sql/ast/update_stmt.go | 8 +- 25 files changed, 446 insertions(+), 109 deletions(-) delete mode 100644 internal/engine/postgresql/deparse.go delete mode 100644 internal/engine/postgresql/deparse_wasi.go diff --git a/internal/endtoend/fmt_test.go b/internal/endtoend/fmt_test.go index 657d417b3a..53cc7403fc 100644 --- a/internal/endtoend/fmt_test.go +++ b/internal/endtoend/fmt_test.go @@ -11,6 +11,7 @@ import ( "github.com/sqlc-dev/sqlc/internal/config" "github.com/sqlc-dev/sqlc/internal/debug" "github.com/sqlc-dev/sqlc/internal/engine/postgresql" + "github.com/sqlc-dev/sqlc/internal/sql/ast" ) func TestFormat(t *testing.T) { @@ -103,27 +104,18 @@ func TestFormat(t *testing.T) { t.Fatal(err) } - // Parse the query to get a ParseResult for Deparse - parseResult, err := postgresql.Parse(query) - if err != nil { - t.Fatal(err) - } - if false { - debug.Dump(parseResult) - } - - out, err := postgresql.Deparse(parseResult) - if err != nil { - t.Fatal(err) + r, err := postgresql.Parse(query) + debug.Dump(r, err) } + out := ast.Format(stmt.Raw) actual, err := postgresql.Fingerprint(out) if err != nil { t.Error(err) } if expected != actual { - debug.Dump(parseResult) + debug.Dump(stmt.Raw) t.Errorf("- %s", expected) t.Errorf("- %s", query) t.Errorf("+ %s", actual) diff --git a/internal/engine/postgresql/convert.go b/internal/engine/postgresql/convert.go index f56a572c16..321294c59e 100644 --- a/internal/engine/postgresql/convert.go +++ b/internal/engine/postgresql/convert.go @@ -1965,6 +1965,22 @@ func convertNullTest(n *pg.NullTest) *ast.NullTest { } } +func convertNullIfExpr(n *pg.NullIfExpr) *ast.NullIfExpr { + if n == nil { + return nil + } + return &ast.NullIfExpr{ + Xpr: convertNode(n.Xpr), + Opno: ast.Oid(n.Opno), + Opresulttype: ast.Oid(n.Opresulttype), + Opretset: n.Opretset, + Opcollid: ast.Oid(n.Opcollid), + Inputcollid: ast.Oid(n.Inputcollid), + Args: convertSlice(n.Args), + Location: int(n.Location), + } +} + func convertObjectWithArgs(n *pg.ObjectWithArgs) *ast.ObjectWithArgs { if n == nil { return nil @@ -3420,6 +3436,9 @@ func convertNode(node *pg.Node) ast.Node { case *pg.Node_NullTest: return convertNullTest(n.NullTest) + case *pg.Node_NullIfExpr: + return convertNullIfExpr(n.NullIfExpr) + case *pg.Node_ObjectWithArgs: return convertObjectWithArgs(n.ObjectWithArgs) diff --git a/internal/engine/postgresql/deparse.go b/internal/engine/postgresql/deparse.go deleted file mode 100644 index 5326eefc05..0000000000 --- a/internal/engine/postgresql/deparse.go +++ /dev/null @@ -1,26 +0,0 @@ -//go:build !windows && cgo - -package postgresql - -import ( - "strings" - - nodes "github.com/pganalyze/pg_query_go/v6" -) - -func Deparse(tree *nodes.ParseResult) (string, error) { - output, err := nodeDeparse(tree) - if err != nil { - return output, err - } - return fixDeparse(output), nil -} - -// fixDeparse corrects known bugs in pg_query_go's Deparse output -func fixDeparse(s string) string { - // Fix missing space before SKIP LOCKED - // pg_query_go outputs "OF tableSKIP LOCKED" instead of "OF table SKIP LOCKED" - s = strings.ReplaceAll(s, "SKIP LOCKED", " SKIP LOCKED") - s = strings.ReplaceAll(s, " SKIP LOCKED", " SKIP LOCKED") // normalize double spaces - return s -} diff --git a/internal/engine/postgresql/deparse_wasi.go b/internal/engine/postgresql/deparse_wasi.go deleted file mode 100644 index 66606abe13..0000000000 --- a/internal/engine/postgresql/deparse_wasi.go +++ /dev/null @@ -1,26 +0,0 @@ -//go:build windows || !cgo - -package postgresql - -import ( - "strings" - - nodes "github.com/wasilibs/go-pgquery" -) - -func Deparse(tree *nodes.ParseResult) (string, error) { - output, err := nodeDeparse(tree) - if err != nil { - return output, err - } - return fixDeparse(output), nil -} - -// fixDeparse corrects known bugs in pg_query's Deparse output -func fixDeparse(s string) string { - // Fix missing space before SKIP LOCKED - // pg_query outputs "OF tableSKIP LOCKED" instead of "OF table SKIP LOCKED" - s = strings.ReplaceAll(s, "SKIP LOCKED", " SKIP LOCKED") - s = strings.ReplaceAll(s, " SKIP LOCKED", " SKIP LOCKED") // normalize double spaces - return s -} diff --git a/internal/engine/postgresql/parse.go b/internal/engine/postgresql/parse.go index 40af125962..ea1648dc5e 100644 --- a/internal/engine/postgresql/parse.go +++ b/internal/engine/postgresql/parse.go @@ -34,6 +34,94 @@ func stringSliceFromNodes(s []*nodes.Node) []string { return items } +func translateNode(node *nodes.Node) ast.Node { + if node == nil { + return nil + } + switch n := node.Node.(type) { + case *nodes.Node_String_: + return &ast.String{Str: n.String_.Sval} + case *nodes.Node_Integer: + return &ast.Integer{Ival: int64(n.Integer.Ival)} + case *nodes.Node_Boolean: + return &ast.Boolean{Boolval: n.Boolean.Boolval} + case *nodes.Node_AConst: + // A_Const contains a constant value (used in type modifiers like varchar(32)) + if n.AConst.GetIval() != nil { + return &ast.Integer{Ival: int64(n.AConst.GetIval().Ival)} + } + if n.AConst.GetSval() != nil { + return &ast.String{Str: n.AConst.GetSval().Sval} + } + if n.AConst.GetFval() != nil { + return &ast.Float{Str: n.AConst.GetFval().Fval} + } + if n.AConst.GetBoolval() != nil { + return &ast.Boolean{Boolval: n.AConst.GetBoolval().Boolval} + } + return &ast.TODO{} + case *nodes.Node_List: + list := &ast.List{} + for _, item := range n.List.Items { + list.Items = append(list.Items, translateNode(item)) + } + return list + default: + return &ast.TODO{} + } +} + +func translateDefElem(n *nodes.DefElem) *ast.DefElem { + if n == nil { + return nil + } + defname := n.Defname + return &ast.DefElem{ + Defname: &defname, + Arg: translateNode(n.Arg), + Location: int(n.Location), + } +} + +func translateOptions(opts []*nodes.Node) *ast.List { + if opts == nil { + return nil + } + list := &ast.List{} + for _, opt := range opts { + if de, ok := opt.Node.(*nodes.Node_DefElem); ok { + list.Items = append(list.Items, translateDefElem(de.DefElem)) + } + } + return list +} + +func translateTypeNameFromPG(tn *nodes.TypeName) *ast.TypeName { + if tn == nil { + return nil + } + rel, err := parseRelationFromNodes(tn.Names) + if err != nil { + return nil + } + result := rel.TypeName() + // Preserve array bounds + if len(tn.ArrayBounds) > 0 { + result.ArrayBounds = &ast.List{} + for _, ab := range tn.ArrayBounds { + result.ArrayBounds.Items = append(result.ArrayBounds.Items, translateNode(ab)) + } + } + // Preserve type modifiers + if len(tn.Typmods) > 0 { + result.Typmods = &ast.List{} + for _, tm := range tn.Typmods { + result.Typmods.Items = append(result.Typmods.Items, translateNode(tm)) + } + } + return result +} + type relation struct { Catalog string Schema string @@ -431,11 +519,6 @@ func translate(node *nodes.Node) (ast.Node, error) { for _, elt := range n.TableElts { switch item := elt.Node.(type) { case *nodes.Node_ColumnDef: - rel, err := parseRelationFromNodes(item.ColumnDef.TypeName.Names) - if err != nil { - return nil, err - } - primary := false for _, con := range item.ColumnDef.Constraints { if constraint, ok := con.Node.(*nodes.Node_Constraint); ok { @@ -445,7 +528,7 @@ func translate(node *nodes.Node) (ast.Node, error) { create.Cols = append(create.Cols, &ast.ColumnDef{ Colname: item.ColumnDef.Colname, - TypeName: rel.TypeName(), + TypeName: translateTypeNameFromPG(item.ColumnDef.TypeName), IsNotNull: isNotNull(item.ColumnDef) || primaryKey[item.ColumnDef.Colname], IsArray: isArray(item.ColumnDef.TypeName), ArrayDims: len(item.ColumnDef.TypeName.ArrayBounds), @@ -494,6 +577,7 @@ func translate(node *nodes.Node) (ast.Node, error) { ReturnType: rt, Replace: n.Replace, Params: &ast.List{}, + Options: translateOptions(n.Options), } for _, item := range n.Parameters { arg := item.Node.(*nodes.Node_FunctionParameter).FunctionParameter diff --git a/internal/engine/postgresql/parse_default.go b/internal/engine/postgresql/parse_default.go index 59eed74565..272f189649 100644 --- a/internal/engine/postgresql/parse_default.go +++ b/internal/engine/postgresql/parse_default.go @@ -8,5 +8,3 @@ import ( var Parse = nodes.Parse var Fingerprint = nodes.Fingerprint - -var nodeDeparse = nodes.Deparse diff --git a/internal/engine/postgresql/parse_wasi.go b/internal/engine/postgresql/parse_wasi.go index 51bba538e0..377b812cdb 100644 --- a/internal/engine/postgresql/parse_wasi.go +++ b/internal/engine/postgresql/parse_wasi.go @@ -8,5 +8,3 @@ import ( var Parse = nodes.Parse var Fingerprint = nodes.Fingerprint - -var nodeDeparse = nodes.Deparse diff --git a/internal/sql/ast/a_expr.go b/internal/sql/ast/a_expr.go index 37484d4ce8..3b73d66d37 100644 --- a/internal/sql/ast/a_expr.go +++ b/internal/sql/ast/a_expr.go @@ -58,16 +58,35 @@ func (n *A_Expr) Format(buf *TrackedBuffer) { buf.astFormat(n.Lexpr) buf.WriteString(" IS NOT DISTINCT FROM ") buf.astFormat(n.Rexpr) + case A_Expr_Kind_NULLIF: + buf.WriteString("NULLIF(") + buf.astFormat(n.Lexpr) + buf.WriteString(", ") + buf.astFormat(n.Rexpr) + buf.WriteString(")") case A_Expr_Kind_OP: - // Standard binary operator - if set(n.Lexpr) { - buf.astFormat(n.Lexpr) - buf.WriteString(" ") + // Check if this is a named parameter (@name) + opName := "" + if n.Name != nil && len(n.Name.Items) == 1 { + if s, ok := n.Name.Items[0].(*String); ok { + opName = s.Str + } } - buf.astFormat(n.Name) - if set(n.Rexpr) { - buf.WriteString(" ") + if opName == "@" && !set(n.Lexpr) && set(n.Rexpr) { + // Named parameter: @name (no space after @) + buf.WriteString("@") buf.astFormat(n.Rexpr) + } else { + // Standard binary operator + if set(n.Lexpr) { + buf.astFormat(n.Lexpr) + buf.WriteString(" ") + } + buf.astFormat(n.Name) + if set(n.Rexpr) { + buf.WriteString(" ") + buf.astFormat(n.Rexpr) + } } default: // Fallback for other cases diff --git a/internal/sql/ast/a_indices.go b/internal/sql/ast/a_indices.go index 8972f3a556..a143ae6d05 100644 --- a/internal/sql/ast/a_indices.go +++ b/internal/sql/ast/a_indices.go @@ -9,3 +9,22 @@ type A_Indices struct { func (n *A_Indices) Pos() int { return 0 } + +func (n *A_Indices) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("[") + if n.IsSlice { + if set(n.Lidx) { + buf.astFormat(n.Lidx) + } + buf.WriteString(":") + if set(n.Uidx) { + buf.astFormat(n.Uidx) + } + } else { + buf.astFormat(n.Uidx) + } + buf.WriteString("]") +} diff --git a/internal/sql/ast/column_ref.go b/internal/sql/ast/column_ref.go index e95b844896..a43641e7ae 100644 --- a/internal/sql/ast/column_ref.go +++ b/internal/sql/ast/column_ref.go @@ -2,6 +2,57 @@ package ast import "strings" +// sqlReservedWords is a set of SQL keywords that must be quoted when used as identifiers +var sqlReservedWords = map[string]bool{ + "all": true, "analyse": true, "analyze": true, "and": true, "any": true, + "array": true, "as": true, "asc": true, "asymmetric": true, "authorization": true, + "between": true, "binary": true, "both": true, "case": true, "cast": true, + "check": true, "collate": true, "collation": true, "column": true, "concurrently": true, + "constraint": true, "create": true, "cross": true, "current_catalog": true, + "current_date": true, "current_role": true, "current_schema": true, + "current_time": true, "current_timestamp": true, "current_user": true, + "default": true, "deferrable": true, "desc": true, "distinct": true, "do": true, + "else": true, "end": true, "except": true, "false": true, "fetch": true, + "for": true, "foreign": true, "freeze": true, "from": true, "full": true, + "grant": true, "group": true, "having": true, "ilike": true, "in": true, + "initially": true, "inner": true, "intersect": true, "into": true, "is": true, + "isnull": true, "join": true, "lateral": true, "leading": true, "left": true, + "like": true, "limit": true, "localtime": true, "localtimestamp": true, + "natural": true, "not": true, "notnull": true, "null": true, "offset": true, + "on": true, "only": true, "or": true, "order": true, "outer": true, + "overlaps": true, "placing": true, "primary": true, "references": true, + "returning": true, "right": true, "select": true, "session_user": true, + "similar": true, "some": true, "symmetric": true, "table": true, "tablesample": true, + "then": true, "to": true, "trailing": true, "true": true, "union": true, + "unique": true, "user": true, "using": true, "variadic": true, "verbose": true, + "when": true, "where": true, "window": true, "with": true, +} + +// needsQuoting returns true if the identifier is a SQL reserved word +// that needs to be quoted when used as an identifier +func needsQuoting(s string) bool { + return sqlReservedWords[strings.ToLower(s)] +} + +// hasMixedCase returns true if the string has any uppercase letters +// (identifiers with mixed case need quoting in PostgreSQL) +func hasMixedCase(s string) bool { + for _, r := range s { + if r >= 'A' && r <= 'Z' { + return true + } + } + return false +} + +// quoteIdent returns a quoted identifier if it needs quoting +func quoteIdent(s string) string { + if needsQuoting(s) || hasMixedCase(s) { + return `"` + s + `"` + } + return s +} + type ColumnRef struct { Name string @@ -24,11 +75,7 @@ func (n *ColumnRef) Format(buf *TrackedBuffer) { for _, item := range n.Fields.Items { switch nn := item.(type) { case *String: - if nn.Str == "user" { - items = append(items, `"user"`) - } else { - items = append(items, nn.Str) - } + items = append(items, quoteIdent(nn.Str)) case *A_Star: items = append(items, "*") } diff --git a/internal/sql/ast/create_extension_stmt.go b/internal/sql/ast/create_extension_stmt.go index 2fe8755b6a..cd12e7505b 100644 --- a/internal/sql/ast/create_extension_stmt.go +++ b/internal/sql/ast/create_extension_stmt.go @@ -9,3 +9,16 @@ type CreateExtensionStmt struct { func (n *CreateExtensionStmt) Pos() int { return 0 } + +func (n *CreateExtensionStmt) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("CREATE EXTENSION ") + if n.IfNotExists { + buf.WriteString("IF NOT EXISTS ") + } + if n.Extname != nil { + buf.WriteString(*n.Extname) + } +} diff --git a/internal/sql/ast/create_function_stmt.go b/internal/sql/ast/create_function_stmt.go index 86605344f7..e070a8720b 100644 --- a/internal/sql/ast/create_function_stmt.go +++ b/internal/sql/ast/create_function_stmt.go @@ -13,3 +13,31 @@ type CreateFunctionStmt struct { func (n *CreateFunctionStmt) Pos() int { return 0 } + +func (n *CreateFunctionStmt) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("CREATE ") + if n.Replace { + buf.WriteString("OR REPLACE ") + } + buf.WriteString("FUNCTION ") + buf.astFormat(n.Func) + buf.WriteString("(") + if items(n.Params) { + buf.join(n.Params, ", ") + } + buf.WriteString(")") + if n.ReturnType != nil { + buf.WriteString(" RETURNS ") + buf.astFormat(n.ReturnType) + } + // Format options (AS, LANGUAGE, etc.) + if items(n.Options) { + for _, opt := range n.Options.Items { + buf.WriteString(" ") + buf.astFormat(opt) + } + } +} diff --git a/internal/sql/ast/def_elem.go b/internal/sql/ast/def_elem.go index 03ecf88e77..d70090339d 100644 --- a/internal/sql/ast/def_elem.go +++ b/internal/sql/ast/def_elem.go @@ -11,3 +11,56 @@ type DefElem struct { func (n *DefElem) Pos() int { return n.Location } + +func (n *DefElem) Format(buf *TrackedBuffer) { + if n == nil { + return + } + if n.Defname != nil { + switch *n.Defname { + case "as": + buf.WriteString("AS ") + // AS clause contains function body which needs quoting + if l, ok := n.Arg.(*List); ok { + for i, item := range l.Items { + if i > 0 { + buf.WriteString(", ") + } + if s, ok := item.(*String); ok { + buf.WriteString("'") + buf.WriteString(s.Str) + buf.WriteString("'") + } else { + buf.astFormat(item) + } + } + } else { + buf.astFormat(n.Arg) + } + case "language": + buf.WriteString("LANGUAGE ") + buf.astFormat(n.Arg) + case "volatility": + // VOLATILE, STABLE, IMMUTABLE + buf.astFormat(n.Arg) + case "strict": + if s, ok := n.Arg.(*Boolean); ok && s.Boolval { + buf.WriteString("STRICT") + } else { + buf.WriteString("CALLED ON NULL INPUT") + } + case "security": + if s, ok := n.Arg.(*Boolean); ok && s.Boolval { + buf.WriteString("SECURITY DEFINER") + } else { + buf.WriteString("SECURITY INVOKER") + } + default: + buf.WriteString(*n.Defname) + if n.Arg != nil { + buf.WriteString(" ") + buf.astFormat(n.Arg) + } + } + } +} diff --git a/internal/sql/ast/do_stmt.go b/internal/sql/ast/do_stmt.go index edc831f15c..a14ddfd537 100644 --- a/internal/sql/ast/do_stmt.go +++ b/internal/sql/ast/do_stmt.go @@ -7,3 +7,22 @@ type DoStmt struct { func (n *DoStmt) Pos() int { return 0 } + +func (n *DoStmt) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("DO ") + // Find the "as" argument which contains the body + if items(n.Args) { + for _, arg := range n.Args.Items { + if de, ok := arg.(*DefElem); ok && de.Defname != nil && *de.Defname == "as" { + if s, ok := de.Arg.(*String); ok { + buf.WriteString("$$") + buf.WriteString(s.Str) + buf.WriteString("$$") + } + } + } + } +} diff --git a/internal/sql/ast/func_call.go b/internal/sql/ast/func_call.go index 5eda3f027f..3b7dcc5400 100644 --- a/internal/sql/ast/func_call.go +++ b/internal/sql/ast/func_call.go @@ -32,11 +32,18 @@ func (n *FuncCall) Format(buf *TrackedBuffer) { } else { buf.astFormat(n.Args) } - if items(n.AggOrder) { + // ORDER BY inside function call (not WITHIN GROUP) + if items(n.AggOrder) && !n.AggWithinGroup { buf.WriteString(" ORDER BY ") buf.join(n.AggOrder, ", ") } buf.WriteString(")") + // WITHIN GROUP clause for ordered-set aggregates + if items(n.AggOrder) && n.AggWithinGroup { + buf.WriteString(" WITHIN GROUP (ORDER BY ") + buf.join(n.AggOrder, ", ") + buf.WriteString(")") + } if set(n.AggFilter) { buf.WriteString(" FILTER (WHERE ") buf.astFormat(n.AggFilter) diff --git a/internal/sql/ast/func_param.go b/internal/sql/ast/func_param.go index b5cf8cfcf0..812d9c629a 100644 --- a/internal/sql/ast/func_param.go +++ b/internal/sql/ast/func_param.go @@ -21,3 +21,25 @@ type FuncParam struct { func (n *FuncParam) Pos() int { return 0 } + +func (n *FuncParam) Format(buf *TrackedBuffer) { + if n == nil { + return + } + // Parameter mode prefix (OUT, INOUT, VARIADIC) + switch n.Mode { + case FuncParamOut: + buf.WriteString("OUT ") + case FuncParamInOut: + buf.WriteString("INOUT ") + case FuncParamVariadic: + buf.WriteString("VARIADIC ") + } + // Parameter name (if present) + if n.Name != nil { + buf.WriteString(*n.Name) + buf.WriteString(" ") + } + // Parameter type + buf.astFormat(n.Type) +} diff --git a/internal/sql/ast/insert_stmt.go b/internal/sql/ast/insert_stmt.go index 90cecf1a46..cbf480b187 100644 --- a/internal/sql/ast/insert_stmt.go +++ b/internal/sql/ast/insert_stmt.go @@ -31,10 +31,11 @@ func (n *InsertStmt) Format(buf *TrackedBuffer) { if items(n.Cols) { buf.WriteString(" (") buf.astFormat(n.Cols) - buf.WriteString(") ") + buf.WriteString(")") } if set(n.SelectStmt) { + buf.WriteString(" ") buf.astFormat(n.SelectStmt) } diff --git a/internal/sql/ast/join_expr.go b/internal/sql/ast/join_expr.go index ba6d2d2298..69c3089b1b 100644 --- a/internal/sql/ast/join_expr.go +++ b/internal/sql/ast/join_expr.go @@ -31,7 +31,12 @@ func (n *JoinExpr) Format(buf *TrackedBuffer) { case JoinTypeFull: buf.WriteString(" FULL JOIN ") case JoinTypeInner: - buf.WriteString(" JOIN ") + // CROSS JOIN has no ON or USING clause + if !items(n.UsingClause) && !set(n.Quals) { + buf.WriteString(" CROSS JOIN ") + } else { + buf.WriteString(" JOIN ") + } default: buf.WriteString(" JOIN ") } diff --git a/internal/sql/ast/range_function.go b/internal/sql/ast/range_function.go index 3ce821482e..6a95388fd1 100644 --- a/internal/sql/ast/range_function.go +++ b/internal/sql/ast/range_function.go @@ -25,7 +25,7 @@ func (n *RangeFunction) Format(buf *TrackedBuffer) { buf.WriteString(" WITH ORDINALITY") } if n.Alias != nil { - buf.WriteString(" ") + buf.WriteString(" AS ") buf.astFormat(n.Alias) } } diff --git a/internal/sql/ast/range_subselect.go b/internal/sql/ast/range_subselect.go index 1506ee7994..a5d63235d3 100644 --- a/internal/sql/ast/range_subselect.go +++ b/internal/sql/ast/range_subselect.go @@ -14,11 +14,14 @@ func (n *RangeSubselect) Format(buf *TrackedBuffer) { if n == nil { return } + if n.Lateral { + buf.WriteString("LATERAL ") + } buf.WriteString("(") buf.astFormat(n.Subquery) buf.WriteString(")") if n.Alias != nil { - buf.WriteString(" ") + buf.WriteString(" AS ") buf.astFormat(n.Alias) } } diff --git a/internal/sql/ast/range_var.go b/internal/sql/ast/range_var.go index 1d1656f6c0..cd3beb45d2 100644 --- a/internal/sql/ast/range_var.go +++ b/internal/sql/ast/range_var.go @@ -19,18 +19,11 @@ func (n *RangeVar) Format(buf *TrackedBuffer) { return } if n.Schemaname != nil { - buf.WriteString(*n.Schemaname) + buf.WriteString(quoteIdent(*n.Schemaname)) buf.WriteString(".") } if n.Relname != nil { - // TODO: What names need to be quoted - if *n.Relname == "user" { - buf.WriteString(`"`) - buf.WriteString(*n.Relname) - buf.WriteString(`"`) - } else { - buf.WriteString(*n.Relname) - } + buf.WriteString(quoteIdent(*n.Relname)) } if n.Alias != nil { buf.WriteString(" ") diff --git a/internal/sql/ast/res_target.go b/internal/sql/ast/res_target.go index 4ee2e72112..25f5ce41ee 100644 --- a/internal/sql/ast/res_target.go +++ b/internal/sql/ast/res_target.go @@ -19,11 +19,11 @@ func (n *ResTarget) Format(buf *TrackedBuffer) { buf.astFormat(n.Val) if n.Name != nil { buf.WriteString(" AS ") - buf.WriteString(*n.Name) + buf.WriteString(quoteIdent(*n.Name)) } } else { if n.Name != nil { - buf.WriteString(*n.Name) + buf.WriteString(quoteIdent(*n.Name)) } } } diff --git a/internal/sql/ast/sort_by.go b/internal/sql/ast/sort_by.go index 21a7a079aa..6d43f541a1 100644 --- a/internal/sql/ast/sort_by.go +++ b/internal/sql/ast/sort_by.go @@ -23,4 +23,10 @@ func (n *SortBy) Format(buf *TrackedBuffer) { case SortByDirDesc: buf.WriteString(" DESC") } + switch n.SortbyNulls { + case SortByNullsFirst: + buf.WriteString(" NULLS FIRST") + case SortByNullsLast: + buf.WriteString(" NULLS LAST") + } } diff --git a/internal/sql/ast/type_name.go b/internal/sql/ast/type_name.go index e26404b3ba..4d169bc8a5 100644 --- a/internal/sql/ast/type_name.go +++ b/internal/sql/ast/type_name.go @@ -20,18 +20,75 @@ func (n *TypeName) Pos() int { return n.Location } +// mapTypeName converts internal PostgreSQL type names to their SQL equivalents +func mapTypeName(name string) string { + switch name { + case "int4": + return "integer" + case "int8": + return "bigint" + case "int2": + return "smallint" + case "float4": + return "real" + case "float8": + return "double precision" + case "bool": + return "boolean" + case "bpchar": + return "character" + case "timestamptz": + return "timestamp with time zone" + case "timetz": + return "time with time zone" + default: + return name + } +} + func (n *TypeName) Format(buf *TrackedBuffer) { if n == nil { return } if items(n.Names) { + // Check if this is a pg_catalog type that should be expanded + if len(n.Names.Items) == 2 { + first, _ := n.Names.Items[0].(*String) + second, _ := n.Names.Items[1].(*String) + if first != nil && second != nil && first.Str == "pg_catalog" { + // pg_catalog.timestamptz -> timestamp with time zone + // pg_catalog.timetz -> time with time zone + // etc. + buf.WriteString(mapTypeName(second.Str)) + goto addMods + } + } + // For single name types, just output as-is (don't expand) + if len(n.Names.Items) == 1 { + if s, ok := n.Names.Items[0].(*String); ok { + buf.WriteString(s.Str) + goto addMods + } + } buf.join(n.Names, ".") + } else if n.Schema == "pg_catalog" { + // pg_catalog.typename -> expanded form (via Schema/Name fields) + buf.WriteString(mapTypeName(n.Name)) + } else if n.Schema != "" { + // schema.typename + buf.WriteString(n.Schema) + buf.WriteString(".") + buf.WriteString(n.Name) } else { - if n.Name == "int4" { - buf.WriteString("INTEGER") - } else { - buf.WriteString(n.Name) - } + // Simple type name - don't expand aliases + buf.WriteString(n.Name) + } +addMods: + // Add type modifiers (e.g., varchar(255)) + if items(n.Typmods) { + buf.WriteString("(") + buf.join(n.Typmods, ", ") + buf.WriteString(")") } if items(n.ArrayBounds) { buf.WriteString("[]") diff --git a/internal/sql/ast/update_stmt.go b/internal/sql/ast/update_stmt.go index efd496ad75..1617023864 100644 --- a/internal/sql/ast/update_stmt.go +++ b/internal/sql/ast/update_stmt.go @@ -79,7 +79,13 @@ func (n *UpdateStmt) Format(buf *TrackedBuffer) { switch nn := item.(type) { case *ResTarget: if nn.Name != nil { - buf.WriteString(*nn.Name) + buf.WriteString(quoteIdent(*nn.Name)) + } + // Handle array subscript indirection (e.g., names[$1]) + if items(nn.Indirection) { + for _, ind := range nn.Indirection.Items { + buf.astFormat(ind) + } } buf.WriteString(" = ") buf.astFormat(nn.Val) From 0e7fa5a50cef6d0b15534109800666371f100208 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sun, 30 Nov 2025 13:05:07 -0800 Subject: [PATCH 6/9] refactor(postgresql): use existing convert functions instead of translate helpers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace custom translate functions (translateTypeNameFromPG, translateOptions, translateNode, translateDefElem) with existing convert.go functions (convertTypeName, convertSlice) to maintain architectural consistency. Both parse.go and convert.go import the same pg_query_go/v6 package, so the types are compatible and the existing convert functions can be used directly. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- internal/engine/postgresql/parse.go | 92 +---------------------------- 1 file changed, 2 insertions(+), 90 deletions(-) diff --git a/internal/engine/postgresql/parse.go b/internal/engine/postgresql/parse.go index ea1648dc5e..ac7d0bcbfa 100644 --- a/internal/engine/postgresql/parse.go +++ b/internal/engine/postgresql/parse.go @@ -34,94 +34,6 @@ func stringSliceFromNodes(s []*nodes.Node) []string { return items } -func translateNode(node *nodes.Node) ast.Node { - if node == nil { - return nil - } - switch n := node.Node.(type) { - case *nodes.Node_String_: - return &ast.String{Str: n.String_.Sval} - case *nodes.Node_Integer: - return &ast.Integer{Ival: int64(n.Integer.Ival)} - case *nodes.Node_Boolean: - return &ast.Boolean{Boolval: n.Boolean.Boolval} - case *nodes.Node_AConst: - // A_Const contains a constant value (used in type modifiers like varchar(32)) - if n.AConst.GetIval() != nil { - return &ast.Integer{Ival: int64(n.AConst.GetIval().Ival)} - } - if n.AConst.GetSval() != nil { - return &ast.String{Str: n.AConst.GetSval().Sval} - } - if n.AConst.GetFval() != nil { - return &ast.Float{Str: n.AConst.GetFval().Fval} - } - if n.AConst.GetBoolval() != nil { - return &ast.Boolean{Boolval: n.AConst.GetBoolval().Boolval} - } - return &ast.TODO{} - case *nodes.Node_List: - list := &ast.List{} - for _, item := range n.List.Items { - list.Items = append(list.Items, translateNode(item)) - } - return list - default: - return &ast.TODO{} - } -} - -func translateDefElem(n *nodes.DefElem) *ast.DefElem { - if n == nil { - return nil - } - defname := n.Defname - return &ast.DefElem{ - Defname: &defname, - Arg: translateNode(n.Arg), - Location: int(n.Location), - } -} - -func translateOptions(opts []*nodes.Node) *ast.List { - if opts == nil { - return nil - } - list := &ast.List{} - for _, opt := range opts { - if de, ok := opt.Node.(*nodes.Node_DefElem); ok { - list.Items = append(list.Items, translateDefElem(de.DefElem)) - } - } - return list -} - -func translateTypeNameFromPG(tn *nodes.TypeName) *ast.TypeName { - if tn == nil { - return nil - } - rel, err := parseRelationFromNodes(tn.Names) - if err != nil { - return nil - } - result := rel.TypeName() - // Preserve array bounds - if len(tn.ArrayBounds) > 0 { - result.ArrayBounds = &ast.List{} - for _, ab := range tn.ArrayBounds { - result.ArrayBounds.Items = append(result.ArrayBounds.Items, translateNode(ab)) - } - } - // Preserve type modifiers - if len(tn.Typmods) > 0 { - result.Typmods = &ast.List{} - for _, tm := range tn.Typmods { - result.Typmods.Items = append(result.Typmods.Items, translateNode(tm)) - } - } - return result -} - type relation struct { Catalog string Schema string @@ -528,7 +440,7 @@ func translate(node *nodes.Node) (ast.Node, error) { create.Cols = append(create.Cols, &ast.ColumnDef{ Colname: item.ColumnDef.Colname, - TypeName: translateTypeNameFromPG(item.ColumnDef.TypeName), + TypeName: convertTypeName(item.ColumnDef.TypeName), IsNotNull: isNotNull(item.ColumnDef) || primaryKey[item.ColumnDef.Colname], IsArray: isArray(item.ColumnDef.TypeName), ArrayDims: len(item.ColumnDef.TypeName.ArrayBounds), @@ -577,7 +489,7 @@ func translate(node *nodes.Node) (ast.Node, error) { ReturnType: rt, Replace: n.Replace, Params: &ast.List{}, - Options: translateOptions(n.Options), + Options: convertSlice(n.Options), } for _, item := range n.Parameters { arg := item.Node.(*nodes.Node_FunctionParameter).FunctionParameter From dc44702bed5dc3919b8606a214bb54fbf9375c33 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sun, 30 Nov 2025 13:20:10 -0800 Subject: [PATCH 7/9] refactor(format): add Formatter interface for SQL dialect-specific quoting MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Create internal/sql/format package with Formatter interface - Add QuoteIdent method to TrackedBuffer that delegates to Formatter - Implement QuoteIdent on postgresql.Parser using existing IsReservedKeyword - Update all Format() methods to use buf.QuoteIdent() instead of local quoteIdent() - Remove duplicate reserved word logic from ast/column_ref.go - Update ast.Format() to accept a Formatter parameter This allows each SQL dialect to provide its own identifier quoting logic based on its reserved keywords and quoting rules. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- internal/endtoend/fmt_test.go | 2 +- internal/engine/postgresql/reserved.go | 20 ++++++++++ internal/sql/ast/column_ref.go | 53 +------------------------- internal/sql/ast/print.go | 28 ++++++++++---- internal/sql/ast/range_var.go | 4 +- internal/sql/ast/res_target.go | 4 +- internal/sql/ast/update_stmt.go | 2 +- internal/sql/format/format.go | 8 ++++ 8 files changed, 55 insertions(+), 66 deletions(-) create mode 100644 internal/sql/format/format.go diff --git a/internal/endtoend/fmt_test.go b/internal/endtoend/fmt_test.go index 53cc7403fc..35b475ca4f 100644 --- a/internal/endtoend/fmt_test.go +++ b/internal/endtoend/fmt_test.go @@ -109,7 +109,7 @@ func TestFormat(t *testing.T) { debug.Dump(r, err) } - out := ast.Format(stmt.Raw) + out := ast.Format(stmt.Raw, parse) actual, err := postgresql.Fingerprint(out) if err != nil { t.Error(err) diff --git a/internal/engine/postgresql/reserved.go b/internal/engine/postgresql/reserved.go index 8f796ffa19..2997e87afe 100644 --- a/internal/engine/postgresql/reserved.go +++ b/internal/engine/postgresql/reserved.go @@ -2,6 +2,26 @@ package postgresql import "strings" +// hasMixedCase returns true if the string has any uppercase letters +// (identifiers with mixed case need quoting in PostgreSQL) +func hasMixedCase(s string) bool { + for _, r := range s { + if r >= 'A' && r <= 'Z' { + return true + } + } + return false +} + +// QuoteIdent returns a quoted identifier if it needs quoting. +// This implements the format.Formatter interface. +func (p *Parser) QuoteIdent(s string) string { + if p.IsReservedKeyword(s) || hasMixedCase(s) { + return `"` + s + `"` + } + return s +} + // https://www.postgresql.org/docs/current/sql-keywords-appendix.html func (p *Parser) IsReservedKeyword(s string) bool { switch strings.ToLower(s) { diff --git a/internal/sql/ast/column_ref.go b/internal/sql/ast/column_ref.go index a43641e7ae..97ea3ab20a 100644 --- a/internal/sql/ast/column_ref.go +++ b/internal/sql/ast/column_ref.go @@ -2,57 +2,6 @@ package ast import "strings" -// sqlReservedWords is a set of SQL keywords that must be quoted when used as identifiers -var sqlReservedWords = map[string]bool{ - "all": true, "analyse": true, "analyze": true, "and": true, "any": true, - "array": true, "as": true, "asc": true, "asymmetric": true, "authorization": true, - "between": true, "binary": true, "both": true, "case": true, "cast": true, - "check": true, "collate": true, "collation": true, "column": true, "concurrently": true, - "constraint": true, "create": true, "cross": true, "current_catalog": true, - "current_date": true, "current_role": true, "current_schema": true, - "current_time": true, "current_timestamp": true, "current_user": true, - "default": true, "deferrable": true, "desc": true, "distinct": true, "do": true, - "else": true, "end": true, "except": true, "false": true, "fetch": true, - "for": true, "foreign": true, "freeze": true, "from": true, "full": true, - "grant": true, "group": true, "having": true, "ilike": true, "in": true, - "initially": true, "inner": true, "intersect": true, "into": true, "is": true, - "isnull": true, "join": true, "lateral": true, "leading": true, "left": true, - "like": true, "limit": true, "localtime": true, "localtimestamp": true, - "natural": true, "not": true, "notnull": true, "null": true, "offset": true, - "on": true, "only": true, "or": true, "order": true, "outer": true, - "overlaps": true, "placing": true, "primary": true, "references": true, - "returning": true, "right": true, "select": true, "session_user": true, - "similar": true, "some": true, "symmetric": true, "table": true, "tablesample": true, - "then": true, "to": true, "trailing": true, "true": true, "union": true, - "unique": true, "user": true, "using": true, "variadic": true, "verbose": true, - "when": true, "where": true, "window": true, "with": true, -} - -// needsQuoting returns true if the identifier is a SQL reserved word -// that needs to be quoted when used as an identifier -func needsQuoting(s string) bool { - return sqlReservedWords[strings.ToLower(s)] -} - -// hasMixedCase returns true if the string has any uppercase letters -// (identifiers with mixed case need quoting in PostgreSQL) -func hasMixedCase(s string) bool { - for _, r := range s { - if r >= 'A' && r <= 'Z' { - return true - } - } - return false -} - -// quoteIdent returns a quoted identifier if it needs quoting -func quoteIdent(s string) string { - if needsQuoting(s) || hasMixedCase(s) { - return `"` + s + `"` - } - return s -} - type ColumnRef struct { Name string @@ -75,7 +24,7 @@ func (n *ColumnRef) Format(buf *TrackedBuffer) { for _, item := range n.Fields.Items { switch nn := item.(type) { case *String: - items = append(items, quoteIdent(nn.Str)) + items = append(items, buf.QuoteIdent(nn.Str)) case *A_Star: items = append(items, "*") } diff --git a/internal/sql/ast/print.go b/internal/sql/ast/print.go index 867a53a177..aa717baad6 100644 --- a/internal/sql/ast/print.go +++ b/internal/sql/ast/print.go @@ -4,26 +4,38 @@ import ( "strings" "github.com/sqlc-dev/sqlc/internal/debug" + "github.com/sqlc-dev/sqlc/internal/sql/format" ) -type formatter interface { +type nodeFormatter interface { Format(*TrackedBuffer) } type TrackedBuffer struct { *strings.Builder + formatter format.Formatter } -// NewTrackedBuffer creates a new TrackedBuffer. -func NewTrackedBuffer() *TrackedBuffer { +// NewTrackedBuffer creates a new TrackedBuffer with the given formatter. +func NewTrackedBuffer(f format.Formatter) *TrackedBuffer { buf := &TrackedBuffer{ - Builder: new(strings.Builder), + Builder: new(strings.Builder), + formatter: f, } return buf } +// QuoteIdent returns a quoted identifier if it needs quoting. +// If no formatter is set, it returns the identifier unchanged. +func (t *TrackedBuffer) QuoteIdent(s string) string { + if t.formatter != nil { + return t.formatter.QuoteIdent(s) + } + return s +} + func (t *TrackedBuffer) astFormat(n Node) { - if ft, ok := n.(formatter); ok { + if ft, ok := n.(nodeFormatter); ok { ft.Format(t) } else { debug.Dump(n) @@ -45,9 +57,9 @@ func (t *TrackedBuffer) join(n *List, sep string) { } } -func Format(n Node) string { - tb := NewTrackedBuffer() - if ft, ok := n.(formatter); ok { +func Format(n Node, f format.Formatter) string { + tb := NewTrackedBuffer(f) + if ft, ok := n.(nodeFormatter); ok { ft.Format(tb) } return tb.String() diff --git a/internal/sql/ast/range_var.go b/internal/sql/ast/range_var.go index cd3beb45d2..b7fb316ee9 100644 --- a/internal/sql/ast/range_var.go +++ b/internal/sql/ast/range_var.go @@ -19,11 +19,11 @@ func (n *RangeVar) Format(buf *TrackedBuffer) { return } if n.Schemaname != nil { - buf.WriteString(quoteIdent(*n.Schemaname)) + buf.WriteString(buf.QuoteIdent(*n.Schemaname)) buf.WriteString(".") } if n.Relname != nil { - buf.WriteString(quoteIdent(*n.Relname)) + buf.WriteString(buf.QuoteIdent(*n.Relname)) } if n.Alias != nil { buf.WriteString(" ") diff --git a/internal/sql/ast/res_target.go b/internal/sql/ast/res_target.go index 25f5ce41ee..b652c2293e 100644 --- a/internal/sql/ast/res_target.go +++ b/internal/sql/ast/res_target.go @@ -19,11 +19,11 @@ func (n *ResTarget) Format(buf *TrackedBuffer) { buf.astFormat(n.Val) if n.Name != nil { buf.WriteString(" AS ") - buf.WriteString(quoteIdent(*n.Name)) + buf.WriteString(buf.QuoteIdent(*n.Name)) } } else { if n.Name != nil { - buf.WriteString(quoteIdent(*n.Name)) + buf.WriteString(buf.QuoteIdent(*n.Name)) } } } diff --git a/internal/sql/ast/update_stmt.go b/internal/sql/ast/update_stmt.go index 1617023864..c98d422130 100644 --- a/internal/sql/ast/update_stmt.go +++ b/internal/sql/ast/update_stmt.go @@ -79,7 +79,7 @@ func (n *UpdateStmt) Format(buf *TrackedBuffer) { switch nn := item.(type) { case *ResTarget: if nn.Name != nil { - buf.WriteString(quoteIdent(*nn.Name)) + buf.WriteString(buf.QuoteIdent(*nn.Name)) } // Handle array subscript indirection (e.g., names[$1]) if items(nn.Indirection) { diff --git a/internal/sql/format/format.go b/internal/sql/format/format.go new file mode 100644 index 0000000000..be9b044aca --- /dev/null +++ b/internal/sql/format/format.go @@ -0,0 +1,8 @@ +package format + +// Formatter provides SQL dialect-specific formatting behavior +type Formatter interface { + // QuoteIdent returns a quoted identifier if it needs quoting + // (e.g., reserved words, mixed case identifiers) + QuoteIdent(s string) string +} From 8194d945b7191e0309364bb4b94a78fb5e3221e4 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sun, 30 Nov 2025 13:25:11 -0800 Subject: [PATCH 8/9] refactor(format): add TypeName method to Formatter interface MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add TypeName(ns, name string) string method to Formatter interface - Implement TypeName on postgresql.Parser with pg_catalog type mappings - Add TypeName method to TrackedBuffer that delegates to Formatter - Update ast.TypeName.Format to use buf.TypeName() - Remove mapTypeName from ast package (moved to postgresql package) This allows each SQL dialect to provide its own type name mappings (e.g., pg_catalog.int4 -> integer for PostgreSQL). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- internal/engine/postgresql/reserved.go | 33 +++++++++++++++++ internal/sql/ast/print.go | 12 +++++++ internal/sql/ast/type_name.go | 50 ++++---------------------- internal/sql/format/format.go | 4 +++ 4 files changed, 55 insertions(+), 44 deletions(-) diff --git a/internal/engine/postgresql/reserved.go b/internal/engine/postgresql/reserved.go index 2997e87afe..0be5c54b8d 100644 --- a/internal/engine/postgresql/reserved.go +++ b/internal/engine/postgresql/reserved.go @@ -22,6 +22,39 @@ func (p *Parser) QuoteIdent(s string) string { return s } +// TypeName returns the SQL type name for the given namespace and name. +// This implements the format.Formatter interface. +func (p *Parser) TypeName(ns, name string) string { + if ns == "pg_catalog" { + switch name { + case "int4": + return "integer" + case "int8": + return "bigint" + case "int2": + return "smallint" + case "float4": + return "real" + case "float8": + return "double precision" + case "bool": + return "boolean" + case "bpchar": + return "character" + case "timestamptz": + return "timestamp with time zone" + case "timetz": + return "time with time zone" + default: + return name + } + } + if ns != "" { + return ns + "." + name + } + return name +} + // https://www.postgresql.org/docs/current/sql-keywords-appendix.html func (p *Parser) IsReservedKeyword(s string) bool { switch strings.ToLower(s) { diff --git a/internal/sql/ast/print.go b/internal/sql/ast/print.go index aa717baad6..8db19ba7d1 100644 --- a/internal/sql/ast/print.go +++ b/internal/sql/ast/print.go @@ -34,6 +34,18 @@ func (t *TrackedBuffer) QuoteIdent(s string) string { return s } +// TypeName returns the SQL type name for the given namespace and name. +// If no formatter is set, it returns "ns.name" or just "name". +func (t *TrackedBuffer) TypeName(ns, name string) string { + if t.formatter != nil { + return t.formatter.TypeName(ns, name) + } + if ns != "" { + return ns + "." + name + } + return name +} + func (t *TrackedBuffer) astFormat(n Node) { if ft, ok := n.(nodeFormatter); ok { ft.Format(t) diff --git a/internal/sql/ast/type_name.go b/internal/sql/ast/type_name.go index 4d169bc8a5..5979d7a90d 100644 --- a/internal/sql/ast/type_name.go +++ b/internal/sql/ast/type_name.go @@ -20,68 +20,30 @@ func (n *TypeName) Pos() int { return n.Location } -// mapTypeName converts internal PostgreSQL type names to their SQL equivalents -func mapTypeName(name string) string { - switch name { - case "int4": - return "integer" - case "int8": - return "bigint" - case "int2": - return "smallint" - case "float4": - return "real" - case "float8": - return "double precision" - case "bool": - return "boolean" - case "bpchar": - return "character" - case "timestamptz": - return "timestamp with time zone" - case "timetz": - return "time with time zone" - default: - return name - } -} - func (n *TypeName) Format(buf *TrackedBuffer) { if n == nil { return } if items(n.Names) { - // Check if this is a pg_catalog type that should be expanded + // Check if this is a qualified type (e.g., pg_catalog.int4) if len(n.Names.Items) == 2 { first, _ := n.Names.Items[0].(*String) second, _ := n.Names.Items[1].(*String) - if first != nil && second != nil && first.Str == "pg_catalog" { - // pg_catalog.timestamptz -> timestamp with time zone - // pg_catalog.timetz -> time with time zone - // etc. - buf.WriteString(mapTypeName(second.Str)) + if first != nil && second != nil { + buf.WriteString(buf.TypeName(first.Str, second.Str)) goto addMods } } - // For single name types, just output as-is (don't expand) + // For single name types, just output as-is if len(n.Names.Items) == 1 { if s, ok := n.Names.Items[0].(*String); ok { - buf.WriteString(s.Str) + buf.WriteString(buf.TypeName("", s.Str)) goto addMods } } buf.join(n.Names, ".") - } else if n.Schema == "pg_catalog" { - // pg_catalog.typename -> expanded form (via Schema/Name fields) - buf.WriteString(mapTypeName(n.Name)) - } else if n.Schema != "" { - // schema.typename - buf.WriteString(n.Schema) - buf.WriteString(".") - buf.WriteString(n.Name) } else { - // Simple type name - don't expand aliases - buf.WriteString(n.Name) + buf.WriteString(buf.TypeName(n.Schema, n.Name)) } addMods: // Add type modifiers (e.g., varchar(255)) diff --git a/internal/sql/format/format.go b/internal/sql/format/format.go index be9b044aca..f47587dd0b 100644 --- a/internal/sql/format/format.go +++ b/internal/sql/format/format.go @@ -5,4 +5,8 @@ type Formatter interface { // QuoteIdent returns a quoted identifier if it needs quoting // (e.g., reserved words, mixed case identifiers) QuoteIdent(s string) string + + // TypeName returns the SQL type name for the given namespace and name. + // This handles dialect-specific type name mappings (e.g., pg_catalog.int4 -> integer) + TypeName(ns, name string) string } From 7a49c3a139a4068bd87b152cfea5512973e8af73 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sun, 30 Nov 2025 13:52:10 -0800 Subject: [PATCH 9/9] fix(postgresql): restore parseRelationFromNodes for column type resolution MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The convertTypeName function populates extra fields (Names, ArrayBounds, Typmods) on the TypeName struct which breaks the catalog's type equality check used for ALTER TYPE RENAME operations. This change: - Reverts to using parseRelationFromNodes + rel.TypeName() which only populates Catalog, Schema, Name fields needed for type resolution - Updates ColumnDef.Format to use IsArray field for array formatting since TypeName.ArrayBounds is no longer set 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- internal/engine/postgresql/parse.go | 7 ++++++- internal/sql/ast/column_def.go | 5 +++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/internal/engine/postgresql/parse.go b/internal/engine/postgresql/parse.go index ac7d0bcbfa..0c6b3a0fc2 100644 --- a/internal/engine/postgresql/parse.go +++ b/internal/engine/postgresql/parse.go @@ -431,6 +431,11 @@ func translate(node *nodes.Node) (ast.Node, error) { for _, elt := range n.TableElts { switch item := elt.Node.(type) { case *nodes.Node_ColumnDef: + rel, err := parseRelationFromNodes(item.ColumnDef.TypeName.Names) + if err != nil { + return nil, err + } + primary := false for _, con := range item.ColumnDef.Constraints { if constraint, ok := con.Node.(*nodes.Node_Constraint); ok { @@ -440,7 +445,7 @@ func translate(node *nodes.Node) (ast.Node, error) { create.Cols = append(create.Cols, &ast.ColumnDef{ Colname: item.ColumnDef.Colname, - TypeName: convertTypeName(item.ColumnDef.TypeName), + TypeName: rel.TypeName(), IsNotNull: isNotNull(item.ColumnDef) || primaryKey[item.ColumnDef.Colname], IsArray: isArray(item.ColumnDef.TypeName), ArrayDims: len(item.ColumnDef.TypeName.ArrayBounds), diff --git a/internal/sql/ast/column_def.go b/internal/sql/ast/column_def.go index f9504eefc7..cd8ba115fc 100644 --- a/internal/sql/ast/column_def.go +++ b/internal/sql/ast/column_def.go @@ -39,6 +39,11 @@ func (n *ColumnDef) Format(buf *TrackedBuffer) { buf.WriteString(n.Colname) buf.WriteString(" ") buf.astFormat(n.TypeName) + // Use IsArray from ColumnDef since TypeName.ArrayBounds may not be set + // (for type resolution compatibility) + if n.IsArray && !items(n.TypeName.ArrayBounds) { + buf.WriteString("[]") + } if n.PrimaryKey { buf.WriteString(" PRIMARY KEY") } else if n.IsNotNull {