Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion internal/codegen/golang/mysql_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@ func mysqlType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.C
}
return "sql.NullInt32"

case "bigint":
case "bigint", "bigint unsigned", "bigint signed":
// "bigint unsigned" and "bigint signed" are MySQL CAST types
// Note: We use int64 for CAST AS UNSIGNED to match original behavior,
// even though uint64 would be more semantically correct.
// The Unsigned flag on columns (from table schema) still uses uint64.
if notNull {
if unsigned {
return "uint64"
Expand Down
9 changes: 8 additions & 1 deletion internal/compiler/output_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,14 @@ func (c *Compiler) sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, erro
list := &ast.List{}
switch n := node.(type) {
case *ast.DeleteStmt:
list = n.Relations
if n.Relations != nil {
list = n.Relations
} else if n.FromClause != nil {
// Multi-table DELETE: walk FromClause to find tables
var tv tableVisitor
astutils.Walk(&tv, n.FromClause)
list = &tv.list
}
case *ast.InsertStmt:
list = &ast.List{
Items: []ast.Node{n.Relation},
Expand Down
54 changes: 47 additions & 7 deletions internal/endtoend/fmt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,30 @@ package main
import (
"bytes"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"testing"

"github.com/sqlc-dev/sqlc/internal/config"
"github.com/sqlc-dev/sqlc/internal/debug"
"github.com/sqlc-dev/sqlc/internal/engine/dolphin"
"github.com/sqlc-dev/sqlc/internal/engine/postgresql"
"github.com/sqlc-dev/sqlc/internal/sql/ast"
"github.com/sqlc-dev/sqlc/internal/sql/format"
)

// sqlParser is an interface for SQL parsers
type sqlParser interface {
Parse(r io.Reader) ([]ast.Statement, error)
}

// sqlFormatter is an interface for formatters
type sqlFormatter interface {
format.Formatter
}

func TestFormat(t *testing.T) {
t.Parallel()
for _, tc := range FindTests(t, "testdata", "base") {
Expand All @@ -36,9 +49,38 @@ func TestFormat(t *testing.T) {
return
}

// For now, only test PostgreSQL since that's the only engine with Format support
engine := conf.SQL[0].Engine
if engine != config.EnginePostgreSQL {

// Select the appropriate parser and fingerprint function based on engine
var parse sqlParser
var formatter sqlFormatter
var fingerprint func(string) (string, error)

switch engine {
case config.EnginePostgreSQL:
pgParser := postgresql.NewParser()
parse = pgParser
formatter = pgParser
fingerprint = postgresql.Fingerprint
case config.EngineMySQL:
mysqlParser := dolphin.NewParser()
parse = mysqlParser
formatter = mysqlParser
// For MySQL, we use a "round-trip" fingerprint: parse the SQL, format it,
// and return the formatted string. This tests that our formatting produces
// valid SQL that parses to the same AST structure.
fingerprint = func(sql string) (string, error) {
stmts, err := mysqlParser.Parse(strings.NewReader(sql))
if err != nil {
return "", err
}
if len(stmts) == 0 {
return "", nil
}
return ast.Format(stmts[0].Raw, mysqlParser), nil
}
default:
// Skip unsupported engines
return
}

Expand Down Expand Up @@ -68,8 +110,6 @@ func TestFormat(t *testing.T) {
return
}

parse := postgresql.NewParser()

for _, queryFile := range queryFiles {
if _, err := os.Stat(queryFile); os.IsNotExist(err) {
continue
Expand Down Expand Up @@ -99,7 +139,7 @@ func TestFormat(t *testing.T) {
}
query := strings.TrimSpace(string(contents[start : start+length]))

expected, err := postgresql.Fingerprint(query)
expected, err := fingerprint(query)
if err != nil {
t.Fatal(err)
}
Expand All @@ -109,8 +149,8 @@ func TestFormat(t *testing.T) {
debug.Dump(r, err)
}

out := ast.Format(stmt.Raw, parse)
actual, err := postgresql.Fingerprint(out)
out := ast.Format(stmt.Raw, formatter)
actual, err := fingerprint(out)
if err != nil {
t.Error(err)
}
Expand Down
224 changes: 224 additions & 0 deletions internal/engine/dolphin/CLAUDE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
# Dolphin Engine (MySQL) - Claude Code Guide

The dolphin engine handles MySQL parsing and AST conversion using the TiDB parser.

## Architecture

### Parser Flow
```
SQL String → TiDB Parser → TiDB AST → sqlc AST → Analysis/Codegen
```

### Key Files
- `convert.go` - Converts TiDB AST nodes to sqlc AST nodes
- `format.go` - MySQL-specific formatting (identifiers, types, parameters)
- `parse.go` - Entry point for parsing MySQL SQL

## TiDB Parser

The TiDB parser (`github.com/pingcap/tidb/pkg/parser`) is used for MySQL parsing:

```go
import (
pcast "github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/parser/types"
)
```

### Common TiDB Types
- `pcast.SelectStmt`, `pcast.InsertStmt`, etc. - Statement types
- `pcast.ColumnNameExpr` - Column reference
- `pcast.FuncCallExpr` - Function call
- `pcast.BinaryOperationExpr` - Binary expression
- `pcast.VariableExpr` - MySQL user variable (@var)
- `pcast.Join` - JOIN clause with Left, Right, On, Using

## Conversion Pattern

Each TiDB node type has a corresponding converter method:

```go
func (c *cc) convertSelectStmt(n *pcast.SelectStmt) *ast.SelectStmt {
return &ast.SelectStmt{
FromClause: c.convertTableRefsClause(n.From),
WhereClause: c.convert(n.Where),
// ...
}
}
```

The main `convert()` method dispatches to specific converters:
```go
func (c *cc) convert(node pcast.Node) ast.Node {
switch n := node.(type) {
case *pcast.SelectStmt:
return c.convertSelectStmt(n)
case *pcast.InsertStmt:
return c.convertInsertStmt(n)
// ...
}
}
```

## Key Conversions

### Column References
```go
func (c *cc) convertColumnNameExpr(n *pcast.ColumnNameExpr) *ast.ColumnRef {
var items []ast.Node
if schema := n.Name.Schema.String(); schema != "" {
items = append(items, NewIdentifier(schema))
}
if table := n.Name.Table.String(); table != "" {
items = append(items, NewIdentifier(table))
}
items = append(items, NewIdentifier(n.Name.Name.String()))
return &ast.ColumnRef{Fields: &ast.List{Items: items}}
}
```

### JOINs
```go
func (c *cc) convertJoin(n *pcast.Join) *ast.List {
if n.Right != nil && n.Left != nil {
return &ast.List{
Items: []ast.Node{&ast.JoinExpr{
Jointype: ast.JoinType(n.Tp),
Larg: c.convert(n.Left),
Rarg: c.convert(n.Right),
Quals: c.convert(n.On),
UsingClause: convertUsing(n.Using),
}},
}
}
// No join - just return tables
// ...
}
```

### MySQL User Variables
MySQL user variables (`@var`) are different from sqlc's `@param` syntax:
```go
func (c *cc) convertVariableExpr(n *pcast.VariableExpr) ast.Node {
// Use VariableExpr to preserve as-is (NOT A_Expr which would be treated as sqlc param)
return &ast.VariableExpr{
Name: n.Name,
Location: n.OriginTextPosition(),
}
}
```

### Type Casts (CAST AS)
```go
func (c *cc) convertFuncCastExpr(n *pcast.FuncCastExpr) ast.Node {
typeName := types.TypeStr(n.Tp.GetType())
// Handle UNSIGNED/SIGNED specially
if typeName == "bigint" {
if mysql.HasUnsignedFlag(n.Tp.GetFlag()) {
typeName = "bigint unsigned"
} else {
typeName = "bigint signed"
}
}
return &ast.TypeCast{
Arg: c.convert(n.Expr),
TypeName: &ast.TypeName{Name: typeName},
}
}
```

### Column Definitions
```go
func convertColumnDef(def *pcast.ColumnDef) *ast.ColumnDef {
typeName := &ast.TypeName{Name: types.TypeToStr(def.Tp.GetType(), def.Tp.GetCharset())}

// Only add Typmods for types where length is meaningful
tp := def.Tp.GetType()
flen := def.Tp.GetFlen()
switch tp {
case mysql.TypeVarchar, mysql.TypeString, mysql.TypeVarString:
if flen >= 0 {
typeName.Typmods = &ast.List{
Items: []ast.Node{&ast.Integer{Ival: int64(flen)}},
}
}
// Don't add for DATETIME, TIMESTAMP - internal flen is not user-specified
}
// ...
}
```

### Multi-Table DELETE
MySQL supports `DELETE t1, t2 FROM t1 JOIN t2 ...`:
```go
func (c *cc) convertDeleteStmt(n *pcast.DeleteStmt) *ast.DeleteStmt {
if n.IsMultiTable && n.Tables != nil {
// Convert targets (t1.*, t2.*)
targets := &ast.List{}
for _, table := range n.Tables.Tables {
// Build ColumnRef for each target
}
stmt.Targets = targets

// Preserve JOINs in FromClause
stmt.FromClause = c.convertTableRefsClause(n.TableRefs).Items[0]
} else {
// Single-table DELETE
stmt.Relations = c.convertTableRefsClause(n.TableRefs)
}
}
```

## MySQL-Specific Formatting

### format.go
```go
func (p *Parser) TypeName(ns, name string) string {
switch name {
case "bigint unsigned":
return "UNSIGNED"
case "bigint signed":
return "SIGNED"
}
return name
}

func (p *Parser) Param(n int) string {
return "?" // MySQL uses ? for all parameters
}
```

## Common Issues and Solutions

### Issue: Panic in Walk/Apply
**Cause**: New AST node type not handled in `astutils/walk.go` or `astutils/rewrite.go`
**Solution**: Add case for the node type in both files

### Issue: sqlc.arg() not converted in ON DUPLICATE KEY UPDATE
**Cause**: `InsertStmt` case in `rewrite.go` didn't traverse `OnDuplicateKeyUpdate`
**Solution**: Add `a.apply(n, "OnDuplicateKeyUpdate", nil, n.OnDuplicateKeyUpdate)`

### Issue: MySQL @variable being treated as parameter
**Cause**: Converting `VariableExpr` to `A_Expr` with `@` operator
**Solution**: Use `ast.VariableExpr` instead, which is not detected by `named.IsParamSign()`

### Issue: Type length appearing incorrectly (e.g., datetime(39))
**Cause**: Using internal `flen` for all types
**Solution**: Only populate `Typmods` for types where length is user-specified (varchar, char, etc.)

## Testing

### TestFormat
Tests that SQL can be:
1. Parsed
2. Formatted back to SQL
3. Re-parsed
4. Re-formatted to match

### TestReplay
Tests the full sqlc pipeline:
1. Parse schema and queries
2. Analyze
3. Generate code
4. Compare with expected output
Loading
Loading