Skip to content
Merged
141 changes: 98 additions & 43 deletions internal/endtoend/fmt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,66 +8,121 @@ import (
"strings"
"testing"

"github.com/sqlc-dev/sqlc/internal/config"
"github.com/sqlc-dev/sqlc/internal/debug"
"github.com/sqlc-dev/sqlc/internal/engine/postgresql"
"github.com/sqlc-dev/sqlc/internal/sql/ast"
)

func TestFormat(t *testing.T) {
t.Parallel()
parse := postgresql.NewParser()
for _, tc := range FindTests(t, "testdata", "base") {
tc := tc

if !strings.Contains(tc.Path, filepath.Join("pgx/v5")) {
continue
}

q := filepath.Join(tc.Path, "query.sql")
if _, err := os.Stat(q); os.IsNotExist(err) {
continue
}

t.Run(tc.Name, func(t *testing.T) {
contents, err := os.ReadFile(q)
// Parse the config file to determine the engine
configPath := filepath.Join(tc.Path, tc.ConfigName)
configFile, err := os.Open(configPath)
if err != nil {
t.Fatal(err)
}
for i, query := range bytes.Split(bytes.TrimSpace(contents), []byte(";")) {
if len(query) <= 1 {
continue
}
query := query
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
expected, err := postgresql.Fingerprint(string(query))
if err != nil {
t.Fatal(err)
}
stmts, err := parse.Parse(bytes.NewReader(query))
if err != nil {
t.Fatal(err)
}
if len(stmts) != 1 {
t.Fatal("expected one statement")
}
if false {
r, err := postgresql.Parse(string(query))
debug.Dump(r, err)
}
conf, err := config.ParseConfig(configFile)
configFile.Close()
if err != nil {
t.Fatal(err)
}

// Skip if there are no SQL packages configured
if len(conf.SQL) == 0 {
return
}

// For now, only test PostgreSQL since that's the only engine with Format support
engine := conf.SQL[0].Engine
if engine != config.EnginePostgreSQL {
return
}

out := ast.Format(stmts[0].Raw)
actual, err := postgresql.Fingerprint(out)
// Find query files from config
var queryFiles []string
for _, sql := range conf.SQL {
for _, q := range sql.Queries {
queryPath := filepath.Join(tc.Path, q)
info, err := os.Stat(queryPath)
if err != nil {
t.Error(err)
continue
}
if expected != actual {
debug.Dump(stmts[0].Raw)
t.Errorf("- %s", expected)
t.Errorf("- %s", string(query))
t.Errorf("+ %s", actual)
t.Errorf("+ %s", out)
if info.IsDir() {
// If it's a directory, glob for .sql files
matches, err := filepath.Glob(filepath.Join(queryPath, "*.sql"))
if err != nil {
continue
}
queryFiles = append(queryFiles, matches...)
} else {
queryFiles = append(queryFiles, queryPath)
}
})
}
}

if len(queryFiles) == 0 {
return
}

parse := postgresql.NewParser()

for _, queryFile := range queryFiles {
if _, err := os.Stat(queryFile); os.IsNotExist(err) {
continue
}

contents, err := os.ReadFile(queryFile)
if err != nil {
t.Fatal(err)
}

// Parse the entire file to get proper statement boundaries
stmts, err := parse.Parse(bytes.NewReader(contents))
if err != nil {
// Skip files with parse errors (e.g., syntax_errors test cases)
return
}

for i, stmt := range stmts {
stmt := stmt
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
// Extract the original query text using statement location and length
start := stmt.Raw.StmtLocation
length := stmt.Raw.StmtLen
if length == 0 {
// If StmtLen is 0, it means the statement goes to the end of the input
length = len(contents) - start
}
query := strings.TrimSpace(string(contents[start : start+length]))

expected, err := postgresql.Fingerprint(query)
if err != nil {
t.Fatal(err)
}

if false {
r, err := postgresql.Parse(query)
debug.Dump(r, err)
}

out := ast.Format(stmt.Raw, parse)
actual, err := postgresql.Fingerprint(out)
if err != nil {
t.Error(err)
}
if expected != actual {
debug.Dump(stmt.Raw)
t.Errorf("- %s", expected)
t.Errorf("- %s", query)
t.Errorf("+ %s", actual)
t.Errorf("+ %s", out)
}
})
}
}
})
}
Expand Down
19 changes: 19 additions & 0 deletions internal/engine/postgresql/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -1965,6 +1965,22 @@ func convertNullTest(n *pg.NullTest) *ast.NullTest {
}
}

func convertNullIfExpr(n *pg.NullIfExpr) *ast.NullIfExpr {
if n == nil {
return nil
}
return &ast.NullIfExpr{
Xpr: convertNode(n.Xpr),
Opno: ast.Oid(n.Opno),
Opresulttype: ast.Oid(n.Opresulttype),
Opretset: n.Opretset,
Opcollid: ast.Oid(n.Opcollid),
Inputcollid: ast.Oid(n.Inputcollid),
Args: convertSlice(n.Args),
Location: int(n.Location),
}
}

func convertObjectWithArgs(n *pg.ObjectWithArgs) *ast.ObjectWithArgs {
if n == nil {
return nil
Expand Down Expand Up @@ -3420,6 +3436,9 @@ func convertNode(node *pg.Node) ast.Node {
case *pg.Node_NullTest:
return convertNullTest(n.NullTest)

case *pg.Node_NullIfExpr:
return convertNullIfExpr(n.NullIfExpr)

case *pg.Node_ObjectWithArgs:
return convertObjectWithArgs(n.ObjectWithArgs)

Expand Down
1 change: 1 addition & 0 deletions internal/engine/postgresql/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ func translate(node *nodes.Node) (ast.Node, error) {
ReturnType: rt,
Replace: n.Replace,
Params: &ast.List{},
Options: convertSlice(n.Options),
}
for _, item := range n.Parameters {
arg := item.Node.(*nodes.Node_FunctionParameter).FunctionParameter
Expand Down
53 changes: 53 additions & 0 deletions internal/engine/postgresql/reserved.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,59 @@ package postgresql

import "strings"

// hasMixedCase returns true if the string has any uppercase letters
// (identifiers with mixed case need quoting in PostgreSQL)
func hasMixedCase(s string) bool {
for _, r := range s {
if r >= 'A' && r <= 'Z' {
return true
}
}
return false
}

// QuoteIdent returns a quoted identifier if it needs quoting.
// This implements the format.Formatter interface.
func (p *Parser) QuoteIdent(s string) string {
if p.IsReservedKeyword(s) || hasMixedCase(s) {
return `"` + s + `"`
}
return s
}

// TypeName returns the SQL type name for the given namespace and name.
// This implements the format.Formatter interface.
func (p *Parser) TypeName(ns, name string) string {
if ns == "pg_catalog" {
switch name {
case "int4":
return "integer"
case "int8":
return "bigint"
case "int2":
return "smallint"
case "float4":
return "real"
case "float8":
return "double precision"
case "bool":
return "boolean"
case "bpchar":
return "character"
case "timestamptz":
return "timestamp with time zone"
case "timetz":
return "time with time zone"
default:
return name
}
}
if ns != "" {
return ns + "." + name
}
return name
}

// https://www.postgresql.org/docs/current/sql-keywords-appendix.html
func (p *Parser) IsReservedKeyword(s string) bool {
switch strings.ToLower(s) {
Expand Down
9 changes: 9 additions & 0 deletions internal/sql/ast/a_array_expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,12 @@ type A_ArrayExpr struct {
func (n *A_ArrayExpr) Pos() int {
return n.Location
}

func (n *A_ArrayExpr) Format(buf *TrackedBuffer) {
if n == nil {
return
}
buf.WriteString("ARRAY[")
buf.join(n.Elements, ", ")
buf.WriteString("]")
}
77 changes: 73 additions & 4 deletions internal/sql/ast/a_expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,88 @@ func (n *A_Expr) Format(buf *TrackedBuffer) {
if n == nil {
return
}
buf.astFormat(n.Lexpr)
buf.WriteString(" ")
switch n.Kind {
case A_Expr_Kind_IN:
buf.astFormat(n.Lexpr)
buf.WriteString(" IN (")
buf.astFormat(n.Rexpr)
buf.WriteString(")")
case A_Expr_Kind_LIKE:
buf.astFormat(n.Lexpr)
buf.WriteString(" LIKE ")
buf.astFormat(n.Rexpr)
case A_Expr_Kind_ILIKE:
buf.astFormat(n.Lexpr)
buf.WriteString(" ILIKE ")
buf.astFormat(n.Rexpr)
case A_Expr_Kind_SIMILAR:
buf.astFormat(n.Lexpr)
buf.WriteString(" SIMILAR TO ")
buf.astFormat(n.Rexpr)
case A_Expr_Kind_BETWEEN:
buf.astFormat(n.Lexpr)
buf.WriteString(" BETWEEN ")
if l, ok := n.Rexpr.(*List); ok && len(l.Items) == 2 {
buf.astFormat(l.Items[0])
buf.WriteString(" AND ")
buf.astFormat(l.Items[1])
}
case A_Expr_Kind_NOT_BETWEEN:
buf.astFormat(n.Lexpr)
buf.WriteString(" NOT BETWEEN ")
if l, ok := n.Rexpr.(*List); ok && len(l.Items) == 2 {
buf.astFormat(l.Items[0])
buf.WriteString(" AND ")
buf.astFormat(l.Items[1])
}
case A_Expr_Kind_DISTINCT:
buf.astFormat(n.Lexpr)
buf.WriteString(" IS DISTINCT FROM ")
buf.astFormat(n.Rexpr)
case A_Expr_Kind_NOT_DISTINCT:
buf.astFormat(n.Lexpr)
buf.WriteString(" IS NOT DISTINCT FROM ")
buf.astFormat(n.Rexpr)
case A_Expr_Kind_NULLIF:
buf.WriteString("NULLIF(")
buf.astFormat(n.Lexpr)
buf.WriteString(", ")
buf.astFormat(n.Rexpr)
buf.WriteString(")")
case A_Expr_Kind_OP:
// Check if this is a named parameter (@name)
opName := ""
if n.Name != nil && len(n.Name.Items) == 1 {
if s, ok := n.Name.Items[0].(*String); ok {
opName = s.Str
}
}
if opName == "@" && !set(n.Lexpr) && set(n.Rexpr) {
// Named parameter: @name (no space after @)
buf.WriteString("@")
buf.astFormat(n.Rexpr)
} else {
// Standard binary operator
if set(n.Lexpr) {
buf.astFormat(n.Lexpr)
buf.WriteString(" ")
}
buf.astFormat(n.Name)
if set(n.Rexpr) {
buf.WriteString(" ")
buf.astFormat(n.Rexpr)
}
}
default:
// Fallback for other cases
if set(n.Lexpr) {
buf.astFormat(n.Lexpr)
buf.WriteString(" ")
}
buf.astFormat(n.Name)
buf.WriteString(" ")
buf.astFormat(n.Rexpr)
if set(n.Rexpr) {
buf.WriteString(" ")
buf.astFormat(n.Rexpr)
}
}
}
16 changes: 14 additions & 2 deletions internal/sql/ast/a_expr_kind.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,20 @@ package ast
type A_Expr_Kind uint

const (
A_Expr_Kind_IN A_Expr_Kind = 7
A_Expr_Kind_LIKE A_Expr_Kind = 8
A_Expr_Kind_OP A_Expr_Kind = 1
A_Expr_Kind_OP_ANY A_Expr_Kind = 2
A_Expr_Kind_OP_ALL A_Expr_Kind = 3
A_Expr_Kind_DISTINCT A_Expr_Kind = 4
A_Expr_Kind_NOT_DISTINCT A_Expr_Kind = 5
A_Expr_Kind_NULLIF A_Expr_Kind = 6
A_Expr_Kind_IN A_Expr_Kind = 7
A_Expr_Kind_LIKE A_Expr_Kind = 8
A_Expr_Kind_ILIKE A_Expr_Kind = 9
A_Expr_Kind_SIMILAR A_Expr_Kind = 10
A_Expr_Kind_BETWEEN A_Expr_Kind = 11
A_Expr_Kind_NOT_BETWEEN A_Expr_Kind = 12
A_Expr_Kind_BETWEEN_SYM A_Expr_Kind = 13
A_Expr_Kind_NOT_BETWEEN_SYM A_Expr_Kind = 14
)

func (n *A_Expr_Kind) Pos() int {
Expand Down
Loading
Loading