Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/endtoend/fmt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type sqlParser interface {

// sqlFormatter is an interface for formatters
type sqlFormatter interface {
format.Formatter
format.Dialect
}

func TestFormat(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions internal/engine/postgresql/reserved.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func hasMixedCase(s string) bool {
}

// QuoteIdent returns a quoted identifier if it needs quoting.
// This implements the format.Formatter interface.
// This implements the format.Dialect interface.
func (p *Parser) QuoteIdent(s string) string {
if p.IsReservedKeyword(s) || hasMixedCase(s) {
return `"` + s + `"`
Expand All @@ -26,7 +26,7 @@ func (p *Parser) QuoteIdent(s string) string {
}

// TypeName returns the SQL type name for the given namespace and name.
// This implements the format.Formatter interface.
// This implements the format.Dialect interface.
func (p *Parser) TypeName(ns, name string) string {
if ns == "pg_catalog" {
switch name {
Expand Down
7 changes: 4 additions & 3 deletions internal/sql/ast/CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@ The `TrackedBuffer` type (`pg_query.go`) handles SQL formatting with dialect-spe
- `QuoteIdent(name string)` - quotes identifiers (dialect-specific)
- `TypeName(ns, name string)` - formats type names (dialect-specific)

### Formatter Interface
Dialect-specific formatting is handled via the `Formatter` interface:
### Dialect Interface
Dialect-specific formatting is handled via the `Dialect` interface:
```go
type Formatter interface {
type Dialect interface {
QuoteIdent(string) string
TypeName(ns, name string) string
Param(int) string // $1 for PostgreSQL, ? for MySQL
NamedParam(string) string // @name for PostgreSQL, :name for SQLite
Cast(string) string
}
```
Expand Down
6 changes: 4 additions & 2 deletions internal/sql/ast/a_array_expr.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package ast

import "github.com/sqlc-dev/sqlc/internal/sql/format"

type A_ArrayExpr struct {
Elements *List
Location int
Expand All @@ -9,11 +11,11 @@ func (n *A_ArrayExpr) Pos() int {
return n.Location
}

func (n *A_ArrayExpr) Format(buf *TrackedBuffer) {
func (n *A_ArrayExpr) Format(buf *TrackedBuffer, d format.Dialect) {
if n == nil {
return
}
buf.WriteString("ARRAY[")
buf.join(n.Elements, ", ")
buf.join(n.Elements, d, ", ")
buf.WriteString("]")
}
8 changes: 5 additions & 3 deletions internal/sql/ast/a_const.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package ast

import "github.com/sqlc-dev/sqlc/internal/sql/format"

type A_Const struct {
Val Node
Location int
Expand All @@ -9,15 +11,15 @@ func (n *A_Const) Pos() int {
return n.Location
}

func (n *A_Const) Format(buf *TrackedBuffer) {
func (n *A_Const) Format(buf *TrackedBuffer, d format.Dialect) {
if n == nil {
return
}
if _, ok := n.Val.(*String); ok {
buf.WriteString("'")
buf.astFormat(n.Val)
buf.astFormat(n.Val, d)
buf.WriteString("'")
} else {
buf.astFormat(n.Val)
buf.astFormat(n.Val, d)
}
}
52 changes: 27 additions & 25 deletions internal/sql/ast/a_expr.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package ast

import "github.com/sqlc-dev/sqlc/internal/sql/format"

type A_Expr struct {
Kind A_Expr_Kind
Name *List
Expand Down Expand Up @@ -31,75 +33,75 @@ func (n *A_Expr) isNamedParam() (string, bool) {
return "", false
}

func (n *A_Expr) Format(buf *TrackedBuffer) {
func (n *A_Expr) Format(buf *TrackedBuffer, d format.Dialect) {
if n == nil {
return
}

// Check for named parameter first (works regardless of Kind)
if name, ok := n.isNamedParam(); ok {
buf.WriteString(buf.NamedParam(name))
buf.WriteString(d.NamedParam(name))
return
}

switch n.Kind {
case A_Expr_Kind_IN:
buf.astFormat(n.Lexpr)
buf.astFormat(n.Lexpr, d)
buf.WriteString(" IN (")
buf.astFormat(n.Rexpr)
buf.astFormat(n.Rexpr, d)
buf.WriteString(")")
case A_Expr_Kind_LIKE:
buf.astFormat(n.Lexpr)
buf.astFormat(n.Lexpr, d)
buf.WriteString(" LIKE ")
buf.astFormat(n.Rexpr)
buf.astFormat(n.Rexpr, d)
case A_Expr_Kind_ILIKE:
buf.astFormat(n.Lexpr)
buf.astFormat(n.Lexpr, d)
buf.WriteString(" ILIKE ")
buf.astFormat(n.Rexpr)
buf.astFormat(n.Rexpr, d)
case A_Expr_Kind_SIMILAR:
buf.astFormat(n.Lexpr)
buf.astFormat(n.Lexpr, d)
buf.WriteString(" SIMILAR TO ")
buf.astFormat(n.Rexpr)
buf.astFormat(n.Rexpr, d)
case A_Expr_Kind_BETWEEN:
buf.astFormat(n.Lexpr)
buf.astFormat(n.Lexpr, d)
buf.WriteString(" BETWEEN ")
if l, ok := n.Rexpr.(*List); ok && len(l.Items) == 2 {
buf.astFormat(l.Items[0])
buf.astFormat(l.Items[0], d)
buf.WriteString(" AND ")
buf.astFormat(l.Items[1])
buf.astFormat(l.Items[1], d)
}
case A_Expr_Kind_NOT_BETWEEN:
buf.astFormat(n.Lexpr)
buf.astFormat(n.Lexpr, d)
buf.WriteString(" NOT BETWEEN ")
if l, ok := n.Rexpr.(*List); ok && len(l.Items) == 2 {
buf.astFormat(l.Items[0])
buf.astFormat(l.Items[0], d)
buf.WriteString(" AND ")
buf.astFormat(l.Items[1])
buf.astFormat(l.Items[1], d)
}
case A_Expr_Kind_DISTINCT:
buf.astFormat(n.Lexpr)
buf.astFormat(n.Lexpr, d)
buf.WriteString(" IS DISTINCT FROM ")
buf.astFormat(n.Rexpr)
buf.astFormat(n.Rexpr, d)
case A_Expr_Kind_NOT_DISTINCT:
buf.astFormat(n.Lexpr)
buf.astFormat(n.Lexpr, d)
buf.WriteString(" IS NOT DISTINCT FROM ")
buf.astFormat(n.Rexpr)
buf.astFormat(n.Rexpr, d)
case A_Expr_Kind_NULLIF:
buf.WriteString("NULLIF(")
buf.astFormat(n.Lexpr)
buf.astFormat(n.Lexpr, d)
buf.WriteString(", ")
buf.astFormat(n.Rexpr)
buf.astFormat(n.Rexpr, d)
buf.WriteString(")")
default:
// Standard operator (including A_Expr_Kind_OP)
if set(n.Lexpr) {
buf.astFormat(n.Lexpr)
buf.astFormat(n.Lexpr, d)
buf.WriteString(" ")
}
buf.astFormat(n.Name)
buf.astFormat(n.Name, d)
if set(n.Rexpr) {
buf.WriteString(" ")
buf.astFormat(n.Rexpr)
buf.astFormat(n.Rexpr, d)
}
}
}
10 changes: 6 additions & 4 deletions internal/sql/ast/a_indices.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package ast

import "github.com/sqlc-dev/sqlc/internal/sql/format"

type A_Indices struct {
IsSlice bool
Lidx Node
Expand All @@ -10,21 +12,21 @@ func (n *A_Indices) Pos() int {
return 0
}

func (n *A_Indices) Format(buf *TrackedBuffer) {
func (n *A_Indices) Format(buf *TrackedBuffer, d format.Dialect) {
if n == nil {
return
}
buf.WriteString("[")
if n.IsSlice {
if set(n.Lidx) {
buf.astFormat(n.Lidx)
buf.astFormat(n.Lidx, d)
}
buf.WriteString(":")
if set(n.Uidx) {
buf.astFormat(n.Uidx)
buf.astFormat(n.Uidx, d)
}
} else {
buf.astFormat(n.Uidx)
buf.astFormat(n.Uidx, d)
}
buf.WriteString("]")
}
4 changes: 3 additions & 1 deletion internal/sql/ast/a_star.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package ast

import "github.com/sqlc-dev/sqlc/internal/sql/format"

type A_Star struct {
}

func (n *A_Star) Pos() int {
return 0
}

func (n *A_Star) Format(buf *TrackedBuffer) {
func (n *A_Star) Format(buf *TrackedBuffer, d format.Dialect) {
if n == nil {
return
}
Expand Down
6 changes: 4 additions & 2 deletions internal/sql/ast/alias.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package ast

import "github.com/sqlc-dev/sqlc/internal/sql/format"

type Alias struct {
Aliasname *string
Colnames *List
Expand All @@ -9,7 +11,7 @@ func (n *Alias) Pos() int {
return 0
}

func (n *Alias) Format(buf *TrackedBuffer) {
func (n *Alias) Format(buf *TrackedBuffer, d format.Dialect) {
if n == nil {
return
}
Expand All @@ -18,7 +20,7 @@ func (n *Alias) Format(buf *TrackedBuffer) {
}
if items(n.Colnames) {
buf.WriteString("(")
buf.astFormat((n.Colnames))
buf.astFormat(n.Colnames, d)
buf.WriteString(")")
}
}
6 changes: 4 additions & 2 deletions internal/sql/ast/alter_table_cmd.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package ast

import "github.com/sqlc-dev/sqlc/internal/sql/format"

const (
AT_AddColumn AlterTableType = iota
AT_AlterColumnType
Expand Down Expand Up @@ -40,7 +42,7 @@ func (n *AlterTableCmd) Pos() int {
return 0
}

func (n *AlterTableCmd) Format(buf *TrackedBuffer) {
func (n *AlterTableCmd) Format(buf *TrackedBuffer, d format.Dialect) {
if n == nil {
return
}
Expand All @@ -51,5 +53,5 @@ func (n *AlterTableCmd) Format(buf *TrackedBuffer) {
buf.WriteString(" DROP COLUMN ")
}

buf.astFormat(n.Def)
buf.astFormat(n.Def, d)
}
10 changes: 6 additions & 4 deletions internal/sql/ast/alter_table_stmt.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package ast

import "github.com/sqlc-dev/sqlc/internal/sql/format"

type AlterTableStmt struct {
// TODO: Only TableName or Relation should be defined
Relation *RangeVar
Expand All @@ -13,12 +15,12 @@ func (n *AlterTableStmt) Pos() int {
return 0
}

func (n *AlterTableStmt) Format(buf *TrackedBuffer) {
func (n *AlterTableStmt) Format(buf *TrackedBuffer, d format.Dialect) {
if n == nil {
return
}
buf.WriteString("ALTER TABLE ")
buf.astFormat(n.Relation)
buf.astFormat(n.Table)
buf.astFormat(n.Cmds)
buf.astFormat(n.Relation, d)
buf.astFormat(n.Table, d)
buf.astFormat(n.Cmds, d)
}
10 changes: 6 additions & 4 deletions internal/sql/ast/between_expr.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package ast

import "github.com/sqlc-dev/sqlc/internal/sql/format"

type BetweenExpr struct {
// Expr is the value expression to be compared.
Expr Node
Expand All @@ -16,17 +18,17 @@ func (n *BetweenExpr) Pos() int {
return n.Location
}

func (n *BetweenExpr) Format(buf *TrackedBuffer) {
func (n *BetweenExpr) Format(buf *TrackedBuffer, d format.Dialect) {
if n == nil {
return
}
buf.astFormat(n.Expr)
buf.astFormat(n.Expr, d)
if n.Not {
buf.WriteString(" NOT BETWEEN ")
} else {
buf.WriteString(" BETWEEN ")
}
buf.astFormat(n.Left)
buf.astFormat(n.Left, d)
buf.WriteString(" AND ")
buf.astFormat(n.Right)
buf.astFormat(n.Right, d)
}
Loading
Loading