From e244ef702a8982bd24a3cd352bbc6c9cb1fa226f Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sun, 30 Nov 2025 15:28:46 -0800 Subject: [PATCH] feat(mysql): improve AST formatting and add DELETE JOIN support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR continues the MySQL AST formatting work with several improvements: **New AST Nodes:** - `VariableExpr` - MySQL user variables (@var), distinct from sqlc @param - `IntervalExpr` - MySQL INTERVAL expressions - `OnDuplicateKeyUpdate` - MySQL ON DUPLICATE KEY UPDATE clause - `ParenExpr` - Explicit parentheses for expression grouping **DELETE with JOIN Support:** - Extended DeleteStmt with Targets and FromClause fields - Multi-table DELETE now properly formats: DELETE t1.*, t2.* FROM t1 JOIN t2... - Updated compiler/output_columns.go to handle new structure **Bug Fixes:** - MySQL @variable now preserved as-is (not treated as sqlc named parameter) - Column type lengths only output for types where meaningful (varchar, char) - Fixed sqlc.arg() handling in ON DUPLICATE KEY UPDATE clause **Documentation:** - Added CLAUDE.md files documenting AST, astutils, named, rewrite packages - Added CLAUDE.md for dolphin engine with conversion patterns 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- internal/codegen/golang/mysql_type.go | 6 +- internal/compiler/output_columns.go | 9 +- internal/endtoend/fmt_test.go | 54 ++++- internal/engine/dolphin/CLAUDE.md | 224 ++++++++++++++++++ internal/engine/dolphin/convert.go | 240 +++++++++++++++++--- internal/engine/dolphin/format.go | 36 +++ internal/engine/dolphin/stdlib.go | 26 +++ internal/engine/postgresql/reserved.go | 17 +- internal/sql/ast/CLAUDE.md | 115 ++++++++++ internal/sql/ast/between_expr.go | 15 ++ internal/sql/ast/bool_expr.go | 35 ++- internal/sql/ast/delete_stmt.go | 22 +- internal/sql/ast/func_call.go | 8 + internal/sql/ast/in.go | 27 +++ internal/sql/ast/insert_stmt.go | 20 +- internal/sql/ast/interval_expr.go | 22 ++ internal/sql/ast/on_duplicate_key_update.go | 35 +++ internal/sql/ast/param_ref.go | 4 +- internal/sql/ast/paren_expr.go | 20 ++ internal/sql/ast/print.go | 19 ++ internal/sql/ast/range_var.go | 4 +- internal/sql/ast/sub_link.go | 20 +- internal/sql/ast/type_cast.go | 11 +- internal/sql/ast/variable_expr.go | 20 ++ internal/sql/astutils/CLAUDE.md | 117 ++++++++++ internal/sql/astutils/rewrite.go | 15 ++ internal/sql/astutils/walk.go | 27 +++ internal/sql/format/format.go | 8 + internal/sql/named/CLAUDE.md | 94 ++++++++ internal/sql/rewrite/CLAUDE.md | 104 +++++++++ 30 files changed, 1304 insertions(+), 70 deletions(-) create mode 100644 internal/engine/dolphin/CLAUDE.md create mode 100644 internal/engine/dolphin/format.go create mode 100644 internal/sql/ast/CLAUDE.md create mode 100644 internal/sql/ast/interval_expr.go create mode 100644 internal/sql/ast/on_duplicate_key_update.go create mode 100644 internal/sql/ast/paren_expr.go create mode 100644 internal/sql/ast/variable_expr.go create mode 100644 internal/sql/astutils/CLAUDE.md create mode 100644 internal/sql/named/CLAUDE.md create mode 100644 internal/sql/rewrite/CLAUDE.md diff --git a/internal/codegen/golang/mysql_type.go b/internal/codegen/golang/mysql_type.go index b8e8aa43c7..252e291f58 100644 --- a/internal/codegen/golang/mysql_type.go +++ b/internal/codegen/golang/mysql_type.go @@ -64,7 +64,11 @@ func mysqlType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.C } return "sql.NullInt32" - case "bigint": + case "bigint", "bigint unsigned", "bigint signed": + // "bigint unsigned" and "bigint signed" are MySQL CAST types + // Note: We use int64 for CAST AS UNSIGNED to match original behavior, + // even though uint64 would be more semantically correct. + // The Unsigned flag on columns (from table schema) still uses uint64. if notNull { if unsigned { return "uint64" diff --git a/internal/compiler/output_columns.go b/internal/compiler/output_columns.go index b0a15e6ac4..dbd486359a 100644 --- a/internal/compiler/output_columns.go +++ b/internal/compiler/output_columns.go @@ -482,7 +482,14 @@ func (c *Compiler) sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, erro list := &ast.List{} switch n := node.(type) { case *ast.DeleteStmt: - list = n.Relations + if n.Relations != nil { + list = n.Relations + } else if n.FromClause != nil { + // Multi-table DELETE: walk FromClause to find tables + var tv tableVisitor + astutils.Walk(&tv, n.FromClause) + list = &tv.list + } case *ast.InsertStmt: list = &ast.List{ Items: []ast.Node{n.Relation}, diff --git a/internal/endtoend/fmt_test.go b/internal/endtoend/fmt_test.go index 35b475ca4f..db4aaee747 100644 --- a/internal/endtoend/fmt_test.go +++ b/internal/endtoend/fmt_test.go @@ -3,6 +3,7 @@ package main import ( "bytes" "fmt" + "io" "os" "path/filepath" "strings" @@ -10,10 +11,22 @@ import ( "github.com/sqlc-dev/sqlc/internal/config" "github.com/sqlc-dev/sqlc/internal/debug" + "github.com/sqlc-dev/sqlc/internal/engine/dolphin" "github.com/sqlc-dev/sqlc/internal/engine/postgresql" "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/format" ) +// sqlParser is an interface for SQL parsers +type sqlParser interface { + Parse(r io.Reader) ([]ast.Statement, error) +} + +// sqlFormatter is an interface for formatters +type sqlFormatter interface { + format.Formatter +} + func TestFormat(t *testing.T) { t.Parallel() for _, tc := range FindTests(t, "testdata", "base") { @@ -36,9 +49,38 @@ func TestFormat(t *testing.T) { return } - // For now, only test PostgreSQL since that's the only engine with Format support engine := conf.SQL[0].Engine - if engine != config.EnginePostgreSQL { + + // Select the appropriate parser and fingerprint function based on engine + var parse sqlParser + var formatter sqlFormatter + var fingerprint func(string) (string, error) + + switch engine { + case config.EnginePostgreSQL: + pgParser := postgresql.NewParser() + parse = pgParser + formatter = pgParser + fingerprint = postgresql.Fingerprint + case config.EngineMySQL: + mysqlParser := dolphin.NewParser() + parse = mysqlParser + formatter = mysqlParser + // For MySQL, we use a "round-trip" fingerprint: parse the SQL, format it, + // and return the formatted string. This tests that our formatting produces + // valid SQL that parses to the same AST structure. + fingerprint = func(sql string) (string, error) { + stmts, err := mysqlParser.Parse(strings.NewReader(sql)) + if err != nil { + return "", err + } + if len(stmts) == 0 { + return "", nil + } + return ast.Format(stmts[0].Raw, mysqlParser), nil + } + default: + // Skip unsupported engines return } @@ -68,8 +110,6 @@ func TestFormat(t *testing.T) { return } - parse := postgresql.NewParser() - for _, queryFile := range queryFiles { if _, err := os.Stat(queryFile); os.IsNotExist(err) { continue @@ -99,7 +139,7 @@ func TestFormat(t *testing.T) { } query := strings.TrimSpace(string(contents[start : start+length])) - expected, err := postgresql.Fingerprint(query) + expected, err := fingerprint(query) if err != nil { t.Fatal(err) } @@ -109,8 +149,8 @@ func TestFormat(t *testing.T) { debug.Dump(r, err) } - out := ast.Format(stmt.Raw, parse) - actual, err := postgresql.Fingerprint(out) + out := ast.Format(stmt.Raw, formatter) + actual, err := fingerprint(out) if err != nil { t.Error(err) } diff --git a/internal/engine/dolphin/CLAUDE.md b/internal/engine/dolphin/CLAUDE.md new file mode 100644 index 0000000000..20142fafaa --- /dev/null +++ b/internal/engine/dolphin/CLAUDE.md @@ -0,0 +1,224 @@ +# Dolphin Engine (MySQL) - Claude Code Guide + +The dolphin engine handles MySQL parsing and AST conversion using the TiDB parser. + +## Architecture + +### Parser Flow +``` +SQL String → TiDB Parser → TiDB AST → sqlc AST → Analysis/Codegen +``` + +### Key Files +- `convert.go` - Converts TiDB AST nodes to sqlc AST nodes +- `format.go` - MySQL-specific formatting (identifiers, types, parameters) +- `parse.go` - Entry point for parsing MySQL SQL + +## TiDB Parser + +The TiDB parser (`github.com/pingcap/tidb/pkg/parser`) is used for MySQL parsing: + +```go +import ( + pcast "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/types" +) +``` + +### Common TiDB Types +- `pcast.SelectStmt`, `pcast.InsertStmt`, etc. - Statement types +- `pcast.ColumnNameExpr` - Column reference +- `pcast.FuncCallExpr` - Function call +- `pcast.BinaryOperationExpr` - Binary expression +- `pcast.VariableExpr` - MySQL user variable (@var) +- `pcast.Join` - JOIN clause with Left, Right, On, Using + +## Conversion Pattern + +Each TiDB node type has a corresponding converter method: + +```go +func (c *cc) convertSelectStmt(n *pcast.SelectStmt) *ast.SelectStmt { + return &ast.SelectStmt{ + FromClause: c.convertTableRefsClause(n.From), + WhereClause: c.convert(n.Where), + // ... + } +} +``` + +The main `convert()` method dispatches to specific converters: +```go +func (c *cc) convert(node pcast.Node) ast.Node { + switch n := node.(type) { + case *pcast.SelectStmt: + return c.convertSelectStmt(n) + case *pcast.InsertStmt: + return c.convertInsertStmt(n) + // ... + } +} +``` + +## Key Conversions + +### Column References +```go +func (c *cc) convertColumnNameExpr(n *pcast.ColumnNameExpr) *ast.ColumnRef { + var items []ast.Node + if schema := n.Name.Schema.String(); schema != "" { + items = append(items, NewIdentifier(schema)) + } + if table := n.Name.Table.String(); table != "" { + items = append(items, NewIdentifier(table)) + } + items = append(items, NewIdentifier(n.Name.Name.String())) + return &ast.ColumnRef{Fields: &ast.List{Items: items}} +} +``` + +### JOINs +```go +func (c *cc) convertJoin(n *pcast.Join) *ast.List { + if n.Right != nil && n.Left != nil { + return &ast.List{ + Items: []ast.Node{&ast.JoinExpr{ + Jointype: ast.JoinType(n.Tp), + Larg: c.convert(n.Left), + Rarg: c.convert(n.Right), + Quals: c.convert(n.On), + UsingClause: convertUsing(n.Using), + }}, + } + } + // No join - just return tables + // ... +} +``` + +### MySQL User Variables +MySQL user variables (`@var`) are different from sqlc's `@param` syntax: +```go +func (c *cc) convertVariableExpr(n *pcast.VariableExpr) ast.Node { + // Use VariableExpr to preserve as-is (NOT A_Expr which would be treated as sqlc param) + return &ast.VariableExpr{ + Name: n.Name, + Location: n.OriginTextPosition(), + } +} +``` + +### Type Casts (CAST AS) +```go +func (c *cc) convertFuncCastExpr(n *pcast.FuncCastExpr) ast.Node { + typeName := types.TypeStr(n.Tp.GetType()) + // Handle UNSIGNED/SIGNED specially + if typeName == "bigint" { + if mysql.HasUnsignedFlag(n.Tp.GetFlag()) { + typeName = "bigint unsigned" + } else { + typeName = "bigint signed" + } + } + return &ast.TypeCast{ + Arg: c.convert(n.Expr), + TypeName: &ast.TypeName{Name: typeName}, + } +} +``` + +### Column Definitions +```go +func convertColumnDef(def *pcast.ColumnDef) *ast.ColumnDef { + typeName := &ast.TypeName{Name: types.TypeToStr(def.Tp.GetType(), def.Tp.GetCharset())} + + // Only add Typmods for types where length is meaningful + tp := def.Tp.GetType() + flen := def.Tp.GetFlen() + switch tp { + case mysql.TypeVarchar, mysql.TypeString, mysql.TypeVarString: + if flen >= 0 { + typeName.Typmods = &ast.List{ + Items: []ast.Node{&ast.Integer{Ival: int64(flen)}}, + } + } + // Don't add for DATETIME, TIMESTAMP - internal flen is not user-specified + } + // ... +} +``` + +### Multi-Table DELETE +MySQL supports `DELETE t1, t2 FROM t1 JOIN t2 ...`: +```go +func (c *cc) convertDeleteStmt(n *pcast.DeleteStmt) *ast.DeleteStmt { + if n.IsMultiTable && n.Tables != nil { + // Convert targets (t1.*, t2.*) + targets := &ast.List{} + for _, table := range n.Tables.Tables { + // Build ColumnRef for each target + } + stmt.Targets = targets + + // Preserve JOINs in FromClause + stmt.FromClause = c.convertTableRefsClause(n.TableRefs).Items[0] + } else { + // Single-table DELETE + stmt.Relations = c.convertTableRefsClause(n.TableRefs) + } +} +``` + +## MySQL-Specific Formatting + +### format.go +```go +func (p *Parser) TypeName(ns, name string) string { + switch name { + case "bigint unsigned": + return "UNSIGNED" + case "bigint signed": + return "SIGNED" + } + return name +} + +func (p *Parser) Param(n int) string { + return "?" // MySQL uses ? for all parameters +} +``` + +## Common Issues and Solutions + +### Issue: Panic in Walk/Apply +**Cause**: New AST node type not handled in `astutils/walk.go` or `astutils/rewrite.go` +**Solution**: Add case for the node type in both files + +### Issue: sqlc.arg() not converted in ON DUPLICATE KEY UPDATE +**Cause**: `InsertStmt` case in `rewrite.go` didn't traverse `OnDuplicateKeyUpdate` +**Solution**: Add `a.apply(n, "OnDuplicateKeyUpdate", nil, n.OnDuplicateKeyUpdate)` + +### Issue: MySQL @variable being treated as parameter +**Cause**: Converting `VariableExpr` to `A_Expr` with `@` operator +**Solution**: Use `ast.VariableExpr` instead, which is not detected by `named.IsParamSign()` + +### Issue: Type length appearing incorrectly (e.g., datetime(39)) +**Cause**: Using internal `flen` for all types +**Solution**: Only populate `Typmods` for types where length is user-specified (varchar, char, etc.) + +## Testing + +### TestFormat +Tests that SQL can be: +1. Parsed +2. Formatted back to SQL +3. Re-parsed +4. Re-formatted to match + +### TestReplay +Tests the full sqlc pipeline: +1. Parse schema and queries +2. Analyze +3. Generate code +4. Compare with expected output diff --git a/internal/engine/dolphin/convert.go b/internal/engine/dolphin/convert.go index 33b89ae8f4..1f68358ce4 100644 --- a/internal/engine/dolphin/convert.go +++ b/internal/engine/dolphin/convert.go @@ -2,6 +2,7 @@ package dolphin import ( "log" + "strconv" "strings" pcast "github.com/pingcap/tidb/pkg/parser/ast" @@ -187,8 +188,14 @@ func opToName(o opcode.Op) string { func (c *cc) convertBinaryOperationExpr(n *pcast.BinaryOperationExpr) ast.Node { if n.Op == opcode.LogicAnd || n.Op == opcode.LogicOr { + var boolop ast.BoolExprType + if n.Op == opcode.LogicAnd { + boolop = ast.BoolExprTypeAnd + } else { + boolop = ast.BoolExprTypeOr + } return &ast.BoolExpr{ - // TODO: Set op + Boolop: boolop, Args: &ast.List{ Items: []ast.Node{ c.convert(n.L), @@ -249,9 +256,36 @@ func convertColumnDef(def *pcast.ColumnDef) *ast.ColumnDef { } } } + + // Build TypeName with modifiers for proper formatting + typeName := &ast.TypeName{Name: types.TypeToStr(def.Tp.GetType(), def.Tp.GetCharset())} + + // Add type modifiers (e.g., length for varchar(255), char(32)) + // Only for types where length is meaningful and user-specified + tp := def.Tp.GetType() + flen := def.Tp.GetFlen() + needsLength := false + switch tp { + case mysql.TypeVarchar, mysql.TypeString, mysql.TypeVarString: + // VARCHAR(n), CHAR(n) - always need length + needsLength = flen >= 0 + case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: + // BLOB types - only if user specified length (VARBINARY(n), BINARY(n)) + // Default blob types don't need length + needsLength = false + } + + if needsLength { + typeName.Typmods = &ast.List{ + Items: []ast.Node{ + &ast.Integer{Ival: int64(flen)}, + }, + } + } + columnDef := ast.ColumnDef{ Colname: def.Name.String(), - TypeName: &ast.TypeName{Name: types.TypeToStr(def.Tp.GetType(), def.Tp.GetCharset())}, + TypeName: typeName, IsNotNull: isNotNull(def), IsUnsigned: isUnsigned(def), Comment: comment, @@ -294,22 +328,54 @@ func (c *cc) convertColumnNames(cols []*pcast.ColumnName) *ast.List { } func (c *cc) convertDeleteStmt(n *pcast.DeleteStmt) *ast.DeleteStmt { - rels := c.convertTableRefsClause(n.TableRefs) - if len(rels.Items) != 1 { - panic("expected one range var") - } - relations := &ast.List{} - convertToRangeVarList(rels, relations) - stmt := &ast.DeleteStmt{ - Relations: relations, WhereClause: c.convert(n.Where), ReturningList: &ast.List{}, WithClause: c.convertWithClause(n.With), } + if n.Limit != nil { stmt.LimitCount = c.convert(n.Limit.Count) } + + // Handle multi-table DELETE (DELETE t1, t2 FROM t1 JOIN t2 ...) + if n.IsMultiTable && n.Tables != nil && len(n.Tables.Tables) > 0 { + // Convert delete targets (e.g., jt.*, pt.*) + targets := &ast.List{} + for _, table := range n.Tables.Tables { + // Each table in the delete list is a ColumnRef like "jt.*" or "pt.*" + items := []ast.Node{} + if table.Schema.String() != "" { + items = append(items, NewIdentifier(table.Schema.String())) + } + items = append(items, NewIdentifier(table.Name.String())) + items = append(items, &ast.A_Star{}) + targets.Items = append(targets.Items, &ast.ColumnRef{ + Fields: &ast.List{Items: items}, + }) + } + stmt.Targets = targets + + // Convert FROM clause preserving JOINs + if n.TableRefs != nil { + fromList := c.convertTableRefsClause(n.TableRefs) + if len(fromList.Items) == 1 { + stmt.FromClause = fromList.Items[0] + } else { + stmt.FromClause = fromList + } + } + } else { + // Single-table DELETE + rels := c.convertTableRefsClause(n.TableRefs) + if len(rels.Items) != 1 { + panic("expected one range var") + } + relations := &ast.List{} + convertToRangeVarList(rels, relations) + stmt.Relations = relations + } + return stmt } @@ -333,9 +399,11 @@ func (c *cc) convertRenameTableStmt(n *pcast.RenameTableStmt) ast.Node { } func (c *cc) convertExistsSubqueryExpr(n *pcast.ExistsSubqueryExpr) *ast.SubLink { - sublink := &ast.SubLink{} - if ss, ok := c.convert(n.Sel).(*ast.SelectStmt); ok { - sublink.Subselect = ss + sublink := &ast.SubLink{ + SubLinkType: ast.EXISTS_SUBLINK, + } + if n.Sel != nil { + sublink.Subselect = c.convert(n.Sel) } return sublink } @@ -359,6 +427,33 @@ func (c *cc) convertFuncCallExpr(n *pcast.FuncCallExpr) ast.Node { } items = append(items, NewIdentifier(name)) + // Handle DATE_ADD/DATE_SUB specially to construct INTERVAL expressions + // These functions have args: [date, interval_value, TimeUnitExpr] + if (name == "date_add" || name == "date_sub") && len(n.Args) == 3 { + if timeUnit, ok := n.Args[2].(*pcast.TimeUnitExpr); ok { + args := &ast.List{ + Items: []ast.Node{ + c.convert(n.Args[0]), + &ast.IntervalExpr{ + Value: c.convert(n.Args[1]), + Unit: timeUnit.Unit.String(), + }, + }, + } + return &ast.FuncCall{ + Args: args, + Func: &ast.FuncName{ + Schema: schema, + Name: name, + }, + Funcname: &ast.List{ + Items: items, + }, + Location: n.OriginTextPosition(), + } + } + } + args := &ast.List{} for _, arg := range n.Args { args.Items = append(args.Items, c.convert(arg)) @@ -415,7 +510,7 @@ func (c *cc) convertInsertStmt(n *pcast.InsertStmt) *ast.InsertStmt { for _, a := range n.OnDuplicate { targetList.Items = append(targetList.Items, c.convertAssignment(a)) } - insert.OnConflictClause = &ast.OnConflictClause{ + insert.OnDuplicateKeyUpdate = &ast.OnDuplicateKeyUpdate{ TargetList: targetList, Location: n.OriginTextPosition(), } @@ -492,7 +587,11 @@ func (c *cc) convertSelectStmt(n *pcast.SelectStmt) *ast.SelectStmt { } func (c *cc) convertSubqueryExpr(n *pcast.SubqueryExpr) ast.Node { - return c.convert(n.Query) + // Wrap subquery in SubLink to ensure parentheses are added + return &ast.SubLink{ + SubLinkType: ast.EXPR_SUBLINK, + Subselect: c.convert(n.Query), + } } func (c *cc) convertTableRefsClause(n *pcast.TableRefsClause) *ast.List { @@ -514,9 +613,17 @@ func (c *cc) convertCommonTableExpression(n *pcast.CommonTableExpression) *ast.C columns.Items = append(columns.Items, NewIdentifier(col.String())) } + // CTE Query is wrapped in SubqueryExpr by TiDB parser. + // We need to unwrap it to get the SelectStmt directly, + // otherwise it would be double-wrapped with parentheses. + var cteQuery ast.Node + if n.Query != nil { + cteQuery = c.convert(n.Query.Query) + } + return &ast.CommonTableExpr{ Ctename: &name, - Ctequery: c.convert(n.Query), + Ctequery: cteQuery, Ctecolnames: columns, } } @@ -596,7 +703,7 @@ func (c *cc) convertValueExpr(n *driver.ValueExpr) *ast.A_Const { mysql.TypeNewDecimal: return &ast.A_Const{ Val: &ast.Float{ - // TODO: Extract the value from n.TexprNode + Str: strconv.FormatFloat(n.Datum.GetFloat64(), 'f', -1, 64), }, Location: n.OriginTextPosition(), } @@ -643,7 +750,21 @@ func (c *cc) convertAggregateFuncExpr(n *pcast.AggregateFuncExpr) *ast.FuncCall Args: &ast.List{}, AggOrder: &ast.List{}, } - for _, a := range n.Args { + + // GROUP_CONCAT has special handling: + // TiDB always adds the separator as the last argument + // We need to extract it and use SEPARATOR syntax + args := n.Args + var separator string + if name == "group_concat" && len(args) >= 2 { + // The last arg is always the separator + if value, ok := args[len(args)-1].(*driver.ValueExpr); ok { + separator = value.GetString() + args = args[:len(args)-1] + } + } + + for _, a := range args { if value, ok := a.(*driver.ValueExpr); ok { if value.GetInt64() == int64(1) { fn.AggStar = true @@ -655,6 +776,12 @@ func (c *cc) convertAggregateFuncExpr(n *pcast.AggregateFuncExpr) *ast.FuncCall if n.Distinct { fn.AggDistinct = true } + + // Store separator for GROUP_CONCAT (only if non-default) + if name == "group_concat" && separator != "" && separator != "," { + fn.Separator = &separator + } + return fn } @@ -871,9 +998,21 @@ func (c *cc) convertFrameClause(n *pcast.FrameClause) ast.Node { } func (c *cc) convertFuncCastExpr(n *pcast.FuncCastExpr) ast.Node { + typeName := types.TypeStr(n.Tp.GetType()) + + // MySQL CAST AS UNSIGNED/SIGNED uses bigint internally. + // We need to preserve the signed/unsigned info for formatting. + if typeName == "bigint" { + if mysql.HasUnsignedFlag(n.Tp.GetFlag()) { + typeName = "bigint unsigned" + } else { + typeName = "bigint signed" + } + } + return &ast.TypeCast{ Arg: c.convert(n.Expr), - TypeName: &ast.TypeName{Name: types.TypeStr(n.Tp.GetType())}, + TypeName: &ast.TypeName{Name: typeName}, } } @@ -949,12 +1088,24 @@ func (c *cc) convertJoin(n *pcast.Join) *ast.List { joinType++ } + // Convert USING clause + var usingClause *ast.List + if len(n.Using) > 0 { + items := make([]ast.Node, len(n.Using)) + for i, col := range n.Using { + items[i] = &ast.String{Str: col.Name.O} + } + usingClause = &ast.List{Items: items} + } + return &ast.List{ Items: []ast.Node{&ast.JoinExpr{ - Jointype: joinType, - Larg: c.convert(n.Left), - Rarg: c.convert(n.Right), - Quals: c.convert(n.On), + Jointype: joinType, + IsNatural: n.NaturalJoin, + Larg: c.convert(n.Left), + Rarg: c.convert(n.Right), + UsingClause: usingClause, + Quals: c.convert(n.On), }}, } } @@ -1049,7 +1200,16 @@ func (c *cc) convertParenthesesExpr(n *pcast.ParenthesesExpr) ast.Node { if n == nil { return nil } - return c.convert(n.Expr) + inner := c.convert(n.Expr) + // Only wrap in ParenExpr for SELECT statements (needed for UNION with parenthesized subqueries) + // For other expressions, the BoolExpr already adds parentheses + if _, ok := inner.(*ast.SelectStmt); ok { + return &ast.ParenExpr{ + Expr: inner, + Location: n.OriginTextPosition(), + } + } + return inner } func (c *cc) convertPartitionByClause(n *pcast.PartitionByClause) ast.Node { @@ -1100,7 +1260,7 @@ func (c *cc) convertPatternRegexpExpr(n *pcast.PatternRegexpExpr) ast.Node { } func (c *cc) convertPositionExpr(n *pcast.PositionExpr) ast.Node { - return todo(n) + return &ast.Integer{Ival: int64(n.N)} } func (c *cc) convertPrepareStmt(n *pcast.PrepareStmt) ast.Node { @@ -1205,7 +1365,28 @@ func (c *cc) convertSetOprSelectList(n *pcast.SetOprSelectList) ast.Node { case *pcast.SelectStmt: selectStmts[i] = c.convertSelectStmt(node) case *pcast.SetOprSelectList: - selectStmts[i] = c.convertSetOprSelectList(node).(*ast.SelectStmt) + // If this is a single-select SetOprSelectList (e.g., from parenthesized SELECT), + // extract the inner select instead of building a UNION tree + if len(node.Selects) == 1 { + if innerSelect, ok := node.Selects[0].(*pcast.SelectStmt); ok { + selectStmts[i] = c.convertSelectStmt(innerSelect) + } else { + selectStmts[i] = c.convertSetOprSelectList(node).(*ast.SelectStmt) + } + } else { + selectStmts[i] = c.convertSetOprSelectList(node).(*ast.SelectStmt) + } + default: + // Handle other node types like ParenthesesExpr wrapping a SELECT + converted := c.convert(node) + if ss, ok := converted.(*ast.SelectStmt); ok { + selectStmts[i] = ss + } else if pe, ok := converted.(*ast.ParenExpr); ok { + // Unwrap ParenExpr to get the inner SelectStmt + if inner, ok := pe.Expr.(*ast.SelectStmt); ok { + selectStmts[i] = inner + } + } } } @@ -1396,7 +1577,12 @@ func (c *cc) convertVariableAssignment(n *pcast.VariableAssignment) ast.Node { } func (c *cc) convertVariableExpr(n *pcast.VariableExpr) ast.Node { - return todo(n) + // MySQL @variable references are user-defined variables, NOT sqlc named parameters. + // Use VariableExpr to preserve them as-is in the output. + return &ast.VariableExpr{ + Name: n.Name, + Location: n.OriginTextPosition(), + } } func (c *cc) convertWhenClause(n *pcast.WhenClause) ast.Node { diff --git a/internal/engine/dolphin/format.go b/internal/engine/dolphin/format.go new file mode 100644 index 0000000000..458ae02363 --- /dev/null +++ b/internal/engine/dolphin/format.go @@ -0,0 +1,36 @@ +package dolphin + +// QuoteIdent returns a quoted identifier if it needs quoting. +// MySQL uses backticks for quoting identifiers. +func (p *Parser) QuoteIdent(s string) string { + // For now, don't quote - MySQL is less strict about quoting + return s +} + +// TypeName returns the SQL type name for the given namespace and name. +// Handles MySQL-specific type name mappings for formatting. +func (p *Parser) TypeName(ns, name string) string { + if ns != "" { + return ns + "." + name + } + // Map internal type names to MySQL CAST-compatible names for formatting + switch name { + case "bigint unsigned": + return "UNSIGNED" + case "bigint signed": + return "SIGNED" + } + return name +} + +// Param returns the parameter placeholder for the given number. +// MySQL uses ? for all parameters (positional). +func (p *Parser) Param(n int) string { + return "?" +} + +// Cast returns a type cast expression. +// MySQL uses CAST(expr AS type) syntax. +func (p *Parser) Cast(arg, typeName string) string { + return "CAST(" + arg + " AS " + typeName + ")" +} diff --git a/internal/engine/dolphin/stdlib.go b/internal/engine/dolphin/stdlib.go index 41469ca49d..46ce500eb5 100644 --- a/internal/engine/dolphin/stdlib.go +++ b/internal/engine/dolphin/stdlib.go @@ -636,6 +636,19 @@ func defaultSchema(name string) *catalog.Schema { }, ReturnType: &ast.TypeName{Name: "date"}, }, + { + // DATE_ADD with INTERVAL expression (2 args) + Name: "DATE_ADD", + Args: []*catalog.Argument{ + { + Type: &ast.TypeName{Name: "date"}, + }, + { + Type: &ast.TypeName{Name: "interval"}, + }, + }, + ReturnType: &ast.TypeName{Name: "date"}, + }, { Name: "DATE_ADD_INTERVAL", Args: []*catalog.Argument{ @@ -675,6 +688,19 @@ func defaultSchema(name string) *catalog.Schema { }, ReturnType: &ast.TypeName{Name: "date"}, }, + { + // DATE_SUB with INTERVAL expression (2 args) + Name: "DATE_SUB", + Args: []*catalog.Argument{ + { + Type: &ast.TypeName{Name: "date"}, + }, + { + Type: &ast.TypeName{Name: "interval"}, + }, + }, + ReturnType: &ast.TypeName{Name: "date"}, + }, { Name: "DATE_SUB_INTERVAL", Args: []*catalog.Argument{ diff --git a/internal/engine/postgresql/reserved.go b/internal/engine/postgresql/reserved.go index 0be5c54b8d..9254bfdb82 100644 --- a/internal/engine/postgresql/reserved.go +++ b/internal/engine/postgresql/reserved.go @@ -1,6 +1,9 @@ package postgresql -import "strings" +import ( + "fmt" + "strings" +) // hasMixedCase returns true if the string has any uppercase letters // (identifiers with mixed case need quoting in PostgreSQL) @@ -55,6 +58,18 @@ func (p *Parser) TypeName(ns, name string) string { return name } +// Param returns the parameter placeholder for the given number. +// PostgreSQL uses $1, $2, etc. +func (p *Parser) Param(n int) string { + return fmt.Sprintf("$%d", n) +} + +// Cast returns a type cast expression. +// PostgreSQL uses expr::type syntax. +func (p *Parser) Cast(arg, typeName string) string { + return arg + "::" + typeName +} + // https://www.postgresql.org/docs/current/sql-keywords-appendix.html func (p *Parser) IsReservedKeyword(s string) bool { switch strings.ToLower(s) { diff --git a/internal/sql/ast/CLAUDE.md b/internal/sql/ast/CLAUDE.md new file mode 100644 index 0000000000..c55f1340ee --- /dev/null +++ b/internal/sql/ast/CLAUDE.md @@ -0,0 +1,115 @@ +# AST Package - Claude Code Guide + +This package defines the Abstract Syntax Tree (AST) nodes used by sqlc to represent SQL statements across all supported databases (PostgreSQL, MySQL, SQLite). + +## Key Concepts + +### Node Interface +All AST nodes implement the `Node` interface with: +- `Pos() int` - returns the source position +- `Format(buf *TrackedBuffer)` - formats the node back to SQL + +### TrackedBuffer +The `TrackedBuffer` type (`pg_query.go`) handles SQL formatting with dialect-specific behavior: +- `astFormat(node Node)` - formats any AST node +- `join(list *List, sep string)` - joins list items with separator +- `WriteString(s string)` - writes raw SQL +- `QuoteIdent(name string)` - quotes identifiers (dialect-specific) +- `TypeName(ns, name string)` - formats type names (dialect-specific) + +### Formatter Interface +Dialect-specific formatting is handled via the `Formatter` interface: +```go +type Formatter interface { + QuoteIdent(string) string + TypeName(ns, name string) string + Param(int) string // $1 for PostgreSQL, ? for MySQL + Cast(string) string +} +``` + +## Adding New AST Nodes + +When adding a new AST node type: + +1. **Create the node file** (e.g., `variable_expr.go`): +```go +package ast + +type VariableExpr struct { + Name string + Location int +} + +func (n *VariableExpr) Pos() int { + return n.Location +} + +func (n *VariableExpr) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("@") + buf.WriteString(n.Name) +} +``` + +2. **Add to `astutils/walk.go`** - Add a case in the Walk function: +```go +case *ast.VariableExpr: + // Leaf node - no children to traverse +``` + +3. **Add to `astutils/rewrite.go`** - Add a case in the Apply function: +```go +case *ast.VariableExpr: + // Leaf node - no children to traverse +``` + +4. **Update the parser/converter** - In the relevant engine (e.g., `dolphin/convert.go` for MySQL) + +## Helper Functions for Format Methods + +- `set(node Node) bool` - returns true if node is non-nil and not an empty List +- `items(list *List) bool` - returns true if list has items +- `todo(node) Node` - placeholder for unimplemented conversions (returns nil) + +## Common Node Types + +### Statements +- `SelectStmt` - SELECT queries with FromClause, WhereClause, etc. +- `InsertStmt` - INSERT with Relation, Cols, SelectStmt, OnConflictClause +- `UpdateStmt` - UPDATE with Relations, TargetList, WhereClause +- `DeleteStmt` - DELETE with Relations, FromClause (for JOINs), Targets + +### Expressions +- `A_Expr` - General expression with operator (e.g., `a + b`, `@param`) +- `ColumnRef` - Column reference with Fields list +- `FuncCall` - Function call with Func, Args, aggregation options +- `TypeCast` - Type cast with Arg and TypeName +- `ParenExpr` - Parenthesized expression +- `VariableExpr` - MySQL user variable (e.g., `@user_id`) + +### Table References +- `RangeVar` - Table reference with schema, name, alias +- `JoinExpr` - JOIN with Larg, Rarg, Jointype, Quals/UsingClause + +## MySQL-Specific Nodes + +- `VariableExpr` - User variables (`@var`), distinct from sqlc's `@param` syntax +- `IntervalExpr` - INTERVAL expressions +- `OnDuplicateKeyUpdate` - MySQL's ON DUPLICATE KEY UPDATE clause +- `ParenExpr` - Explicit parentheses (TiDB parser wraps expressions) + +## Important Distinctions + +### MySQL @variable vs sqlc @param +- MySQL user variables (`@user_id`) use `VariableExpr` - preserved as-is in output +- sqlc named parameters (`@param`) use `A_Expr` with `@` operator - replaced with `?` +- The `named.IsParamSign()` function checks for `A_Expr` with `@` operator + +### Type Modifiers +- `TypeName.Typmods` holds type modifiers like `varchar(255)` +- For MySQL, only populate Typmods for types where length is user-specified: + - VARCHAR, CHAR, VARBINARY, BINARY - need length + - DATETIME, TIMESTAMP, DATE - internal flen should NOT be output diff --git a/internal/sql/ast/between_expr.go b/internal/sql/ast/between_expr.go index 0811caee31..aa18e6b82a 100644 --- a/internal/sql/ast/between_expr.go +++ b/internal/sql/ast/between_expr.go @@ -15,3 +15,18 @@ type BetweenExpr struct { func (n *BetweenExpr) Pos() int { return n.Location } + +func (n *BetweenExpr) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.astFormat(n.Expr) + if n.Not { + buf.WriteString(" NOT BETWEEN ") + } else { + buf.WriteString(" BETWEEN ") + } + buf.astFormat(n.Left) + buf.WriteString(" AND ") + buf.astFormat(n.Right) +} diff --git a/internal/sql/ast/bool_expr.go b/internal/sql/ast/bool_expr.go index 6d15276a05..9bbddfd7dc 100644 --- a/internal/sql/ast/bool_expr.go +++ b/internal/sql/ast/bool_expr.go @@ -15,17 +15,30 @@ func (n *BoolExpr) Format(buf *TrackedBuffer) { if n == nil { return } - buf.WriteString("(") - if items(n.Args) { - switch n.Boolop { - case BoolExprTypeAnd: - buf.join(n.Args, " AND ") - case BoolExprTypeOr: - buf.join(n.Args, " OR ") - case BoolExprTypeNot: - buf.WriteString(" NOT ") - buf.astFormat(n.Args) + switch n.Boolop { + case BoolExprTypeIsNull: + if items(n.Args) && len(n.Args.Items) > 0 { + buf.astFormat(n.Args.Items[0]) } + buf.WriteString(" IS NULL") + case BoolExprTypeIsNotNull: + if items(n.Args) && len(n.Args.Items) > 0 { + buf.astFormat(n.Args.Items[0]) + } + buf.WriteString(" IS NOT NULL") + default: + buf.WriteString("(") + if items(n.Args) { + switch n.Boolop { + case BoolExprTypeAnd: + buf.join(n.Args, " AND ") + case BoolExprTypeOr: + buf.join(n.Args, " OR ") + case BoolExprTypeNot: + buf.WriteString(" NOT ") + buf.astFormat(n.Args) + } + } + buf.WriteString(")") } - buf.WriteString(")") } diff --git a/internal/sql/ast/delete_stmt.go b/internal/sql/ast/delete_stmt.go index 45c2621095..828274978e 100644 --- a/internal/sql/ast/delete_stmt.go +++ b/internal/sql/ast/delete_stmt.go @@ -7,6 +7,9 @@ type DeleteStmt struct { LimitCount Node ReturningList *List WithClause *WithClause + // MySQL multi-table DELETE support + Targets *List // Tables to delete from (e.g., jt.*, pt.*) + FromClause Node // FROM clause with JOINs (Node to support JoinExpr) } func (n *DeleteStmt) Pos() int { @@ -23,9 +26,22 @@ func (n *DeleteStmt) Format(buf *TrackedBuffer) { buf.WriteString(" ") } - buf.WriteString("DELETE FROM ") - if items(n.Relations) { - buf.astFormat(n.Relations) + buf.WriteString("DELETE ") + + // MySQL multi-table DELETE: DELETE t1.*, t2.* FROM t1 JOIN t2 ... + if items(n.Targets) { + buf.join(n.Targets, ", ") + buf.WriteString(" FROM ") + if set(n.FromClause) { + buf.astFormat(n.FromClause) + } else if items(n.Relations) { + buf.astFormat(n.Relations) + } + } else { + buf.WriteString("FROM ") + if items(n.Relations) { + buf.astFormat(n.Relations) + } } if items(n.UsingClause) { diff --git a/internal/sql/ast/func_call.go b/internal/sql/ast/func_call.go index 3b7dcc5400..5f4857a679 100644 --- a/internal/sql/ast/func_call.go +++ b/internal/sql/ast/func_call.go @@ -11,6 +11,7 @@ type FuncCall struct { AggDistinct bool FuncVariadic bool Over *WindowDef + Separator *string // MySQL GROUP_CONCAT SEPARATOR Location int } @@ -37,6 +38,13 @@ func (n *FuncCall) Format(buf *TrackedBuffer) { buf.WriteString(" ORDER BY ") buf.join(n.AggOrder, ", ") } + // SEPARATOR for GROUP_CONCAT (MySQL) + if n.Separator != nil { + buf.WriteString(" SEPARATOR ") + buf.WriteString("'") + buf.WriteString(*n.Separator) + buf.WriteString("'") + } buf.WriteString(")") // WITHIN GROUP clause for ordered-set aggregates if items(n.AggOrder) && n.AggWithinGroup { diff --git a/internal/sql/ast/in.go b/internal/sql/ast/in.go index e11b2086a1..68bd038ad3 100644 --- a/internal/sql/ast/in.go +++ b/internal/sql/ast/in.go @@ -17,3 +17,30 @@ type In struct { func (n *In) Pos() int { return n.Location } + +// Format formats the In expression. +func (n *In) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.astFormat(n.Expr) + if n.Not { + buf.WriteString(" NOT IN ") + } else { + buf.WriteString(" IN ") + } + if n.Sel != nil { + buf.WriteString("(") + buf.astFormat(n.Sel) + buf.WriteString(")") + } else if len(n.List) > 0 { + buf.WriteString("(") + for i, item := range n.List { + if i > 0 { + buf.WriteString(", ") + } + buf.astFormat(item) + } + buf.WriteString(")") + } +} diff --git a/internal/sql/ast/insert_stmt.go b/internal/sql/ast/insert_stmt.go index cbf480b187..f287df4ae7 100644 --- a/internal/sql/ast/insert_stmt.go +++ b/internal/sql/ast/insert_stmt.go @@ -1,13 +1,14 @@ package ast type InsertStmt struct { - Relation *RangeVar - Cols *List - SelectStmt Node - OnConflictClause *OnConflictClause - ReturningList *List - WithClause *WithClause - Override OverridingKind + Relation *RangeVar + Cols *List + SelectStmt Node + OnConflictClause *OnConflictClause + OnDuplicateKeyUpdate *OnDuplicateKeyUpdate // MySQL-specific + ReturningList *List + WithClause *WithClause + Override OverridingKind } func (n *InsertStmt) Pos() int { @@ -44,6 +45,11 @@ func (n *InsertStmt) Format(buf *TrackedBuffer) { buf.astFormat(n.OnConflictClause) } + if n.OnDuplicateKeyUpdate != nil { + buf.WriteString(" ") + buf.astFormat(n.OnDuplicateKeyUpdate) + } + if items(n.ReturningList) { buf.WriteString(" RETURNING ") buf.astFormat(n.ReturningList) diff --git a/internal/sql/ast/interval_expr.go b/internal/sql/ast/interval_expr.go new file mode 100644 index 0000000000..0572dc6d70 --- /dev/null +++ b/internal/sql/ast/interval_expr.go @@ -0,0 +1,22 @@ +package ast + +// IntervalExpr represents a MySQL INTERVAL expression like "INTERVAL 1 DAY" +type IntervalExpr struct { + Value Node + Unit string + Location int +} + +func (n *IntervalExpr) Pos() int { + return n.Location +} + +func (n *IntervalExpr) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("INTERVAL ") + buf.astFormat(n.Value) + buf.WriteString(" ") + buf.WriteString(n.Unit) +} diff --git a/internal/sql/ast/on_duplicate_key_update.go b/internal/sql/ast/on_duplicate_key_update.go new file mode 100644 index 0000000000..ad5b7672d1 --- /dev/null +++ b/internal/sql/ast/on_duplicate_key_update.go @@ -0,0 +1,35 @@ +package ast + +// OnDuplicateKeyUpdate represents MySQL's ON DUPLICATE KEY UPDATE clause +type OnDuplicateKeyUpdate struct { + // TargetList contains the assignments (column = value pairs) + TargetList *List + Location int +} + +func (n *OnDuplicateKeyUpdate) Pos() int { + return n.Location +} + +func (n *OnDuplicateKeyUpdate) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("ON DUPLICATE KEY UPDATE ") + if n.TargetList != nil { + for i, item := range n.TargetList.Items { + if i > 0 { + buf.WriteString(", ") + } + if rt, ok := item.(*ResTarget); ok { + if rt.Name != nil { + buf.WriteString(*rt.Name) + } + buf.WriteString(" = ") + buf.astFormat(rt.Val) + } else { + buf.astFormat(item) + } + } + } +} diff --git a/internal/sql/ast/param_ref.go b/internal/sql/ast/param_ref.go index 8bd724993d..0558f78bdf 100644 --- a/internal/sql/ast/param_ref.go +++ b/internal/sql/ast/param_ref.go @@ -1,7 +1,5 @@ package ast -import "fmt" - type ParamRef struct { Number int Location int @@ -16,5 +14,5 @@ func (n *ParamRef) Format(buf *TrackedBuffer) { if n == nil { return } - fmt.Fprintf(buf, "$%d", n.Number) + buf.WriteString(buf.Param(n.Number)) } diff --git a/internal/sql/ast/paren_expr.go b/internal/sql/ast/paren_expr.go new file mode 100644 index 0000000000..ee57ac55d7 --- /dev/null +++ b/internal/sql/ast/paren_expr.go @@ -0,0 +1,20 @@ +package ast + +// ParenExpr represents a parenthesized expression +type ParenExpr struct { + Expr Node + Location int +} + +func (n *ParenExpr) Pos() int { + return n.Location +} + +func (n *ParenExpr) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("(") + buf.astFormat(n.Expr) + buf.WriteString(")") +} diff --git a/internal/sql/ast/print.go b/internal/sql/ast/print.go index 8db19ba7d1..c4390a15c5 100644 --- a/internal/sql/ast/print.go +++ b/internal/sql/ast/print.go @@ -1,6 +1,7 @@ package ast import ( + "fmt" "strings" "github.com/sqlc-dev/sqlc/internal/debug" @@ -46,6 +47,24 @@ func (t *TrackedBuffer) TypeName(ns, name string) string { return name } +// Param returns the parameter placeholder for the given number. +// If no formatter is set, it returns PostgreSQL-style $n. +func (t *TrackedBuffer) Param(n int) string { + if t.formatter != nil { + return t.formatter.Param(n) + } + return fmt.Sprintf("$%d", n) +} + +// Cast returns a type cast expression. +// If no formatter is set, it returns PostgreSQL-style expr::type. +func (t *TrackedBuffer) Cast(arg, typeName string) string { + if t.formatter != nil { + return t.formatter.Cast(arg, typeName) + } + return arg + "::" + typeName +} + func (t *TrackedBuffer) astFormat(n Node) { if ft, ok := n.(nodeFormatter); ok { ft.Format(t) diff --git a/internal/sql/ast/range_var.go b/internal/sql/ast/range_var.go index b7fb316ee9..5fd6db535f 100644 --- a/internal/sql/ast/range_var.go +++ b/internal/sql/ast/range_var.go @@ -18,7 +18,7 @@ func (n *RangeVar) Format(buf *TrackedBuffer) { if n == nil { return } - if n.Schemaname != nil { + if n.Schemaname != nil && *n.Schemaname != "" { buf.WriteString(buf.QuoteIdent(*n.Schemaname)) buf.WriteString(".") } @@ -26,7 +26,7 @@ func (n *RangeVar) Format(buf *TrackedBuffer) { buf.WriteString(buf.QuoteIdent(*n.Relname)) } if n.Alias != nil { - buf.WriteString(" ") + buf.WriteString(" AS ") buf.astFormat(n.Alias) } } diff --git a/internal/sql/ast/sub_link.go b/internal/sql/ast/sub_link.go index 9463f98c54..369b41ed86 100644 --- a/internal/sql/ast/sub_link.go +++ b/internal/sql/ast/sub_link.go @@ -31,14 +31,26 @@ func (n *SubLink) Format(buf *TrackedBuffer) { if n == nil { return } - buf.astFormat(n.Testexpr) + // Format the test expression if present (for IN subqueries etc.) + hasTestExpr := n.Testexpr != nil + if hasTestExpr { + buf.astFormat(n.Testexpr) + } switch n.SubLinkType { case EXISTS_SUBLINK: - buf.WriteString(" EXISTS (") + buf.WriteString("EXISTS (") case ANY_SUBLINK: - buf.WriteString(" IN (") + if hasTestExpr { + buf.WriteString(" IN (") + } else { + buf.WriteString("IN (") + } default: - buf.WriteString(" (") + if hasTestExpr { + buf.WriteString(" (") + } else { + buf.WriteString("(") + } } buf.astFormat(n.Subselect) buf.WriteString(")") diff --git a/internal/sql/ast/type_cast.go b/internal/sql/ast/type_cast.go index 0b549eb4b1..163d145dbc 100644 --- a/internal/sql/ast/type_cast.go +++ b/internal/sql/ast/type_cast.go @@ -14,7 +14,12 @@ func (n *TypeCast) Format(buf *TrackedBuffer) { if n == nil { return } - buf.astFormat(n.Arg) - buf.WriteString("::") - buf.astFormat(n.TypeName) + // Format the arg and type to strings first + argBuf := NewTrackedBuffer(buf.formatter) + argBuf.astFormat(n.Arg) + + typeBuf := NewTrackedBuffer(buf.formatter) + typeBuf.astFormat(n.TypeName) + + buf.WriteString(buf.Cast(argBuf.String(), typeBuf.String())) } diff --git a/internal/sql/ast/variable_expr.go b/internal/sql/ast/variable_expr.go new file mode 100644 index 0000000000..63afdf3d99 --- /dev/null +++ b/internal/sql/ast/variable_expr.go @@ -0,0 +1,20 @@ +package ast + +// VariableExpr represents a MySQL user variable (e.g., @user_id) +// This is distinct from sqlc's @param named parameter syntax. +type VariableExpr struct { + Name string + Location int +} + +func (n *VariableExpr) Pos() int { + return n.Location +} + +func (n *VariableExpr) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("@") + buf.WriteString(n.Name) +} diff --git a/internal/sql/astutils/CLAUDE.md b/internal/sql/astutils/CLAUDE.md new file mode 100644 index 0000000000..b7903542c5 --- /dev/null +++ b/internal/sql/astutils/CLAUDE.md @@ -0,0 +1,117 @@ +# AST Utilities Package - Claude Code Guide + +This package provides utilities for traversing and transforming AST nodes. + +## Key Functions + +### Walk +`Walk(f Visitor, node ast.Node)` traverses the AST depth-first, calling `f.Visit()` on each node. + +```go +type Visitor interface { + Visit(node ast.Node) Visitor +} +``` + +**Important**: When adding new AST node types, you MUST add a case to the switch statement in `walk.go`, otherwise you'll get a panic: +``` +panic: walk: unexpected node type *ast.YourNewType +``` + +### Apply (Rewrite) +`Apply(root ast.Node, pre, post ApplyFunc) ast.Node` traverses and optionally transforms the AST. + +```go +type ApplyFunc func(*Cursor) bool +``` + +The `Cursor` provides: +- `Node()` - current node +- `Parent()` - parent node +- `Name()` - field name in parent +- `Index()` - index if in a list +- `Replace(node)` - replace current node + +**Important**: When adding new AST node types, you MUST add a case to the switch statement in `rewrite.go`, otherwise you'll get a panic: +``` +panic: Apply: unexpected node type *ast.YourNewType +``` + +### Search +`Search(root ast.Node, fn func(ast.Node) bool) *ast.List` finds all nodes matching a predicate. + +### Join +`Join(list *ast.List, sep string) string` joins string nodes with a separator. + +## Adding Support for New AST Nodes + +When you create a new AST node type, you must update BOTH `walk.go` and `rewrite.go`: + +### In walk.go +Add a case that walks all child nodes: +```go +case *ast.YourNewType: + if n.ChildField != nil { + Walk(f, n.ChildField) + } + if n.ChildList != nil { + Walk(f, n.ChildList) + } +``` + +For leaf nodes with no children: +```go +case *ast.YourNewType: + // Leaf node - no children to traverse +``` + +### In rewrite.go +Add a case that applies to all child nodes: +```go +case *ast.YourNewType: + a.apply(n, "ChildField", nil, n.ChildField) + a.apply(n, "ChildList", nil, n.ChildList) +``` + +For leaf nodes: +```go +case *ast.YourNewType: + // Leaf node - no children to traverse +``` + +## Common Patterns + +### Finding All Tables in a Statement +```go +var tv tableVisitor +astutils.Walk(&tv, stmt.FromClause) +// tv.list now contains all RangeVar nodes +``` + +### Replacing Named Parameters +The `rewrite/parameters.go` uses Apply to replace `sqlc.arg()` calls with `ParamRef`: +```go +astutils.Apply(root, func(cr *astutils.Cursor) bool { + if named.IsParamFunc(cr.Node()) { + cr.Replace(&ast.ParamRef{Number: nextParam()}) + } + return true +}, nil) +``` + +## Node Types That Must Be Handled + +All node types in `internal/sql/ast/` must have cases in both walk.go and rewrite.go. Key MySQL-specific nodes: +- `IntervalExpr` - INTERVAL expressions +- `OnDuplicateKeyUpdate` - MySQL ON DUPLICATE KEY UPDATE +- `ParenExpr` - Parenthesized expressions +- `VariableExpr` - MySQL user variables (@var) + +## Debugging Tips + +If you see a panic like: +``` +panic: walk: unexpected node type *ast.SomeType +``` + +Check that `SomeType` has a case in both `walk.go` and `rewrite.go`. diff --git a/internal/sql/astutils/rewrite.go b/internal/sql/astutils/rewrite.go index 93c5be3cfb..fc7996b5f5 100644 --- a/internal/sql/astutils/rewrite.go +++ b/internal/sql/astutils/rewrite.go @@ -687,6 +687,8 @@ func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast. a.apply(n, "WhereClause", nil, n.WhereClause) a.apply(n, "ReturningList", nil, n.ReturningList) a.apply(n, "WithClause", nil, n.WithClause) + a.apply(n, "Targets", nil, n.Targets) + a.apply(n, "FromClause", nil, n.FromClause) case *ast.DiscardStmt: // pass @@ -812,12 +814,16 @@ func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast. a.apply(n, "Cols", nil, n.Cols) a.apply(n, "SelectStmt", nil, n.SelectStmt) a.apply(n, "OnConflictClause", nil, n.OnConflictClause) + a.apply(n, "OnDuplicateKeyUpdate", nil, n.OnDuplicateKeyUpdate) a.apply(n, "ReturningList", nil, n.ReturningList) a.apply(n, "WithClause", nil, n.WithClause) case *ast.Integer: // pass + case *ast.IntervalExpr: + a.apply(n, "Value", nil, n.Value) + case *ast.IntoClause: a.apply(n, "Rel", nil, n.Rel) a.apply(n, "ColNames", nil, n.ColNames) @@ -883,6 +889,9 @@ func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast. a.apply(n, "OnConflictWhere", nil, n.OnConflictWhere) a.apply(n, "ExclRelTlist", nil, n.ExclRelTlist) + case *ast.OnDuplicateKeyUpdate: + a.apply(n, "TargetList", nil, n.TargetList) + case *ast.OpExpr: a.apply(n, "Xpr", nil, n.Xpr) a.apply(n, "Args", nil, n.Args) @@ -902,6 +911,12 @@ func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast. case *ast.ParamRef: // pass + case *ast.ParenExpr: + a.apply(n, "Expr", nil, n.Expr) + + case *ast.VariableExpr: + // Leaf node - no children to traverse + case *ast.PartitionBoundSpec: a.apply(n, "Listdatums", nil, n.Listdatums) a.apply(n, "Lowerdatums", nil, n.Lowerdatums) diff --git a/internal/sql/astutils/walk.go b/internal/sql/astutils/walk.go index 0943379f03..6d5e80bdc3 100644 --- a/internal/sql/astutils/walk.go +++ b/internal/sql/astutils/walk.go @@ -1077,6 +1077,12 @@ func Walk(f Visitor, node ast.Node) { if n.WithClause != nil { Walk(f, n.WithClause) } + if n.Targets != nil { + Walk(f, n.Targets) + } + if n.FromClause != nil { + Walk(f, n.FromClause) + } case *ast.DiscardStmt: // pass @@ -1312,6 +1318,9 @@ func Walk(f Visitor, node ast.Node) { if n.OnConflictClause != nil { Walk(f, n.OnConflictClause) } + if n.OnDuplicateKeyUpdate != nil { + Walk(f, n.OnDuplicateKeyUpdate) + } if n.ReturningList != nil { Walk(f, n.ReturningList) } @@ -1336,6 +1345,11 @@ func Walk(f Visitor, node ast.Node) { Walk(f, n.ViewQuery) } + case *ast.IntervalExpr: + if n.Value != nil { + Walk(f, n.Value) + } + case *ast.JoinExpr: if n.Larg != nil { Walk(f, n.Larg) @@ -1445,6 +1459,11 @@ func Walk(f Visitor, node ast.Node) { Walk(f, n.ExclRelTlist) } + case *ast.OnDuplicateKeyUpdate: + if n.TargetList != nil { + Walk(f, n.TargetList) + } + case *ast.OpExpr: if n.Xpr != nil { Walk(f, n.Xpr) @@ -1470,6 +1489,14 @@ func Walk(f Visitor, node ast.Node) { case *ast.ParamRef: // pass + case *ast.ParenExpr: + if n.Expr != nil { + Walk(f, n.Expr) + } + + case *ast.VariableExpr: + // Leaf node - no children to traverse + case *ast.PartitionBoundSpec: if n.Listdatums != nil { Walk(f, n.Listdatums) diff --git a/internal/sql/format/format.go b/internal/sql/format/format.go index f47587dd0b..922b02b61c 100644 --- a/internal/sql/format/format.go +++ b/internal/sql/format/format.go @@ -9,4 +9,12 @@ type Formatter interface { // TypeName returns the SQL type name for the given namespace and name. // This handles dialect-specific type name mappings (e.g., pg_catalog.int4 -> integer) TypeName(ns, name string) string + + // Param returns the parameter placeholder for the given parameter number. + // PostgreSQL uses $1, $2, etc. MySQL uses ? + Param(n int) string + + // Cast formats a type cast expression. + // PostgreSQL uses expr::type, MySQL uses CAST(expr AS type) + Cast(arg, typeName string) string } diff --git a/internal/sql/named/CLAUDE.md b/internal/sql/named/CLAUDE.md new file mode 100644 index 0000000000..05ba358ee9 --- /dev/null +++ b/internal/sql/named/CLAUDE.md @@ -0,0 +1,94 @@ +# Named Parameters Package - Claude Code Guide + +This package provides utilities for identifying sqlc's named parameter syntax. + +## Named Parameter Styles + +sqlc supports two styles of named parameters: + +### 1. Function-style: `sqlc.arg(name)`, `sqlc.narg(name)`, `sqlc.slice(name)` +Identified by `IsParamFunc()`: +```go +func IsParamFunc(node ast.Node) bool { + call, ok := node.(*ast.FuncCall) + if !ok { + return false + } + return call.Func.Schema == "sqlc" && + (call.Func.Name == "arg" || call.Func.Name == "narg" || call.Func.Name == "slice") +} +``` + +### 2. At-sign style: `@param_name` (PostgreSQL only) +Identified by `IsParamSign()`: +```go +func IsParamSign(node ast.Node) bool { + expr, ok := node.(*ast.A_Expr) + return ok && astutils.Join(expr.Name, ".") == "@" +} +``` + +## Important Distinction: sqlc @param vs MySQL @variable + +**sqlc named parameters** (`@param` in PostgreSQL queries): +- Represented as `A_Expr` with `Kind=A_Expr_Kind_OP` and `Name=["@"]` +- Detected by `IsParamSign()` +- Replaced with positional parameters (`$1`, `$2` for PostgreSQL, `?` for MySQL) + +**MySQL user variables** (`@user_id` in MySQL queries): +- Represented as `VariableExpr` +- NOT detected by `IsParamSign()` (it checks for `A_Expr`, not `VariableExpr`) +- Preserved as-is in the output SQL + +This distinction is critical: +```sql +-- PostgreSQL with sqlc @param syntax: +SELECT * FROM users WHERE id = @user_id +-- Becomes: SELECT * FROM users WHERE id = $1 + +-- MySQL with user variable: +SELECT * FROM users WHERE id != @user_id +-- Stays: SELECT * FROM users WHERE id != @user_id +``` + +## Usage in Parameter Rewriting + +The `rewrite/parameters.go` package uses these functions to find and replace named parameters: + +```go +// Find all named parameters +params := astutils.Search(root, func(node ast.Node) bool { + return named.IsParamFunc(node) || named.IsParamSign(node) +}) + +// Replace with positional parameters +astutils.Apply(root, func(cr *astutils.Cursor) bool { + if named.IsParamFunc(cr.Node()) || named.IsParamSign(cr.Node()) { + cr.Replace(&ast.ParamRef{Number: nextParam()}) + } + return true +}, nil) +``` + +## Converting MySQL @variable Correctly + +When converting TiDB's `VariableExpr` in `dolphin/convert.go`: + +```go +// CORRECT - preserves MySQL user variable as-is +func (c *cc) convertVariableExpr(n *pcast.VariableExpr) ast.Node { + return &ast.VariableExpr{ + Name: n.Name, + Location: n.OriginTextPosition(), + } +} + +// WRONG - would be treated as sqlc named parameter +func (c *cc) convertVariableExpr(n *pcast.VariableExpr) ast.Node { + return &ast.A_Expr{ + Kind: ast.A_Expr_Kind_OP, + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "@"}}}, + Rexpr: &ast.String{Str: n.Name}, + } +} +``` diff --git a/internal/sql/rewrite/CLAUDE.md b/internal/sql/rewrite/CLAUDE.md new file mode 100644 index 0000000000..dd6459029f --- /dev/null +++ b/internal/sql/rewrite/CLAUDE.md @@ -0,0 +1,104 @@ +# SQL Rewrite Package - Claude Code Guide + +This package handles AST transformations, primarily for parameter handling. + +## Key Functions + +### NamedParameters +`NamedParameters(engine config.Engine, raw *ast.RawStmt, ...) (*ast.RawStmt, map[int]Parameter, error)` + +Finds and replaces named parameters (`sqlc.arg()`, `@param`) with positional parameters. + +The function: +1. Searches for named parameters using `named.IsParamFunc()` and `named.IsParamSign()` +2. Extracts parameter names and types +3. Replaces them with `ast.ParamRef` nodes +4. Returns a map of parameter positions to their metadata + +### Expand +`Expand(raw *ast.RawStmt, expected int) error` + +Expands `sqlc.slice()` parameters into the correct number of positional parameters. + +## How Parameter Rewriting Works + +### Step 1: Find Named Parameters +```go +refs := astutils.Search(raw.Stmt, func(node ast.Node) bool { + return named.IsParamFunc(node) || named.IsParamSign(node) +}) +``` + +### Step 2: Replace with ParamRef +```go +astutils.Apply(raw.Stmt, func(cr *astutils.Cursor) bool { + if named.IsParamFunc(cr.Node()) { + // Extract name from sqlc.arg(name) + call := cr.Node().(*ast.FuncCall) + name := extractName(call.Args) + + cr.Replace(&ast.ParamRef{ + Number: nextParam(), + Location: call.Location, + }) + } + return true +}, nil) +``` + +## Important: AST Node Requirements + +For parameter rewriting to work correctly, the AST must be walkable. This means: + +1. All node types must have cases in `astutils/walk.go` +2. All node types must have cases in `astutils/rewrite.go` +3. New container types (like `OnDuplicateKeyUpdate`) must be traversed + +### Example: OnDuplicateKeyUpdate + +MySQL's `ON DUPLICATE KEY UPDATE` clause can contain `sqlc.arg()`: +```sql +INSERT INTO t (a) VALUES (sqlc.arg(val)) +ON DUPLICATE KEY UPDATE a = sqlc.arg(new_val) +``` + +For the parameter in `ON DUPLICATE KEY UPDATE` to be found and replaced: + +1. `InsertStmt` in `rewrite.go` must traverse `OnDuplicateKeyUpdate`: +```go +case *ast.InsertStmt: + a.apply(n, "Relation", nil, n.Relation) + a.apply(n, "Cols", nil, n.Cols) + a.apply(n, "SelectStmt", nil, n.SelectStmt) + a.apply(n, "OnConflictClause", nil, n.OnConflictClause) + a.apply(n, "OnDuplicateKeyUpdate", nil, n.OnDuplicateKeyUpdate) // Critical! + a.apply(n, "ReturningList", nil, n.ReturningList) + a.apply(n, "WithClause", nil, n.WithClause) +``` + +2. `OnDuplicateKeyUpdate` must have its own case: +```go +case *ast.OnDuplicateKeyUpdate: + a.apply(n, "List", nil, n.List) +``` + +## Debugging Parameter Issues + +If a `sqlc.arg()` isn't being converted to `?`: + +1. Check that the containing node type has a case in `rewrite.go` +2. Check that the case traverses all child fields +3. Add debug logging to see if the node is being visited: +```go +case *ast.YourType: + fmt.Printf("Visiting YourType with fields: %+v\n", n) + a.apply(n, "ChildField", nil, n.ChildField) +``` + +## Parameter Output Format by Engine + +- PostgreSQL: `$1`, `$2`, `$3`, ... +- MySQL: `?`, `?`, `?`, ... +- SQLite: `?`, `?`, `?`, ... + +The format is determined by the `Formatter.Param()` method in each engine.