From 5c03b471b71fafbe7f86db0418f69e461422f3a2 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sun, 30 Nov 2025 17:45:09 -0800 Subject: [PATCH] refactor(ast): rename Formatter interface to Dialect MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Renames the Formatter interface to Dialect and refactors Format methods to accept the dialect as a parameter instead of storing it in TrackedBuffer. **Interface Changes:** - Rename `format.Formatter` to `format.Dialect` - Change Format signature from `Format(buf *TrackedBuffer)` to `Format(buf *TrackedBuffer, d format.Dialect)` **TrackedBuffer Simplification:** - Remove `formatter` field from TrackedBuffer struct - TrackedBuffer is now a simple strings.Builder wrapper - NewTrackedBuffer() no longer takes a dialect parameter **Method Call Updates:** - `buf.astFormat(x)` → `buf.astFormat(x, d)` - `buf.join(x, sep)` → `buf.join(x, d, sep)` - Helper methods now called directly on dialect: - `buf.QuoteIdent(x)` → `d.QuoteIdent(x)` - `buf.TypeName(ns, name)` → `d.TypeName(ns, name)` - `buf.Param(n)` → `d.Param(n)` - `buf.Cast(arg, t)` → `d.Cast(arg, t)` - `buf.NamedParam(name)` → `d.NamedParam(name)` This change makes the dialect dependency explicit in Format methods and simplifies TrackedBuffer to be purely a buffer utility. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- internal/endtoend/fmt_test.go | 2 +- internal/engine/postgresql/reserved.go | 4 +- internal/sql/ast/CLAUDE.md | 7 +- internal/sql/ast/a_array_expr.go | 6 +- internal/sql/ast/a_const.go | 8 ++- internal/sql/ast/a_expr.go | 52 +++++++------- internal/sql/ast/a_indices.go | 10 +-- internal/sql/ast/a_star.go | 4 +- internal/sql/ast/alias.go | 6 +- internal/sql/ast/alter_table_cmd.go | 6 +- internal/sql/ast/alter_table_stmt.go | 10 +-- internal/sql/ast/between_expr.go | 10 +-- internal/sql/ast/bool_expr.go | 14 ++-- internal/sql/ast/boolean.go | 8 ++- internal/sql/ast/call_stmt.go | 6 +- internal/sql/ast/case_expr.go | 10 +-- internal/sql/ast/case_when.go | 8 ++- internal/sql/ast/coalesce_expr.go | 6 +- internal/sql/ast/collate_expr.go | 8 ++- internal/sql/ast/column_def.go | 8 ++- internal/sql/ast/column_ref.go | 10 ++- internal/sql/ast/common_table_expr.go | 8 ++- internal/sql/ast/create_extension_stmt.go | 4 +- internal/sql/ast/create_function_stmt.go | 12 ++-- internal/sql/ast/create_table_stmt.go | 8 ++- internal/sql/ast/def_elem.go | 14 ++-- internal/sql/ast/delete_stmt.go | 22 +++--- internal/sql/ast/do_stmt.go | 4 +- internal/sql/ast/float.go | 4 +- internal/sql/ast/func_call.go | 16 +++-- internal/sql/ast/func_name.go | 4 +- internal/sql/ast/func_param.go | 6 +- internal/sql/ast/in.go | 10 +-- internal/sql/ast/index_elem.go | 6 +- internal/sql/ast/infer_clause.go | 8 ++- internal/sql/ast/insert_stmt.go | 36 +++++----- internal/sql/ast/integer.go | 8 ++- internal/sql/ast/interval_expr.go | 6 +- internal/sql/ast/join_expr.go | 12 ++-- internal/sql/ast/list.go | 6 +- internal/sql/ast/listen_stmt.go | 4 +- internal/sql/ast/locking_clause.go | 6 +- internal/sql/ast/multi_assign_ref.go | 6 +- internal/sql/ast/named_arg_expr.go | 6 +- internal/sql/ast/notify_stmt.go | 4 +- internal/sql/ast/null.go | 4 +- internal/sql/ast/null_test_expr.go | 6 +- internal/sql/ast/on_conflict_clause.go | 12 ++-- internal/sql/ast/on_duplicate_key_update.go | 8 ++- internal/sql/ast/param_ref.go | 6 +- internal/sql/ast/paren_expr.go | 6 +- internal/sql/ast/print.go | 76 ++++----------------- internal/sql/ast/range_function.go | 8 ++- internal/sql/ast/range_subselect.go | 8 ++- internal/sql/ast/range_var.go | 10 +-- internal/sql/ast/raw_stmt.go | 6 +- internal/sql/ast/refresh_mat_view_stmt.go | 6 +- internal/sql/ast/res_target.go | 10 +-- internal/sql/ast/row_expr.go | 10 +-- internal/sql/ast/scalar_array_op_expr.go | 8 ++- internal/sql/ast/select_stmt.go | 32 +++++---- internal/sql/ast/sort_by.go | 6 +- internal/sql/ast/sql_value_function.go | 4 +- internal/sql/ast/string.go | 4 +- internal/sql/ast/sub_link.go | 8 ++- internal/sql/ast/table_name.go | 4 +- internal/sql/ast/truncate_stmt.go | 6 +- internal/sql/ast/type_cast.go | 14 ++-- internal/sql/ast/type_name.go | 14 ++-- internal/sql/ast/typedefs.go | 6 +- internal/sql/ast/update_stmt.go | 30 ++++---- internal/sql/ast/variable_expr.go | 4 +- internal/sql/ast/window_def.go | 22 +++--- internal/sql/ast/with_clause.go | 6 +- internal/sql/format/format.go | 4 +- internal/sql/rewrite/CLAUDE.md | 2 +- 76 files changed, 435 insertions(+), 338 deletions(-) diff --git a/internal/endtoend/fmt_test.go b/internal/endtoend/fmt_test.go index 550033de49..eac3fa0390 100644 --- a/internal/endtoend/fmt_test.go +++ b/internal/endtoend/fmt_test.go @@ -25,7 +25,7 @@ type sqlParser interface { // sqlFormatter is an interface for formatters type sqlFormatter interface { - format.Formatter + format.Dialect } func TestFormat(t *testing.T) { diff --git a/internal/engine/postgresql/reserved.go b/internal/engine/postgresql/reserved.go index b9ccc76d30..b03a6a7e9f 100644 --- a/internal/engine/postgresql/reserved.go +++ b/internal/engine/postgresql/reserved.go @@ -17,7 +17,7 @@ func hasMixedCase(s string) bool { } // QuoteIdent returns a quoted identifier if it needs quoting. -// This implements the format.Formatter interface. +// This implements the format.Dialect interface. func (p *Parser) QuoteIdent(s string) string { if p.IsReservedKeyword(s) || hasMixedCase(s) { return `"` + s + `"` @@ -26,7 +26,7 @@ func (p *Parser) QuoteIdent(s string) string { } // TypeName returns the SQL type name for the given namespace and name. -// This implements the format.Formatter interface. +// This implements the format.Dialect interface. func (p *Parser) TypeName(ns, name string) string { if ns == "pg_catalog" { switch name { diff --git a/internal/sql/ast/CLAUDE.md b/internal/sql/ast/CLAUDE.md index c55f1340ee..e769fbfca6 100644 --- a/internal/sql/ast/CLAUDE.md +++ b/internal/sql/ast/CLAUDE.md @@ -17,13 +17,14 @@ The `TrackedBuffer` type (`pg_query.go`) handles SQL formatting with dialect-spe - `QuoteIdent(name string)` - quotes identifiers (dialect-specific) - `TypeName(ns, name string)` - formats type names (dialect-specific) -### Formatter Interface -Dialect-specific formatting is handled via the `Formatter` interface: +### Dialect Interface +Dialect-specific formatting is handled via the `Dialect` interface: ```go -type Formatter interface { +type Dialect interface { QuoteIdent(string) string TypeName(ns, name string) string Param(int) string // $1 for PostgreSQL, ? for MySQL + NamedParam(string) string // @name for PostgreSQL, :name for SQLite Cast(string) string } ``` diff --git a/internal/sql/ast/a_array_expr.go b/internal/sql/ast/a_array_expr.go index 970e95deb1..0437dac84f 100644 --- a/internal/sql/ast/a_array_expr.go +++ b/internal/sql/ast/a_array_expr.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type A_ArrayExpr struct { Elements *List Location int @@ -9,11 +11,11 @@ func (n *A_ArrayExpr) Pos() int { return n.Location } -func (n *A_ArrayExpr) Format(buf *TrackedBuffer) { +func (n *A_ArrayExpr) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } buf.WriteString("ARRAY[") - buf.join(n.Elements, ", ") + buf.join(n.Elements, d, ", ") buf.WriteString("]") } diff --git a/internal/sql/ast/a_const.go b/internal/sql/ast/a_const.go index ec1d780945..a6b610e349 100644 --- a/internal/sql/ast/a_const.go +++ b/internal/sql/ast/a_const.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type A_Const struct { Val Node Location int @@ -9,15 +11,15 @@ func (n *A_Const) Pos() int { return n.Location } -func (n *A_Const) Format(buf *TrackedBuffer) { +func (n *A_Const) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } if _, ok := n.Val.(*String); ok { buf.WriteString("'") - buf.astFormat(n.Val) + buf.astFormat(n.Val, d) buf.WriteString("'") } else { - buf.astFormat(n.Val) + buf.astFormat(n.Val, d) } } diff --git a/internal/sql/ast/a_expr.go b/internal/sql/ast/a_expr.go index fc795a77ce..4e67967baa 100644 --- a/internal/sql/ast/a_expr.go +++ b/internal/sql/ast/a_expr.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type A_Expr struct { Kind A_Expr_Kind Name *List @@ -31,75 +33,75 @@ func (n *A_Expr) isNamedParam() (string, bool) { return "", false } -func (n *A_Expr) Format(buf *TrackedBuffer) { +func (n *A_Expr) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } // Check for named parameter first (works regardless of Kind) if name, ok := n.isNamedParam(); ok { - buf.WriteString(buf.NamedParam(name)) + buf.WriteString(d.NamedParam(name)) return } switch n.Kind { case A_Expr_Kind_IN: - buf.astFormat(n.Lexpr) + buf.astFormat(n.Lexpr, d) buf.WriteString(" IN (") - buf.astFormat(n.Rexpr) + buf.astFormat(n.Rexpr, d) buf.WriteString(")") case A_Expr_Kind_LIKE: - buf.astFormat(n.Lexpr) + buf.astFormat(n.Lexpr, d) buf.WriteString(" LIKE ") - buf.astFormat(n.Rexpr) + buf.astFormat(n.Rexpr, d) case A_Expr_Kind_ILIKE: - buf.astFormat(n.Lexpr) + buf.astFormat(n.Lexpr, d) buf.WriteString(" ILIKE ") - buf.astFormat(n.Rexpr) + buf.astFormat(n.Rexpr, d) case A_Expr_Kind_SIMILAR: - buf.astFormat(n.Lexpr) + buf.astFormat(n.Lexpr, d) buf.WriteString(" SIMILAR TO ") - buf.astFormat(n.Rexpr) + buf.astFormat(n.Rexpr, d) case A_Expr_Kind_BETWEEN: - buf.astFormat(n.Lexpr) + buf.astFormat(n.Lexpr, d) buf.WriteString(" BETWEEN ") if l, ok := n.Rexpr.(*List); ok && len(l.Items) == 2 { - buf.astFormat(l.Items[0]) + buf.astFormat(l.Items[0], d) buf.WriteString(" AND ") - buf.astFormat(l.Items[1]) + buf.astFormat(l.Items[1], d) } case A_Expr_Kind_NOT_BETWEEN: - buf.astFormat(n.Lexpr) + buf.astFormat(n.Lexpr, d) buf.WriteString(" NOT BETWEEN ") if l, ok := n.Rexpr.(*List); ok && len(l.Items) == 2 { - buf.astFormat(l.Items[0]) + buf.astFormat(l.Items[0], d) buf.WriteString(" AND ") - buf.astFormat(l.Items[1]) + buf.astFormat(l.Items[1], d) } case A_Expr_Kind_DISTINCT: - buf.astFormat(n.Lexpr) + buf.astFormat(n.Lexpr, d) buf.WriteString(" IS DISTINCT FROM ") - buf.astFormat(n.Rexpr) + buf.astFormat(n.Rexpr, d) case A_Expr_Kind_NOT_DISTINCT: - buf.astFormat(n.Lexpr) + buf.astFormat(n.Lexpr, d) buf.WriteString(" IS NOT DISTINCT FROM ") - buf.astFormat(n.Rexpr) + buf.astFormat(n.Rexpr, d) case A_Expr_Kind_NULLIF: buf.WriteString("NULLIF(") - buf.astFormat(n.Lexpr) + buf.astFormat(n.Lexpr, d) buf.WriteString(", ") - buf.astFormat(n.Rexpr) + buf.astFormat(n.Rexpr, d) buf.WriteString(")") default: // Standard operator (including A_Expr_Kind_OP) if set(n.Lexpr) { - buf.astFormat(n.Lexpr) + buf.astFormat(n.Lexpr, d) buf.WriteString(" ") } - buf.astFormat(n.Name) + buf.astFormat(n.Name, d) if set(n.Rexpr) { buf.WriteString(" ") - buf.astFormat(n.Rexpr) + buf.astFormat(n.Rexpr, d) } } } diff --git a/internal/sql/ast/a_indices.go b/internal/sql/ast/a_indices.go index a143ae6d05..7180f220e7 100644 --- a/internal/sql/ast/a_indices.go +++ b/internal/sql/ast/a_indices.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type A_Indices struct { IsSlice bool Lidx Node @@ -10,21 +12,21 @@ func (n *A_Indices) Pos() int { return 0 } -func (n *A_Indices) Format(buf *TrackedBuffer) { +func (n *A_Indices) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } buf.WriteString("[") if n.IsSlice { if set(n.Lidx) { - buf.astFormat(n.Lidx) + buf.astFormat(n.Lidx, d) } buf.WriteString(":") if set(n.Uidx) { - buf.astFormat(n.Uidx) + buf.astFormat(n.Uidx, d) } } else { - buf.astFormat(n.Uidx) + buf.astFormat(n.Uidx, d) } buf.WriteString("]") } diff --git a/internal/sql/ast/a_star.go b/internal/sql/ast/a_star.go index a43b2ab5b7..7e5f07b96a 100644 --- a/internal/sql/ast/a_star.go +++ b/internal/sql/ast/a_star.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type A_Star struct { } @@ -7,7 +9,7 @@ func (n *A_Star) Pos() int { return 0 } -func (n *A_Star) Format(buf *TrackedBuffer) { +func (n *A_Star) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } diff --git a/internal/sql/ast/alias.go b/internal/sql/ast/alias.go index 55965b55c9..7123982305 100644 --- a/internal/sql/ast/alias.go +++ b/internal/sql/ast/alias.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type Alias struct { Aliasname *string Colnames *List @@ -9,7 +11,7 @@ func (n *Alias) Pos() int { return 0 } -func (n *Alias) Format(buf *TrackedBuffer) { +func (n *Alias) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } @@ -18,7 +20,7 @@ func (n *Alias) Format(buf *TrackedBuffer) { } if items(n.Colnames) { buf.WriteString("(") - buf.astFormat((n.Colnames)) + buf.astFormat(n.Colnames, d) buf.WriteString(")") } } diff --git a/internal/sql/ast/alter_table_cmd.go b/internal/sql/ast/alter_table_cmd.go index 80fad95eaf..90ffd891eb 100644 --- a/internal/sql/ast/alter_table_cmd.go +++ b/internal/sql/ast/alter_table_cmd.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + const ( AT_AddColumn AlterTableType = iota AT_AlterColumnType @@ -40,7 +42,7 @@ func (n *AlterTableCmd) Pos() int { return 0 } -func (n *AlterTableCmd) Format(buf *TrackedBuffer) { +func (n *AlterTableCmd) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } @@ -51,5 +53,5 @@ func (n *AlterTableCmd) Format(buf *TrackedBuffer) { buf.WriteString(" DROP COLUMN ") } - buf.astFormat(n.Def) + buf.astFormat(n.Def, d) } diff --git a/internal/sql/ast/alter_table_stmt.go b/internal/sql/ast/alter_table_stmt.go index 5d4a22f50e..4dc88707ff 100644 --- a/internal/sql/ast/alter_table_stmt.go +++ b/internal/sql/ast/alter_table_stmt.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type AlterTableStmt struct { // TODO: Only TableName or Relation should be defined Relation *RangeVar @@ -13,12 +15,12 @@ func (n *AlterTableStmt) Pos() int { return 0 } -func (n *AlterTableStmt) Format(buf *TrackedBuffer) { +func (n *AlterTableStmt) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } buf.WriteString("ALTER TABLE ") - buf.astFormat(n.Relation) - buf.astFormat(n.Table) - buf.astFormat(n.Cmds) + buf.astFormat(n.Relation, d) + buf.astFormat(n.Table, d) + buf.astFormat(n.Cmds, d) } diff --git a/internal/sql/ast/between_expr.go b/internal/sql/ast/between_expr.go index aa18e6b82a..a160f1892c 100644 --- a/internal/sql/ast/between_expr.go +++ b/internal/sql/ast/between_expr.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type BetweenExpr struct { // Expr is the value expression to be compared. Expr Node @@ -16,17 +18,17 @@ func (n *BetweenExpr) Pos() int { return n.Location } -func (n *BetweenExpr) Format(buf *TrackedBuffer) { +func (n *BetweenExpr) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } - buf.astFormat(n.Expr) + buf.astFormat(n.Expr, d) if n.Not { buf.WriteString(" NOT BETWEEN ") } else { buf.WriteString(" BETWEEN ") } - buf.astFormat(n.Left) + buf.astFormat(n.Left, d) buf.WriteString(" AND ") - buf.astFormat(n.Right) + buf.astFormat(n.Right, d) } diff --git a/internal/sql/ast/bool_expr.go b/internal/sql/ast/bool_expr.go index 0241503a06..f2c0243a9c 100644 --- a/internal/sql/ast/bool_expr.go +++ b/internal/sql/ast/bool_expr.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type BoolExpr struct { Xpr Node Boolop BoolExprType @@ -11,35 +13,35 @@ func (n *BoolExpr) Pos() int { return n.Location } -func (n *BoolExpr) Format(buf *TrackedBuffer) { +func (n *BoolExpr) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } switch n.Boolop { case BoolExprTypeIsNull: if items(n.Args) && len(n.Args.Items) > 0 { - buf.astFormat(n.Args.Items[0]) + buf.astFormat(n.Args.Items[0], d) } buf.WriteString(" IS NULL") case BoolExprTypeIsNotNull: if items(n.Args) && len(n.Args.Items) > 0 { - buf.astFormat(n.Args.Items[0]) + buf.astFormat(n.Args.Items[0], d) } buf.WriteString(" IS NOT NULL") case BoolExprTypeNot: // NOT expression: format as NOT buf.WriteString("NOT ") if items(n.Args) && len(n.Args.Items) > 0 { - buf.astFormat(n.Args.Items[0]) + buf.astFormat(n.Args.Items[0], d) } default: buf.WriteString("(") if items(n.Args) { switch n.Boolop { case BoolExprTypeAnd: - buf.join(n.Args, " AND ") + buf.join(n.Args, d, " AND ") case BoolExprTypeOr: - buf.join(n.Args, " OR ") + buf.join(n.Args, d, " OR ") } } buf.WriteString(")") diff --git a/internal/sql/ast/boolean.go b/internal/sql/ast/boolean.go index 522af84868..16a6db54da 100644 --- a/internal/sql/ast/boolean.go +++ b/internal/sql/ast/boolean.go @@ -1,6 +1,10 @@ package ast -import "fmt" +import ( + "fmt" + + "github.com/sqlc-dev/sqlc/internal/sql/format" +) type Boolean struct { Boolval bool @@ -10,7 +14,7 @@ func (n *Boolean) Pos() int { return 0 } -func (n *Boolean) Format(buf *TrackedBuffer) { +func (n *Boolean) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } diff --git a/internal/sql/ast/call_stmt.go b/internal/sql/ast/call_stmt.go index 5267a1ff3f..6cba39986e 100644 --- a/internal/sql/ast/call_stmt.go +++ b/internal/sql/ast/call_stmt.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type CallStmt struct { FuncCall *FuncCall } @@ -11,7 +13,7 @@ func (n *CallStmt) Pos() int { return n.FuncCall.Pos() } -func (n *CallStmt) Format(buf *TrackedBuffer) { +func (n *CallStmt) Format(buf *TrackedBuffer, d format.Dialect) { buf.WriteString("CALL ") - buf.astFormat(n.FuncCall) + buf.astFormat(n.FuncCall, d) } diff --git a/internal/sql/ast/case_expr.go b/internal/sql/ast/case_expr.go index 1d19dbdeec..52692b297b 100644 --- a/internal/sql/ast/case_expr.go +++ b/internal/sql/ast/case_expr.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type CaseExpr struct { Xpr Node Casetype Oid @@ -14,19 +16,19 @@ func (n *CaseExpr) Pos() int { return n.Location } -func (n *CaseExpr) Format(buf *TrackedBuffer) { +func (n *CaseExpr) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } buf.WriteString("CASE ") if set(n.Arg) { - buf.astFormat(n.Arg) + buf.astFormat(n.Arg, d) buf.WriteString(" ") } - buf.join(n.Args, " ") + buf.join(n.Args, d, " ") if set(n.Defresult) { buf.WriteString(" ELSE ") - buf.astFormat(n.Defresult) + buf.astFormat(n.Defresult, d) } buf.WriteString(" END") } diff --git a/internal/sql/ast/case_when.go b/internal/sql/ast/case_when.go index b036411d54..9636d24a97 100644 --- a/internal/sql/ast/case_when.go +++ b/internal/sql/ast/case_when.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type CaseWhen struct { Xpr Node Expr Node @@ -11,12 +13,12 @@ func (n *CaseWhen) Pos() int { return n.Location } -func (n *CaseWhen) Format(buf *TrackedBuffer) { +func (n *CaseWhen) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } buf.WriteString("WHEN ") - buf.astFormat(n.Expr) + buf.astFormat(n.Expr, d) buf.WriteString(" THEN ") - buf.astFormat(n.Result) + buf.astFormat(n.Result, d) } diff --git a/internal/sql/ast/coalesce_expr.go b/internal/sql/ast/coalesce_expr.go index cbf7025748..0faee5bf4c 100644 --- a/internal/sql/ast/coalesce_expr.go +++ b/internal/sql/ast/coalesce_expr.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type CoalesceExpr struct { Xpr Node Coalescetype Oid @@ -12,11 +14,11 @@ func (n *CoalesceExpr) Pos() int { return n.Location } -func (n *CoalesceExpr) Format(buf *TrackedBuffer) { +func (n *CoalesceExpr) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } buf.WriteString("COALESCE(") - buf.astFormat(n.Args) + buf.astFormat(n.Args, d) buf.WriteString(")") } diff --git a/internal/sql/ast/collate_expr.go b/internal/sql/ast/collate_expr.go index fd9a891e08..80483f75ce 100644 --- a/internal/sql/ast/collate_expr.go +++ b/internal/sql/ast/collate_expr.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type CollateExpr struct { Xpr Node Arg Node @@ -11,11 +13,11 @@ func (n *CollateExpr) Pos() int { return n.Location } -func (n *CollateExpr) Format(buf *TrackedBuffer) { +func (n *CollateExpr) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } - buf.astFormat(n.Xpr) + buf.astFormat(n.Xpr, d) buf.WriteString(" COLLATE ") - buf.astFormat(n.Arg) + buf.astFormat(n.Arg, d) } diff --git a/internal/sql/ast/column_def.go b/internal/sql/ast/column_def.go index cd8ba115fc..225cdd4779 100644 --- a/internal/sql/ast/column_def.go +++ b/internal/sql/ast/column_def.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type ColumnDef struct { Colname string TypeName *TypeName @@ -32,13 +34,13 @@ func (n *ColumnDef) Pos() int { return n.Location } -func (n *ColumnDef) Format(buf *TrackedBuffer) { +func (n *ColumnDef) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } buf.WriteString(n.Colname) buf.WriteString(" ") - buf.astFormat(n.TypeName) + buf.astFormat(n.TypeName, d) // Use IsArray from ColumnDef since TypeName.ArrayBounds may not be set // (for type resolution compatibility) if n.IsArray && !items(n.TypeName.ArrayBounds) { @@ -49,5 +51,5 @@ func (n *ColumnDef) Format(buf *TrackedBuffer) { } else if n.IsNotNull { buf.WriteString(" NOT NULL") } - buf.astFormat(n.Constraints) + buf.astFormat(n.Constraints, d) } diff --git a/internal/sql/ast/column_ref.go b/internal/sql/ast/column_ref.go index 97ea3ab20a..943311799d 100644 --- a/internal/sql/ast/column_ref.go +++ b/internal/sql/ast/column_ref.go @@ -1,6 +1,10 @@ package ast -import "strings" +import ( + "strings" + + "github.com/sqlc-dev/sqlc/internal/sql/format" +) type ColumnRef struct { Name string @@ -14,7 +18,7 @@ func (n *ColumnRef) Pos() int { return n.Location } -func (n *ColumnRef) Format(buf *TrackedBuffer) { +func (n *ColumnRef) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } @@ -24,7 +28,7 @@ func (n *ColumnRef) Format(buf *TrackedBuffer) { for _, item := range n.Fields.Items { switch nn := item.(type) { case *String: - items = append(items, buf.QuoteIdent(nn.Str)) + items = append(items, d.QuoteIdent(nn.Str)) case *A_Star: items = append(items, "*") } diff --git a/internal/sql/ast/common_table_expr.go b/internal/sql/ast/common_table_expr.go index b36b3f23d3..aa334167ce 100644 --- a/internal/sql/ast/common_table_expr.go +++ b/internal/sql/ast/common_table_expr.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type CommonTableExpr struct { Ctename *string Aliascolnames *List @@ -17,7 +19,7 @@ func (n *CommonTableExpr) Pos() int { return n.Location } -func (n *CommonTableExpr) Format(buf *TrackedBuffer) { +func (n *CommonTableExpr) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } @@ -26,10 +28,10 @@ func (n *CommonTableExpr) Format(buf *TrackedBuffer) { } if items(n.Aliascolnames) { buf.WriteString("(") - buf.join(n.Aliascolnames, ", ") + buf.join(n.Aliascolnames, d, ", ") buf.WriteString(")") } buf.WriteString(" AS (") - buf.astFormat(n.Ctequery) + buf.astFormat(n.Ctequery, d) buf.WriteString(")") } diff --git a/internal/sql/ast/create_extension_stmt.go b/internal/sql/ast/create_extension_stmt.go index cd12e7505b..140a10da4c 100644 --- a/internal/sql/ast/create_extension_stmt.go +++ b/internal/sql/ast/create_extension_stmt.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type CreateExtensionStmt struct { Extname *string IfNotExists bool @@ -10,7 +12,7 @@ func (n *CreateExtensionStmt) Pos() int { return 0 } -func (n *CreateExtensionStmt) Format(buf *TrackedBuffer) { +func (n *CreateExtensionStmt) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } diff --git a/internal/sql/ast/create_function_stmt.go b/internal/sql/ast/create_function_stmt.go index e070a8720b..f5200085ee 100644 --- a/internal/sql/ast/create_function_stmt.go +++ b/internal/sql/ast/create_function_stmt.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type CreateFunctionStmt struct { Replace bool Params *List @@ -14,7 +16,7 @@ func (n *CreateFunctionStmt) Pos() int { return 0 } -func (n *CreateFunctionStmt) Format(buf *TrackedBuffer) { +func (n *CreateFunctionStmt) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } @@ -23,21 +25,21 @@ func (n *CreateFunctionStmt) Format(buf *TrackedBuffer) { buf.WriteString("OR REPLACE ") } buf.WriteString("FUNCTION ") - buf.astFormat(n.Func) + buf.astFormat(n.Func, d) buf.WriteString("(") if items(n.Params) { - buf.join(n.Params, ", ") + buf.join(n.Params, d, ", ") } buf.WriteString(")") if n.ReturnType != nil { buf.WriteString(" RETURNS ") - buf.astFormat(n.ReturnType) + buf.astFormat(n.ReturnType, d) } // Format options (AS, LANGUAGE, etc.) if items(n.Options) { for _, opt := range n.Options.Items { buf.WriteString(" ") - buf.astFormat(opt) + buf.astFormat(opt, d) } } } diff --git a/internal/sql/ast/create_table_stmt.go b/internal/sql/ast/create_table_stmt.go index ce88a1b244..f7ab2f9f60 100644 --- a/internal/sql/ast/create_table_stmt.go +++ b/internal/sql/ast/create_table_stmt.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type CreateTableStmt struct { IfNotExists bool Name *TableName @@ -13,19 +15,19 @@ func (n *CreateTableStmt) Pos() int { return 0 } -func (n *CreateTableStmt) Format(buf *TrackedBuffer) { +func (n *CreateTableStmt) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } buf.WriteString("CREATE TABLE ") - buf.astFormat(n.Name) + buf.astFormat(n.Name, d) buf.WriteString("(") for i, col := range n.Cols { if i > 0 { buf.WriteString(", ") } - buf.astFormat(col) + buf.astFormat(col, d) } buf.WriteString(")") } diff --git a/internal/sql/ast/def_elem.go b/internal/sql/ast/def_elem.go index d70090339d..33aacaaa03 100644 --- a/internal/sql/ast/def_elem.go +++ b/internal/sql/ast/def_elem.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type DefElem struct { Defnamespace *string Defname *string @@ -12,7 +14,7 @@ func (n *DefElem) Pos() int { return n.Location } -func (n *DefElem) Format(buf *TrackedBuffer) { +func (n *DefElem) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } @@ -31,18 +33,18 @@ func (n *DefElem) Format(buf *TrackedBuffer) { buf.WriteString(s.Str) buf.WriteString("'") } else { - buf.astFormat(item) + buf.astFormat(item, d) } } } else { - buf.astFormat(n.Arg) + buf.astFormat(n.Arg, d) } case "language": buf.WriteString("LANGUAGE ") - buf.astFormat(n.Arg) + buf.astFormat(n.Arg, d) case "volatility": // VOLATILE, STABLE, IMMUTABLE - buf.astFormat(n.Arg) + buf.astFormat(n.Arg, d) case "strict": if s, ok := n.Arg.(*Boolean); ok && s.Boolval { buf.WriteString("STRICT") @@ -59,7 +61,7 @@ func (n *DefElem) Format(buf *TrackedBuffer) { buf.WriteString(*n.Defname) if n.Arg != nil { buf.WriteString(" ") - buf.astFormat(n.Arg) + buf.astFormat(n.Arg, d) } } } diff --git a/internal/sql/ast/delete_stmt.go b/internal/sql/ast/delete_stmt.go index 828274978e..d23617881a 100644 --- a/internal/sql/ast/delete_stmt.go +++ b/internal/sql/ast/delete_stmt.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type DeleteStmt struct { Relations *List UsingClause *List @@ -16,13 +18,13 @@ func (n *DeleteStmt) Pos() int { return 0 } -func (n *DeleteStmt) Format(buf *TrackedBuffer) { +func (n *DeleteStmt) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } if n.WithClause != nil { - buf.astFormat(n.WithClause) + buf.astFormat(n.WithClause, d) buf.WriteString(" ") } @@ -30,37 +32,37 @@ func (n *DeleteStmt) Format(buf *TrackedBuffer) { // MySQL multi-table DELETE: DELETE t1.*, t2.* FROM t1 JOIN t2 ... if items(n.Targets) { - buf.join(n.Targets, ", ") + buf.join(n.Targets, d, ", ") buf.WriteString(" FROM ") if set(n.FromClause) { - buf.astFormat(n.FromClause) + buf.astFormat(n.FromClause, d) } else if items(n.Relations) { - buf.astFormat(n.Relations) + buf.astFormat(n.Relations, d) } } else { buf.WriteString("FROM ") if items(n.Relations) { - buf.astFormat(n.Relations) + buf.astFormat(n.Relations, d) } } if items(n.UsingClause) { buf.WriteString(" USING ") - buf.join(n.UsingClause, ", ") + buf.join(n.UsingClause, d, ", ") } if set(n.WhereClause) { buf.WriteString(" WHERE ") - buf.astFormat(n.WhereClause) + buf.astFormat(n.WhereClause, d) } if set(n.LimitCount) { buf.WriteString(" LIMIT ") - buf.astFormat(n.LimitCount) + buf.astFormat(n.LimitCount, d) } if items(n.ReturningList) { buf.WriteString(" RETURNING ") - buf.astFormat(n.ReturningList) + buf.astFormat(n.ReturningList, d) } } diff --git a/internal/sql/ast/do_stmt.go b/internal/sql/ast/do_stmt.go index a14ddfd537..9becfb8e64 100644 --- a/internal/sql/ast/do_stmt.go +++ b/internal/sql/ast/do_stmt.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type DoStmt struct { Args *List } @@ -8,7 +10,7 @@ func (n *DoStmt) Pos() int { return 0 } -func (n *DoStmt) Format(buf *TrackedBuffer) { +func (n *DoStmt) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } diff --git a/internal/sql/ast/float.go b/internal/sql/ast/float.go index fee8655bbe..94e8c2652f 100644 --- a/internal/sql/ast/float.go +++ b/internal/sql/ast/float.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type Float struct { Str string } @@ -8,7 +10,7 @@ func (n *Float) Pos() int { return 0 } -func (n *Float) Format(buf *TrackedBuffer) { +func (n *Float) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } diff --git a/internal/sql/ast/func_call.go b/internal/sql/ast/func_call.go index 5f4857a679..cb4f210fe4 100644 --- a/internal/sql/ast/func_call.go +++ b/internal/sql/ast/func_call.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type FuncCall struct { Func *FuncName Funcname *List @@ -19,11 +21,11 @@ func (n *FuncCall) Pos() int { return n.Location } -func (n *FuncCall) Format(buf *TrackedBuffer) { +func (n *FuncCall) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } - buf.astFormat(n.Func) + buf.astFormat(n.Func, d) buf.WriteString("(") if n.AggDistinct { buf.WriteString("DISTINCT ") @@ -31,12 +33,12 @@ func (n *FuncCall) Format(buf *TrackedBuffer) { if n.AggStar { buf.WriteString("*") } else { - buf.astFormat(n.Args) + buf.astFormat(n.Args, d) } // ORDER BY inside function call (not WITHIN GROUP) if items(n.AggOrder) && !n.AggWithinGroup { buf.WriteString(" ORDER BY ") - buf.join(n.AggOrder, ", ") + buf.join(n.AggOrder, d, ", ") } // SEPARATOR for GROUP_CONCAT (MySQL) if n.Separator != nil { @@ -49,16 +51,16 @@ func (n *FuncCall) Format(buf *TrackedBuffer) { // WITHIN GROUP clause for ordered-set aggregates if items(n.AggOrder) && n.AggWithinGroup { buf.WriteString(" WITHIN GROUP (ORDER BY ") - buf.join(n.AggOrder, ", ") + buf.join(n.AggOrder, d, ", ") buf.WriteString(")") } if set(n.AggFilter) { buf.WriteString(" FILTER (WHERE ") - buf.astFormat(n.AggFilter) + buf.astFormat(n.AggFilter, d) buf.WriteString(")") } if n.Over != nil { buf.WriteString(" OVER ") - buf.astFormat(n.Over) + buf.astFormat(n.Over, d) } } diff --git a/internal/sql/ast/func_name.go b/internal/sql/ast/func_name.go index 29b8e0fa61..cdf3e23d33 100644 --- a/internal/sql/ast/func_name.go +++ b/internal/sql/ast/func_name.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type FuncName struct { Catalog string Schema string @@ -10,7 +12,7 @@ func (n *FuncName) Pos() int { return 0 } -func (n *FuncName) Format(buf *TrackedBuffer) { +func (n *FuncName) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } diff --git a/internal/sql/ast/func_param.go b/internal/sql/ast/func_param.go index 812d9c629a..5881a1441f 100644 --- a/internal/sql/ast/func_param.go +++ b/internal/sql/ast/func_param.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type FuncParamMode int const ( @@ -22,7 +24,7 @@ func (n *FuncParam) Pos() int { return 0 } -func (n *FuncParam) Format(buf *TrackedBuffer) { +func (n *FuncParam) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } @@ -41,5 +43,5 @@ func (n *FuncParam) Format(buf *TrackedBuffer) { buf.WriteString(" ") } // Parameter type - buf.astFormat(n.Type) + buf.astFormat(n.Type, d) } diff --git a/internal/sql/ast/in.go b/internal/sql/ast/in.go index 68bd038ad3..9bdad67eeb 100644 --- a/internal/sql/ast/in.go +++ b/internal/sql/ast/in.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + // In describes a 'select foo in (bar, baz)' type statement, though there are multiple important variants handled. type In struct { // Expr is the value expression to be compared. @@ -19,11 +21,11 @@ func (n *In) Pos() int { } // Format formats the In expression. -func (n *In) Format(buf *TrackedBuffer) { +func (n *In) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } - buf.astFormat(n.Expr) + buf.astFormat(n.Expr, d) if n.Not { buf.WriteString(" NOT IN ") } else { @@ -31,7 +33,7 @@ func (n *In) Format(buf *TrackedBuffer) { } if n.Sel != nil { buf.WriteString("(") - buf.astFormat(n.Sel) + buf.astFormat(n.Sel, d) buf.WriteString(")") } else if len(n.List) > 0 { buf.WriteString("(") @@ -39,7 +41,7 @@ func (n *In) Format(buf *TrackedBuffer) { if i > 0 { buf.WriteString(", ") } - buf.astFormat(item) + buf.astFormat(item, d) } buf.WriteString(")") } diff --git a/internal/sql/ast/index_elem.go b/internal/sql/ast/index_elem.go index d1400699ee..acc2a7fc23 100644 --- a/internal/sql/ast/index_elem.go +++ b/internal/sql/ast/index_elem.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type IndexElem struct { Name *string Expr Node @@ -14,13 +16,13 @@ func (n *IndexElem) Pos() int { return 0 } -func (n *IndexElem) Format(buf *TrackedBuffer) { +func (n *IndexElem) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } if n.Name != nil && *n.Name != "" { buf.WriteString(*n.Name) } else if set(n.Expr) { - buf.astFormat(n.Expr) + buf.astFormat(n.Expr, d) } } diff --git a/internal/sql/ast/infer_clause.go b/internal/sql/ast/infer_clause.go index ff3855cae5..6df0db4a86 100644 --- a/internal/sql/ast/infer_clause.go +++ b/internal/sql/ast/infer_clause.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type InferClause struct { IndexElems *List WhereClause Node @@ -11,7 +13,7 @@ func (n *InferClause) Pos() int { return n.Location } -func (n *InferClause) Format(buf *TrackedBuffer) { +func (n *InferClause) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } @@ -20,11 +22,11 @@ func (n *InferClause) Format(buf *TrackedBuffer) { buf.WriteString(*n.Conname) } else if items(n.IndexElems) { buf.WriteString("(") - buf.join(n.IndexElems, ", ") + buf.join(n.IndexElems, d, ", ") buf.WriteString(")") if set(n.WhereClause) { buf.WriteString(" WHERE ") - buf.astFormat(n.WhereClause) + buf.astFormat(n.WhereClause, d) } } } diff --git a/internal/sql/ast/insert_stmt.go b/internal/sql/ast/insert_stmt.go index 75ef44863a..4d5c8d1df2 100644 --- a/internal/sql/ast/insert_stmt.go +++ b/internal/sql/ast/insert_stmt.go @@ -1,38 +1,40 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type InsertStmt struct { - Relation *RangeVar - Cols *List - SelectStmt Node - OnConflictClause *OnConflictClause - OnDuplicateKeyUpdate *OnDuplicateKeyUpdate // MySQL-specific - ReturningList *List - WithClause *WithClause - Override OverridingKind - DefaultValues bool // SQLite-specific: INSERT INTO ... DEFAULT VALUES + Relation *RangeVar + Cols *List + SelectStmt Node + OnConflictClause *OnConflictClause + OnDuplicateKeyUpdate *OnDuplicateKeyUpdate // MySQL-specific + ReturningList *List + WithClause *WithClause + Override OverridingKind + DefaultValues bool // SQLite-specific: INSERT INTO ... DEFAULT VALUES } func (n *InsertStmt) Pos() int { return 0 } -func (n *InsertStmt) Format(buf *TrackedBuffer) { +func (n *InsertStmt) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } if n.WithClause != nil { - buf.astFormat(n.WithClause) + buf.astFormat(n.WithClause, d) buf.WriteString(" ") } buf.WriteString("INSERT INTO ") if n.Relation != nil { - buf.astFormat(n.Relation) + buf.astFormat(n.Relation, d) } if items(n.Cols) { buf.WriteString(" (") - buf.astFormat(n.Cols) + buf.astFormat(n.Cols, d) buf.WriteString(")") } @@ -40,21 +42,21 @@ func (n *InsertStmt) Format(buf *TrackedBuffer) { buf.WriteString(" DEFAULT VALUES") } else if set(n.SelectStmt) { buf.WriteString(" ") - buf.astFormat(n.SelectStmt) + buf.astFormat(n.SelectStmt, d) } if n.OnConflictClause != nil { buf.WriteString(" ") - buf.astFormat(n.OnConflictClause) + buf.astFormat(n.OnConflictClause, d) } if n.OnDuplicateKeyUpdate != nil { buf.WriteString(" ") - buf.astFormat(n.OnDuplicateKeyUpdate) + buf.astFormat(n.OnDuplicateKeyUpdate, d) } if items(n.ReturningList) { buf.WriteString(" RETURNING ") - buf.astFormat(n.ReturningList) + buf.astFormat(n.ReturningList, d) } } diff --git a/internal/sql/ast/integer.go b/internal/sql/ast/integer.go index e9f911add2..c0c360f2f2 100644 --- a/internal/sql/ast/integer.go +++ b/internal/sql/ast/integer.go @@ -1,6 +1,10 @@ package ast -import "strconv" +import ( + "strconv" + + "github.com/sqlc-dev/sqlc/internal/sql/format" +) type Integer struct { Ival int64 @@ -10,7 +14,7 @@ func (n *Integer) Pos() int { return 0 } -func (n *Integer) Format(buf *TrackedBuffer) { +func (n *Integer) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } diff --git a/internal/sql/ast/interval_expr.go b/internal/sql/ast/interval_expr.go index 0572dc6d70..dac73a0557 100644 --- a/internal/sql/ast/interval_expr.go +++ b/internal/sql/ast/interval_expr.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + // IntervalExpr represents a MySQL INTERVAL expression like "INTERVAL 1 DAY" type IntervalExpr struct { Value Node @@ -11,12 +13,12 @@ func (n *IntervalExpr) Pos() int { return n.Location } -func (n *IntervalExpr) Format(buf *TrackedBuffer) { +func (n *IntervalExpr) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } buf.WriteString("INTERVAL ") - buf.astFormat(n.Value) + buf.astFormat(n.Value, d) buf.WriteString(" ") buf.WriteString(n.Unit) } diff --git a/internal/sql/ast/join_expr.go b/internal/sql/ast/join_expr.go index 69c3089b1b..8ac059d006 100644 --- a/internal/sql/ast/join_expr.go +++ b/internal/sql/ast/join_expr.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type JoinExpr struct { Jointype JoinType IsNatural bool @@ -15,11 +17,11 @@ func (n *JoinExpr) Pos() int { return 0 } -func (n *JoinExpr) Format(buf *TrackedBuffer) { +func (n *JoinExpr) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } - buf.astFormat(n.Larg) + buf.astFormat(n.Larg, d) if n.IsNatural { buf.WriteString(" NATURAL") } @@ -40,13 +42,13 @@ func (n *JoinExpr) Format(buf *TrackedBuffer) { default: buf.WriteString(" JOIN ") } - buf.astFormat(n.Rarg) + buf.astFormat(n.Rarg, d) if items(n.UsingClause) { buf.WriteString(" USING (") - buf.join(n.UsingClause, ", ") + buf.join(n.UsingClause, d, ", ") buf.WriteString(")") } else if set(n.Quals) { buf.WriteString(" ON ") - buf.astFormat(n.Quals) + buf.astFormat(n.Quals, d) } } diff --git a/internal/sql/ast/list.go b/internal/sql/ast/list.go index 1c89d55339..38be310e3c 100644 --- a/internal/sql/ast/list.go +++ b/internal/sql/ast/list.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type List struct { Items []Node } @@ -8,9 +10,9 @@ func (n *List) Pos() int { return 0 } -func (n *List) Format(buf *TrackedBuffer) { +func (n *List) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } - buf.join(n, ",") + buf.join(n, d, ",") } diff --git a/internal/sql/ast/listen_stmt.go b/internal/sql/ast/listen_stmt.go index 79c1b132c1..48c38419a8 100644 --- a/internal/sql/ast/listen_stmt.go +++ b/internal/sql/ast/listen_stmt.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type ListenStmt struct { Conditionname *string } @@ -8,7 +10,7 @@ func (n *ListenStmt) Pos() int { return 0 } -func (n *ListenStmt) Format(buf *TrackedBuffer) { +func (n *ListenStmt) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } diff --git a/internal/sql/ast/locking_clause.go b/internal/sql/ast/locking_clause.go index 286d726edd..6202b4ae02 100644 --- a/internal/sql/ast/locking_clause.go +++ b/internal/sql/ast/locking_clause.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type LockingClause struct { LockedRels *List Strength LockClauseStrength @@ -27,7 +29,7 @@ const ( LockWaitPolicyError LockWaitPolicy = 3 ) -func (n *LockingClause) Format(buf *TrackedBuffer) { +func (n *LockingClause) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } @@ -44,7 +46,7 @@ func (n *LockingClause) Format(buf *TrackedBuffer) { } if items(n.LockedRels) { buf.WriteString(" OF ") - buf.join(n.LockedRels, ", ") + buf.join(n.LockedRels, d, ", ") } switch n.WaitPolicy { case LockWaitPolicySkip: diff --git a/internal/sql/ast/multi_assign_ref.go b/internal/sql/ast/multi_assign_ref.go index 16302b4e4c..94b783bcc1 100644 --- a/internal/sql/ast/multi_assign_ref.go +++ b/internal/sql/ast/multi_assign_ref.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type MultiAssignRef struct { Source Node Colno int @@ -10,9 +12,9 @@ func (n *MultiAssignRef) Pos() int { return 0 } -func (n *MultiAssignRef) Format(buf *TrackedBuffer) { +func (n *MultiAssignRef) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } - buf.astFormat(n.Source) + buf.astFormat(n.Source, d) } diff --git a/internal/sql/ast/named_arg_expr.go b/internal/sql/ast/named_arg_expr.go index e37427826e..a711fd2712 100644 --- a/internal/sql/ast/named_arg_expr.go +++ b/internal/sql/ast/named_arg_expr.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type NamedArgExpr struct { Xpr Node Arg Node @@ -12,7 +14,7 @@ func (n *NamedArgExpr) Pos() int { return n.Location } -func (n *NamedArgExpr) Format(buf *TrackedBuffer) { +func (n *NamedArgExpr) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } @@ -20,5 +22,5 @@ func (n *NamedArgExpr) Format(buf *TrackedBuffer) { buf.WriteString(*n.Name) } buf.WriteString(" => ") - buf.astFormat(n.Arg) + buf.astFormat(n.Arg, d) } diff --git a/internal/sql/ast/notify_stmt.go b/internal/sql/ast/notify_stmt.go index 0c50a11123..abecb94360 100644 --- a/internal/sql/ast/notify_stmt.go +++ b/internal/sql/ast/notify_stmt.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type NotifyStmt struct { Conditionname *string Payload *string @@ -9,7 +11,7 @@ func (n *NotifyStmt) Pos() int { return 0 } -func (n *NotifyStmt) Format(buf *TrackedBuffer) { +func (n *NotifyStmt) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } diff --git a/internal/sql/ast/null.go b/internal/sql/ast/null.go index 380c8e7372..e3606e2d7f 100644 --- a/internal/sql/ast/null.go +++ b/internal/sql/ast/null.go @@ -1,11 +1,13 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type Null struct { } func (n *Null) Pos() int { return 0 } -func (n *Null) Format(buf *TrackedBuffer) { +func (n *Null) Format(buf *TrackedBuffer, d format.Dialect) { buf.WriteString("NULL") } diff --git a/internal/sql/ast/null_test_expr.go b/internal/sql/ast/null_test_expr.go index 42059bca6e..3436bff0a5 100644 --- a/internal/sql/ast/null_test_expr.go +++ b/internal/sql/ast/null_test_expr.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type NullTest struct { Xpr Node Arg Node @@ -18,11 +20,11 @@ const ( NullTestTypeIsNotNull NullTestType = 2 ) -func (n *NullTest) Format(buf *TrackedBuffer) { +func (n *NullTest) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } - buf.astFormat(n.Arg) + buf.astFormat(n.Arg, d) switch n.Nulltesttype { case NullTestTypeIsNull: buf.WriteString(" IS NULL") diff --git a/internal/sql/ast/on_conflict_clause.go b/internal/sql/ast/on_conflict_clause.go index 055532fb3c..a71bae0a23 100644 --- a/internal/sql/ast/on_conflict_clause.go +++ b/internal/sql/ast/on_conflict_clause.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type OnConflictClause struct { Action OnConflictAction Infer *InferClause @@ -20,13 +22,13 @@ const ( OnConflictActionUpdate OnConflictAction = 3 ) -func (n *OnConflictClause) Format(buf *TrackedBuffer) { +func (n *OnConflictClause) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } buf.WriteString("ON CONFLICT ") if n.Infer != nil { - buf.astFormat(n.Infer) + buf.astFormat(n.Infer, d) buf.WriteString(" ") } switch n.Action { @@ -45,15 +47,15 @@ func (n *OnConflictClause) Format(buf *TrackedBuffer) { buf.WriteString(*rt.Name) } buf.WriteString(" = ") - buf.astFormat(rt.Val) + buf.astFormat(rt.Val, d) } else { - buf.astFormat(item) + buf.astFormat(item, d) } } } if set(n.WhereClause) { buf.WriteString(" WHERE ") - buf.astFormat(n.WhereClause) + buf.astFormat(n.WhereClause, d) } } } diff --git a/internal/sql/ast/on_duplicate_key_update.go b/internal/sql/ast/on_duplicate_key_update.go index ad5b7672d1..a11ce1ab18 100644 --- a/internal/sql/ast/on_duplicate_key_update.go +++ b/internal/sql/ast/on_duplicate_key_update.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + // OnDuplicateKeyUpdate represents MySQL's ON DUPLICATE KEY UPDATE clause type OnDuplicateKeyUpdate struct { // TargetList contains the assignments (column = value pairs) @@ -11,7 +13,7 @@ func (n *OnDuplicateKeyUpdate) Pos() int { return n.Location } -func (n *OnDuplicateKeyUpdate) Format(buf *TrackedBuffer) { +func (n *OnDuplicateKeyUpdate) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } @@ -26,9 +28,9 @@ func (n *OnDuplicateKeyUpdate) Format(buf *TrackedBuffer) { buf.WriteString(*rt.Name) } buf.WriteString(" = ") - buf.astFormat(rt.Val) + buf.astFormat(rt.Val, d) } else { - buf.astFormat(item) + buf.astFormat(item, d) } } } diff --git a/internal/sql/ast/param_ref.go b/internal/sql/ast/param_ref.go index 0558f78bdf..7ebc897a95 100644 --- a/internal/sql/ast/param_ref.go +++ b/internal/sql/ast/param_ref.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type ParamRef struct { Number int Location int @@ -10,9 +12,9 @@ func (n *ParamRef) Pos() int { return n.Location } -func (n *ParamRef) Format(buf *TrackedBuffer) { +func (n *ParamRef) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } - buf.WriteString(buf.Param(n.Number)) + buf.WriteString(d.Param(n.Number)) } diff --git a/internal/sql/ast/paren_expr.go b/internal/sql/ast/paren_expr.go index ee57ac55d7..831d461f3e 100644 --- a/internal/sql/ast/paren_expr.go +++ b/internal/sql/ast/paren_expr.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + // ParenExpr represents a parenthesized expression type ParenExpr struct { Expr Node @@ -10,11 +12,11 @@ func (n *ParenExpr) Pos() int { return n.Location } -func (n *ParenExpr) Format(buf *TrackedBuffer) { +func (n *ParenExpr) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } buf.WriteString("(") - buf.astFormat(n.Expr) + buf.astFormat(n.Expr, d) buf.WriteString(")") } diff --git a/internal/sql/ast/print.go b/internal/sql/ast/print.go index 6335846946..87f6107622 100644 --- a/internal/sql/ast/print.go +++ b/internal/sql/ast/print.go @@ -1,7 +1,6 @@ package ast import ( - "fmt" "strings" "github.com/sqlc-dev/sqlc/internal/debug" @@ -9,80 +8,29 @@ import ( ) type nodeFormatter interface { - Format(*TrackedBuffer) + Format(*TrackedBuffer, format.Dialect) } type TrackedBuffer struct { *strings.Builder - formatter format.Formatter } -// NewTrackedBuffer creates a new TrackedBuffer with the given formatter. -func NewTrackedBuffer(f format.Formatter) *TrackedBuffer { - buf := &TrackedBuffer{ - Builder: new(strings.Builder), - formatter: f, +// NewTrackedBuffer creates a new TrackedBuffer. +func NewTrackedBuffer() *TrackedBuffer { + return &TrackedBuffer{ + Builder: new(strings.Builder), } - return buf } -// QuoteIdent returns a quoted identifier if it needs quoting. -// If no formatter is set, it returns the identifier unchanged. -func (t *TrackedBuffer) QuoteIdent(s string) string { - if t.formatter != nil { - return t.formatter.QuoteIdent(s) - } - return s -} - -// TypeName returns the SQL type name for the given namespace and name. -// If no formatter is set, it returns "ns.name" or just "name". -func (t *TrackedBuffer) TypeName(ns, name string) string { - if t.formatter != nil { - return t.formatter.TypeName(ns, name) - } - if ns != "" { - return ns + "." + name - } - return name -} - -// Param returns the parameter placeholder for the given number. -// If no formatter is set, it returns PostgreSQL-style $n. -func (t *TrackedBuffer) Param(n int) string { - if t.formatter != nil { - return t.formatter.Param(n) - } - return fmt.Sprintf("$%d", n) -} - -// Cast returns a type cast expression. -// If no formatter is set, it returns PostgreSQL-style expr::type. -func (t *TrackedBuffer) Cast(arg, typeName string) string { - if t.formatter != nil { - return t.formatter.Cast(arg, typeName) - } - return arg + "::" + typeName -} - -// NamedParam returns the named parameter placeholder for the given name. -// If no formatter is set, it returns PostgreSQL-style @name. -func (t *TrackedBuffer) NamedParam(name string) string { - if t.formatter != nil { - return t.formatter.NamedParam(name) - } - return "@" + name -} - -func (t *TrackedBuffer) astFormat(n Node) { +func (t *TrackedBuffer) astFormat(n Node, d format.Dialect) { if ft, ok := n.(nodeFormatter); ok { - ft.Format(t) + ft.Format(t, d) } else { debug.Dump(n) } } -func (t *TrackedBuffer) join(n *List, sep string) { +func (t *TrackedBuffer) join(n *List, d format.Dialect, sep string) { if n == nil { return } @@ -93,14 +41,14 @@ func (t *TrackedBuffer) join(n *List, sep string) { if i > 0 { t.WriteString(sep) } - t.astFormat(item) + t.astFormat(item, d) } } -func Format(n Node, f format.Formatter) string { - tb := NewTrackedBuffer(f) +func Format(n Node, d format.Dialect) string { + tb := NewTrackedBuffer() if ft, ok := n.(nodeFormatter); ok { - ft.Format(tb) + ft.Format(tb, d) } return tb.String() } diff --git a/internal/sql/ast/range_function.go b/internal/sql/ast/range_function.go index 6a95388fd1..dca63595d8 100644 --- a/internal/sql/ast/range_function.go +++ b/internal/sql/ast/range_function.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type RangeFunction struct { Lateral bool Ordinality bool @@ -13,19 +15,19 @@ func (n *RangeFunction) Pos() int { return 0 } -func (n *RangeFunction) Format(buf *TrackedBuffer) { +func (n *RangeFunction) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } if n.Lateral { buf.WriteString("LATERAL ") } - buf.astFormat(n.Functions) + buf.astFormat(n.Functions, d) if n.Ordinality { buf.WriteString(" WITH ORDINALITY") } if n.Alias != nil { buf.WriteString(" AS ") - buf.astFormat(n.Alias) + buf.astFormat(n.Alias, d) } } diff --git a/internal/sql/ast/range_subselect.go b/internal/sql/ast/range_subselect.go index a5d63235d3..51a8825e2b 100644 --- a/internal/sql/ast/range_subselect.go +++ b/internal/sql/ast/range_subselect.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type RangeSubselect struct { Lateral bool Subquery Node @@ -10,7 +12,7 @@ func (n *RangeSubselect) Pos() int { return 0 } -func (n *RangeSubselect) Format(buf *TrackedBuffer) { +func (n *RangeSubselect) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } @@ -18,10 +20,10 @@ func (n *RangeSubselect) Format(buf *TrackedBuffer) { buf.WriteString("LATERAL ") } buf.WriteString("(") - buf.astFormat(n.Subquery) + buf.astFormat(n.Subquery, d) buf.WriteString(")") if n.Alias != nil { buf.WriteString(" AS ") - buf.astFormat(n.Alias) + buf.astFormat(n.Alias, d) } } diff --git a/internal/sql/ast/range_var.go b/internal/sql/ast/range_var.go index 5fd6db535f..250b2b3bbf 100644 --- a/internal/sql/ast/range_var.go +++ b/internal/sql/ast/range_var.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type RangeVar struct { Catalogname *string Schemaname *string @@ -14,19 +16,19 @@ func (n *RangeVar) Pos() int { return n.Location } -func (n *RangeVar) Format(buf *TrackedBuffer) { +func (n *RangeVar) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } if n.Schemaname != nil && *n.Schemaname != "" { - buf.WriteString(buf.QuoteIdent(*n.Schemaname)) + buf.WriteString(d.QuoteIdent(*n.Schemaname)) buf.WriteString(".") } if n.Relname != nil { - buf.WriteString(buf.QuoteIdent(*n.Relname)) + buf.WriteString(d.QuoteIdent(*n.Relname)) } if n.Alias != nil { buf.WriteString(" AS ") - buf.astFormat(n.Alias) + buf.astFormat(n.Alias, d) } } diff --git a/internal/sql/ast/raw_stmt.go b/internal/sql/ast/raw_stmt.go index 55192d2eec..fe02bed803 100644 --- a/internal/sql/ast/raw_stmt.go +++ b/internal/sql/ast/raw_stmt.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type RawStmt struct { Stmt Node StmtLocation int @@ -10,9 +12,9 @@ func (n *RawStmt) Pos() int { return n.StmtLocation } -func (n *RawStmt) Format(buf *TrackedBuffer) { +func (n *RawStmt) Format(buf *TrackedBuffer, d format.Dialect) { if n.Stmt != nil { - buf.astFormat(n.Stmt) + buf.astFormat(n.Stmt, d) } buf.WriteString(";") } diff --git a/internal/sql/ast/refresh_mat_view_stmt.go b/internal/sql/ast/refresh_mat_view_stmt.go index e9b3e26bfa..f627e7bf21 100644 --- a/internal/sql/ast/refresh_mat_view_stmt.go +++ b/internal/sql/ast/refresh_mat_view_stmt.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type RefreshMatViewStmt struct { Concurrent bool SkipData bool @@ -10,10 +12,10 @@ func (n *RefreshMatViewStmt) Pos() int { return 0 } -func (n *RefreshMatViewStmt) Format(buf *TrackedBuffer) { +func (n *RefreshMatViewStmt) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } buf.WriteString("REFRESH MATERIALIZED VIEW ") - buf.astFormat(n.Relation) + buf.astFormat(n.Relation, d) } diff --git a/internal/sql/ast/res_target.go b/internal/sql/ast/res_target.go index b652c2293e..dc34879942 100644 --- a/internal/sql/ast/res_target.go +++ b/internal/sql/ast/res_target.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type ResTarget struct { Name *string Indirection *List @@ -11,19 +13,19 @@ func (n *ResTarget) Pos() int { return n.Location } -func (n *ResTarget) Format(buf *TrackedBuffer) { +func (n *ResTarget) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } if set(n.Val) { - buf.astFormat(n.Val) + buf.astFormat(n.Val, d) if n.Name != nil { buf.WriteString(" AS ") - buf.WriteString(buf.QuoteIdent(*n.Name)) + buf.WriteString(d.QuoteIdent(*n.Name)) } } else { if n.Name != nil { - buf.WriteString(buf.QuoteIdent(*n.Name)) + buf.WriteString(d.QuoteIdent(*n.Name)) } } } diff --git a/internal/sql/ast/row_expr.go b/internal/sql/ast/row_expr.go index 14804f5821..0f8578355a 100644 --- a/internal/sql/ast/row_expr.go +++ b/internal/sql/ast/row_expr.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type RowExpr struct { Xpr Node Args *List @@ -13,17 +15,17 @@ func (n *RowExpr) Pos() int { return n.Location } -func (n *RowExpr) Format(buf *TrackedBuffer) { +func (n *RowExpr) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } if items(n.Args) { buf.WriteString("args") - buf.astFormat(n.Args) + buf.astFormat(n.Args, d) } - buf.astFormat(n.Xpr) + buf.astFormat(n.Xpr, d) if items(n.Colnames) { buf.WriteString("cols") - buf.astFormat(n.Colnames) + buf.astFormat(n.Colnames, d) } } diff --git a/internal/sql/ast/scalar_array_op_expr.go b/internal/sql/ast/scalar_array_op_expr.go index f887bf6508..b4f36548b3 100644 --- a/internal/sql/ast/scalar_array_op_expr.go +++ b/internal/sql/ast/scalar_array_op_expr.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type ScalarArrayOpExpr struct { Xpr Node Opno Oid @@ -13,21 +15,21 @@ func (n *ScalarArrayOpExpr) Pos() int { return n.Location } -func (n *ScalarArrayOpExpr) Format(buf *TrackedBuffer) { +func (n *ScalarArrayOpExpr) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } // ScalarArrayOpExpr represents "scalar op ANY/ALL (array)" // Args[0] is the left operand, Args[1] is the array if n.Args != nil && len(n.Args.Items) >= 2 { - buf.astFormat(n.Args.Items[0]) + buf.astFormat(n.Args.Items[0], d) buf.WriteString(" = ") // TODO: Use actual operator based on Opno if n.UseOr { buf.WriteString("ANY(") } else { buf.WriteString("ALL(") } - buf.astFormat(n.Args.Items[1]) + buf.astFormat(n.Args.Items[1], d) buf.WriteString(")") } } diff --git a/internal/sql/ast/select_stmt.go b/internal/sql/ast/select_stmt.go index a0f0fd4f43..8c3606dd4d 100644 --- a/internal/sql/ast/select_stmt.go +++ b/internal/sql/ast/select_stmt.go @@ -2,6 +2,8 @@ package ast import ( "fmt" + + "github.com/sqlc-dev/sqlc/internal/sql/format" ) type SelectStmt struct { @@ -29,25 +31,25 @@ func (n *SelectStmt) Pos() int { return 0 } -func (n *SelectStmt) Format(buf *TrackedBuffer) { +func (n *SelectStmt) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } if items(n.ValuesLists) { buf.WriteString("VALUES (") - buf.astFormat(n.ValuesLists) + buf.astFormat(n.ValuesLists, d) buf.WriteString(")") return } if n.WithClause != nil { - buf.astFormat(n.WithClause) + buf.astFormat(n.WithClause, d) buf.WriteString(" ") } if n.Larg != nil && n.Rarg != nil { - buf.astFormat(n.Larg) + buf.astFormat(n.Larg, d) switch n.Op { case Union: buf.WriteString(" UNION ") @@ -59,7 +61,7 @@ func (n *SelectStmt) Format(buf *TrackedBuffer) { if n.All { buf.WriteString("ALL ") } - buf.astFormat(n.Rarg) + buf.astFormat(n.Rarg, d) } else { buf.WriteString("SELECT ") } @@ -68,50 +70,50 @@ func (n *SelectStmt) Format(buf *TrackedBuffer) { buf.WriteString("DISTINCT ") if !todo(n.DistinctClause) { fmt.Fprintf(buf, "ON (") - buf.astFormat(n.DistinctClause) + buf.astFormat(n.DistinctClause, d) fmt.Fprintf(buf, ")") } } - buf.astFormat(n.TargetList) + buf.astFormat(n.TargetList, d) if items(n.FromClause) { buf.WriteString(" FROM ") - buf.astFormat(n.FromClause) + buf.astFormat(n.FromClause, d) } if set(n.WhereClause) { buf.WriteString(" WHERE ") - buf.astFormat(n.WhereClause) + buf.astFormat(n.WhereClause, d) } if items(n.GroupClause) { buf.WriteString(" GROUP BY ") - buf.astFormat(n.GroupClause) + buf.astFormat(n.GroupClause, d) } if set(n.HavingClause) { buf.WriteString(" HAVING ") - buf.astFormat(n.HavingClause) + buf.astFormat(n.HavingClause, d) } if items(n.SortClause) { buf.WriteString(" ORDER BY ") - buf.astFormat(n.SortClause) + buf.astFormat(n.SortClause, d) } if set(n.LimitCount) { buf.WriteString(" LIMIT ") - buf.astFormat(n.LimitCount) + buf.astFormat(n.LimitCount, d) } if set(n.LimitOffset) { buf.WriteString(" OFFSET ") - buf.astFormat(n.LimitOffset) + buf.astFormat(n.LimitOffset, d) } if items(n.LockingClause) { buf.WriteString(" ") - buf.astFormat(n.LockingClause) + buf.astFormat(n.LockingClause, d) } } diff --git a/internal/sql/ast/sort_by.go b/internal/sql/ast/sort_by.go index 6d43f541a1..b8634b7d6d 100644 --- a/internal/sql/ast/sort_by.go +++ b/internal/sql/ast/sort_by.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type SortBy struct { Node Node SortbyDir SortByDir @@ -12,11 +14,11 @@ func (n *SortBy) Pos() int { return n.Location } -func (n *SortBy) Format(buf *TrackedBuffer) { +func (n *SortBy) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } - buf.astFormat(n.Node) + buf.astFormat(n.Node, d) switch n.SortbyDir { case SortByDirAsc: buf.WriteString(" ASC") diff --git a/internal/sql/ast/sql_value_function.go b/internal/sql/ast/sql_value_function.go index 0bd0777374..31bd008245 100644 --- a/internal/sql/ast/sql_value_function.go +++ b/internal/sql/ast/sql_value_function.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type SQLValueFunction struct { Xpr Node Op SQLValueFunctionOp @@ -12,7 +14,7 @@ func (n *SQLValueFunction) Pos() int { return n.Location } -func (n *SQLValueFunction) Format(buf *TrackedBuffer) { +func (n *SQLValueFunction) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } diff --git a/internal/sql/ast/string.go b/internal/sql/ast/string.go index 977fc19a2f..d167ef4575 100644 --- a/internal/sql/ast/string.go +++ b/internal/sql/ast/string.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type String struct { Str string } @@ -8,7 +10,7 @@ func (n *String) Pos() int { return 0 } -func (n *String) Format(buf *TrackedBuffer) { +func (n *String) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } diff --git a/internal/sql/ast/sub_link.go b/internal/sql/ast/sub_link.go index 369b41ed86..99b8458afe 100644 --- a/internal/sql/ast/sub_link.go +++ b/internal/sql/ast/sub_link.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type SubLinkType uint const ( @@ -27,14 +29,14 @@ func (n *SubLink) Pos() int { return n.Location } -func (n *SubLink) Format(buf *TrackedBuffer) { +func (n *SubLink) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } // Format the test expression if present (for IN subqueries etc.) hasTestExpr := n.Testexpr != nil if hasTestExpr { - buf.astFormat(n.Testexpr) + buf.astFormat(n.Testexpr, d) } switch n.SubLinkType { case EXISTS_SUBLINK: @@ -52,6 +54,6 @@ func (n *SubLink) Format(buf *TrackedBuffer) { buf.WriteString("(") } } - buf.astFormat(n.Subselect) + buf.astFormat(n.Subselect, d) buf.WriteString(")") } diff --git a/internal/sql/ast/table_name.go b/internal/sql/ast/table_name.go index a95a510c83..4f494a67e0 100644 --- a/internal/sql/ast/table_name.go +++ b/internal/sql/ast/table_name.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type TableName struct { Catalog string Schema string @@ -10,7 +12,7 @@ func (n *TableName) Pos() int { return 0 } -func (n *TableName) Format(buf *TrackedBuffer) { +func (n *TableName) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } diff --git a/internal/sql/ast/truncate_stmt.go b/internal/sql/ast/truncate_stmt.go index f23a5bbcb3..6636e9f9e8 100644 --- a/internal/sql/ast/truncate_stmt.go +++ b/internal/sql/ast/truncate_stmt.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type TruncateStmt struct { Relations *List RestartSeqs bool @@ -10,10 +12,10 @@ func (n *TruncateStmt) Pos() int { return 0 } -func (n *TruncateStmt) Format(buf *TrackedBuffer) { +func (n *TruncateStmt) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } buf.WriteString("TRUNCATE ") - buf.astFormat(n.Relations) + buf.astFormat(n.Relations, d) } diff --git a/internal/sql/ast/type_cast.go b/internal/sql/ast/type_cast.go index 163d145dbc..fe5b321abf 100644 --- a/internal/sql/ast/type_cast.go +++ b/internal/sql/ast/type_cast.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type TypeCast struct { Arg Node TypeName *TypeName @@ -10,16 +12,16 @@ func (n *TypeCast) Pos() int { return n.Location } -func (n *TypeCast) Format(buf *TrackedBuffer) { +func (n *TypeCast) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } // Format the arg and type to strings first - argBuf := NewTrackedBuffer(buf.formatter) - argBuf.astFormat(n.Arg) + argBuf := NewTrackedBuffer() + argBuf.astFormat(n.Arg, d) - typeBuf := NewTrackedBuffer(buf.formatter) - typeBuf.astFormat(n.TypeName) + typeBuf := NewTrackedBuffer() + typeBuf.astFormat(n.TypeName, d) - buf.WriteString(buf.Cast(argBuf.String(), typeBuf.String())) + buf.WriteString(d.Cast(argBuf.String(), typeBuf.String())) } diff --git a/internal/sql/ast/type_name.go b/internal/sql/ast/type_name.go index 5979d7a90d..d8d91f4f87 100644 --- a/internal/sql/ast/type_name.go +++ b/internal/sql/ast/type_name.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type TypeName struct { Catalog string Schema string @@ -20,7 +22,7 @@ func (n *TypeName) Pos() int { return n.Location } -func (n *TypeName) Format(buf *TrackedBuffer) { +func (n *TypeName) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } @@ -30,26 +32,26 @@ func (n *TypeName) Format(buf *TrackedBuffer) { first, _ := n.Names.Items[0].(*String) second, _ := n.Names.Items[1].(*String) if first != nil && second != nil { - buf.WriteString(buf.TypeName(first.Str, second.Str)) + buf.WriteString(d.TypeName(first.Str, second.Str)) goto addMods } } // For single name types, just output as-is if len(n.Names.Items) == 1 { if s, ok := n.Names.Items[0].(*String); ok { - buf.WriteString(buf.TypeName("", s.Str)) + buf.WriteString(d.TypeName("", s.Str)) goto addMods } } - buf.join(n.Names, ".") + buf.join(n.Names, d, ".") } else { - buf.WriteString(buf.TypeName(n.Schema, n.Name)) + buf.WriteString(d.TypeName(n.Schema, n.Name)) } addMods: // Add type modifiers (e.g., varchar(255)) if items(n.Typmods) { buf.WriteString("(") - buf.join(n.Typmods, ", ") + buf.join(n.Typmods, d, ", ") buf.WriteString(")") } if items(n.ArrayBounds) { diff --git a/internal/sql/ast/typedefs.go b/internal/sql/ast/typedefs.go index 46b0e66120..924fad767b 100644 --- a/internal/sql/ast/typedefs.go +++ b/internal/sql/ast/typedefs.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type AclMode uint32 func (n *AclMode) Pos() int { @@ -18,12 +20,12 @@ func (n *NullIfExpr) Pos() int { return 0 } -func (n *NullIfExpr) Format(buf *TrackedBuffer) { +func (n *NullIfExpr) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } buf.WriteString("NULLIF(") - buf.join(n.Args, ", ") + buf.join(n.Args, d, ", ") buf.WriteString(")") } diff --git a/internal/sql/ast/update_stmt.go b/internal/sql/ast/update_stmt.go index c98d422130..5376a8c6ce 100644 --- a/internal/sql/ast/update_stmt.go +++ b/internal/sql/ast/update_stmt.go @@ -1,6 +1,10 @@ package ast -import "strings" +import ( + "strings" + + "github.com/sqlc-dev/sqlc/internal/sql/format" +) type UpdateStmt struct { Relations *List @@ -16,18 +20,18 @@ func (n *UpdateStmt) Pos() int { return 0 } -func (n *UpdateStmt) Format(buf *TrackedBuffer) { +func (n *UpdateStmt) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } if n.WithClause != nil { - buf.astFormat(n.WithClause) + buf.astFormat(n.WithClause, d) buf.WriteString(" ") } buf.WriteString("UPDATE ") if items(n.Relations) { - buf.astFormat(n.Relations) + buf.astFormat(n.Relations, d) } if items(n.TargetList) { @@ -69,7 +73,7 @@ func (n *UpdateStmt) Format(buf *TrackedBuffer) { buf.WriteString("(") buf.WriteString(strings.Join(names, ",")) buf.WriteString(") = (") - buf.join(vals, ",") + buf.join(vals, d, ",") buf.WriteString(")") } else { for i, item := range n.TargetList.Items { @@ -79,18 +83,18 @@ func (n *UpdateStmt) Format(buf *TrackedBuffer) { switch nn := item.(type) { case *ResTarget: if nn.Name != nil { - buf.WriteString(buf.QuoteIdent(*nn.Name)) + buf.WriteString(d.QuoteIdent(*nn.Name)) } // Handle array subscript indirection (e.g., names[$1]) if items(nn.Indirection) { for _, ind := range nn.Indirection.Items { - buf.astFormat(ind) + buf.astFormat(ind, d) } } buf.WriteString(" = ") - buf.astFormat(nn.Val) + buf.astFormat(nn.Val, d) default: - buf.astFormat(item) + buf.astFormat(item, d) } } } @@ -98,21 +102,21 @@ func (n *UpdateStmt) Format(buf *TrackedBuffer) { if items(n.FromClause) { buf.WriteString(" FROM ") - buf.astFormat(n.FromClause) + buf.astFormat(n.FromClause, d) } if set(n.WhereClause) { buf.WriteString(" WHERE ") - buf.astFormat(n.WhereClause) + buf.astFormat(n.WhereClause, d) } if set(n.LimitCount) { buf.WriteString(" LIMIT ") - buf.astFormat(n.LimitCount) + buf.astFormat(n.LimitCount, d) } if items(n.ReturningList) { buf.WriteString(" RETURNING ") - buf.astFormat(n.ReturningList) + buf.astFormat(n.ReturningList, d) } } diff --git a/internal/sql/ast/variable_expr.go b/internal/sql/ast/variable_expr.go index 63afdf3d99..83223b482b 100644 --- a/internal/sql/ast/variable_expr.go +++ b/internal/sql/ast/variable_expr.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + // VariableExpr represents a MySQL user variable (e.g., @user_id) // This is distinct from sqlc's @param named parameter syntax. type VariableExpr struct { @@ -11,7 +13,7 @@ func (n *VariableExpr) Pos() int { return n.Location } -func (n *VariableExpr) Format(buf *TrackedBuffer) { +func (n *VariableExpr) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } diff --git a/internal/sql/ast/window_def.go b/internal/sql/ast/window_def.go index 7e9db4aeef..caba3e643c 100644 --- a/internal/sql/ast/window_def.go +++ b/internal/sql/ast/window_def.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type WindowDef struct { Name *string Refname *string @@ -17,11 +19,11 @@ func (n *WindowDef) Pos() int { // Frame option constants (from PostgreSQL's parsenodes.h) const ( - FrameOptionNonDefault = 0x00001 - FrameOptionRange = 0x00002 - FrameOptionRows = 0x00004 - FrameOptionGroups = 0x00008 - FrameOptionBetween = 0x00010 + FrameOptionNonDefault = 0x00001 + FrameOptionRange = 0x00002 + FrameOptionRows = 0x00004 + FrameOptionGroups = 0x00008 + FrameOptionBetween = 0x00010 FrameOptionStartUnboundedPreceding = 0x00020 FrameOptionEndUnboundedPreceding = 0x00040 FrameOptionStartUnboundedFollowing = 0x00080 @@ -35,7 +37,7 @@ const ( FrameOptionExcludeTies = 0x08000 ) -func (n *WindowDef) Format(buf *TrackedBuffer) { +func (n *WindowDef) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } @@ -51,7 +53,7 @@ func (n *WindowDef) Format(buf *TrackedBuffer) { if items(n.PartitionClause) { buf.WriteString("PARTITION BY ") - buf.join(n.PartitionClause, ", ") + buf.join(n.PartitionClause, d, ", ") needSpace = true } @@ -60,7 +62,7 @@ func (n *WindowDef) Format(buf *TrackedBuffer) { buf.WriteString(" ") } buf.WriteString("ORDER BY ") - buf.join(n.OrderClause, ", ") + buf.join(n.OrderClause, d, ", ") needSpace = true } @@ -89,7 +91,7 @@ func (n *WindowDef) Format(buf *TrackedBuffer) { } else if n.FrameOptions&FrameOptionStartCurrentRow != 0 { buf.WriteString("CURRENT ROW") } else if n.FrameOptions&FrameOptionStartOffset != 0 { - buf.astFormat(n.StartOffset) + buf.astFormat(n.StartOffset, d) buf.WriteString(" PRECEDING") } @@ -102,7 +104,7 @@ func (n *WindowDef) Format(buf *TrackedBuffer) { } else if n.FrameOptions&FrameOptionEndCurrentRow != 0 { buf.WriteString("CURRENT ROW") } else if n.FrameOptions&FrameOptionEndOffset != 0 { - buf.astFormat(n.EndOffset) + buf.astFormat(n.EndOffset, d) buf.WriteString(" FOLLOWING") } } diff --git a/internal/sql/ast/with_clause.go b/internal/sql/ast/with_clause.go index 86c53fb544..0def53d382 100644 --- a/internal/sql/ast/with_clause.go +++ b/internal/sql/ast/with_clause.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type WithClause struct { Ctes *List Recursive bool @@ -10,7 +12,7 @@ func (n *WithClause) Pos() int { return n.Location } -func (n *WithClause) Format(buf *TrackedBuffer) { +func (n *WithClause) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } @@ -18,5 +20,5 @@ func (n *WithClause) Format(buf *TrackedBuffer) { if n.Recursive { buf.WriteString("RECURSIVE ") } - buf.join(n.Ctes, ", ") + buf.join(n.Ctes, d, ", ") } diff --git a/internal/sql/format/format.go b/internal/sql/format/format.go index 02140757f7..b900c227ed 100644 --- a/internal/sql/format/format.go +++ b/internal/sql/format/format.go @@ -1,7 +1,7 @@ package format -// Formatter provides SQL dialect-specific formatting behavior -type Formatter interface { +// Dialect provides SQL dialect-specific formatting behavior +type Dialect interface { // QuoteIdent returns a quoted identifier if it needs quoting // (e.g., reserved words, mixed case identifiers) QuoteIdent(s string) string diff --git a/internal/sql/rewrite/CLAUDE.md b/internal/sql/rewrite/CLAUDE.md index dd6459029f..6ea885016e 100644 --- a/internal/sql/rewrite/CLAUDE.md +++ b/internal/sql/rewrite/CLAUDE.md @@ -101,4 +101,4 @@ case *ast.YourType: - MySQL: `?`, `?`, `?`, ... - SQLite: `?`, `?`, `?`, ... -The format is determined by the `Formatter.Param()` method in each engine. +The format is determined by the `Dialect.Param()` method in each engine.