Skip to content

Commit

Permalink
fix(compiler): Support references to columns in joined tables in UPDA…
Browse files Browse the repository at this point in the history
…TE statements (#1289)

* fix(compiler): Support references to columns in joined tables in UPDATE statements
  • Loading branch information
timstudd committed Nov 17, 2021
1 parent 5eb649d commit 466c3e1
Show file tree
Hide file tree
Showing 17 changed files with 297 additions and 15 deletions.
8 changes: 7 additions & 1 deletion internal/compiler/find_params.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,13 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
if !ok {
continue
}
*p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: n.Relation})
for _, relation := range n.Relations.Items {
rv, ok := relation.(*ast.RangeVar)
if !ok {
continue
}
*p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: rv})
}
p.seen[ref.Location] = struct{}{}
}

Expand Down
2 changes: 1 addition & 1 deletion internal/compiler/output_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ func sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, error) {
})
case *ast.UpdateStmt:
list = &ast.List{
Items: append(n.FromClause.Items, n.Relation),
Items: append(n.FromClause.Items, n.Relations.Items...),
}
default:
return nil, fmt.Errorf("sourceTables: unsupported node type: %T", n)
Expand Down
29 changes: 29 additions & 0 deletions internal/endtoend/testdata/update_join/mysql/db/db.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions internal/endtoend/testdata/update_join/mysql/db/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

68 changes: 68 additions & 0 deletions internal/endtoend/testdata/update_join/mysql/db/query.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

37 changes: 37 additions & 0 deletions internal/endtoend/testdata/update_join/mysql/query.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
CREATE TABLE primary_table (
id bigint(20) unsigned NOT NULL AUTO_INCREMENT,
user_id bigint(20) unsigned NOT NULL,
PRIMARY KEY (id)
);

CREATE TABLE join_table (
id bigint(20) unsigned NOT NULL AUTO_INCREMENT,
primary_table_id bigint(20) unsigned NOT NULL,
other_table_id bigint(20) unsigned NOT NULL,
is_active tinyint(1) NOT NULL DEFAULT '0',
PRIMARY KEY (id)
);

-- name: UpdateJoin :exec
UPDATE join_table as jt
JOIN primary_table as pt
ON jt.primary_table_id = pt.id
SET jt.is_active = ?
WHERE jt.id = ?
AND pt.user_id = ?;

-- name: UpdateLeftJoin :exec
UPDATE join_table as jt
LEFT JOIN primary_table as pt
ON jt.primary_table_id = pt.id
SET jt.is_active = ?
WHERE jt.id = ?
AND pt.user_id = ?;

-- name: UpdateRightJoin :exec
UPDATE join_table as jt
RIGHT JOIN primary_table as pt
ON jt.primary_table_id = pt.id
SET jt.is_active = ?
WHERE jt.id = ?
AND pt.user_id = ?;
11 changes: 11 additions & 0 deletions internal/endtoend/testdata/update_join/mysql/sqlc.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"version": "1",
"packages": [
{
"path": "db",
"engine": "mysql",
"schema": "query.sql",
"queries": "query.sql"
}
]
}
29 changes: 29 additions & 0 deletions internal/endtoend/testdata/update_join/postgresql/db/db.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions internal/endtoend/testdata/update_join/postgresql/db/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 28 additions & 0 deletions internal/endtoend/testdata/update_join/postgresql/db/query.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions internal/endtoend/testdata/update_join/postgresql/query.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
CREATE TABLE primary_table (
id INT PRIMARY KEY,
user_id INT NOT NULL
);

CREATE TABLE join_table (
id INT PRIMARY KEY,
primary_table_id INT NOT NULL,
other_table_id INT NOT NULL,
is_active BOOLEAN NOT NULL
);

-- name: UpdateJoin :exec
UPDATE join_table
SET is_active = $1
FROM primary_table
WHERE join_table.id = $2
AND primary_table.user_id = $3
AND join_table.primary_table_id = primary_table.id;
11 changes: 11 additions & 0 deletions internal/endtoend/testdata/update_join/postgresql/sqlc.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"version": "1",
"packages": [
{
"path": "db",
"engine": "postgresql",
"schema": "query.sql",
"queries": "query.sql"
}
]
}
14 changes: 10 additions & 4 deletions internal/engine/dolphin/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ func (c *cc) convertUpdateStmt(n *pcast.UpdateStmt) *ast.UpdateStmt {
panic("expected one range var")
}

var rangeVar *ast.RangeVar
relations := &ast.List{}
switch rel := rels.Items[0].(type) {

// Special case for joins in updates
Expand All @@ -549,10 +549,16 @@ func (c *cc) convertUpdateStmt(n *pcast.UpdateStmt) *ast.UpdateStmt {
if !ok {
panic("expected range var")
}
rangeVar = left
relations.Items = append(relations.Items, left)

right, ok := rel.Rarg.(*ast.RangeVar)
if !ok {
panic("expected range var")
}
relations.Items = append(relations.Items, right)

case *ast.RangeVar:
rangeVar = rel
relations.Items = append(relations.Items, rel)

default:
panic("expected range var")
Expand All @@ -564,7 +570,7 @@ func (c *cc) convertUpdateStmt(n *pcast.UpdateStmt) *ast.UpdateStmt {
list.Items = append(list.Items, c.convertAssignment(a))
}
return &ast.UpdateStmt{
Relation: rangeVar,
Relations: relations,
TargetList: list,
WhereClause: c.convert(n.Where),
FromClause: &ast.List{},
Expand Down
14 changes: 9 additions & 5 deletions internal/engine/postgresql/convert.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//go:build !windows
// +build !windows

package postgresql
Expand Down Expand Up @@ -2765,8 +2766,11 @@ func convertUpdateStmt(n *pg.UpdateStmt) *ast.UpdateStmt {
if n == nil {
return nil
}

return &ast.UpdateStmt{
Relation: convertRangeVar(n.Relation),
Relations: &ast.List{
Items: []ast.Node{convertRangeVar(n.Relation)},
},
TargetList: convertSlice(n.TargetList),
WhereClause: convertNode(n.WhereClause),
FromClause: convertSlice(n.FromClause),
Expand All @@ -2780,10 +2784,10 @@ func convertVacuumStmt(n *pg.VacuumStmt) *ast.VacuumStmt {
return nil
}
return &ast.VacuumStmt{
// FIXME: The VacuumStmt node has changed quite a bit
// Options: n.Options
// Relation: convertRangeVar(n.Relation),
// VaCols: convertSlice(n.VaCols),
// FIXME: The VacuumStmt node has changed quite a bit
// Options: n.Options
// Relation: convertRangeVar(n.Relation),
// VaCols: convertSlice(n.VaCols),
}
}

Expand Down
2 changes: 1 addition & 1 deletion internal/sql/ast/update_stmt.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package ast

type UpdateStmt struct {
Relation *RangeVar
Relations *List
TargetList *List
WhereClause Node
FromClause *List
Expand Down
2 changes: 1 addition & 1 deletion internal/sql/astutils/rewrite.go
Original file line number Diff line number Diff line change
Expand Up @@ -1141,7 +1141,7 @@ func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast.
// pass

case *ast.UpdateStmt:
a.apply(n, "Relation", nil, n.Relation)
a.apply(n, "Relations", nil, n.Relations)
a.apply(n, "TargetList", nil, n.TargetList)
a.apply(n, "WhereClause", nil, n.WhereClause)
a.apply(n, "FromClause", nil, n.FromClause)
Expand Down
4 changes: 2 additions & 2 deletions internal/sql/astutils/walk.go
Original file line number Diff line number Diff line change
Expand Up @@ -2008,8 +2008,8 @@ func Walk(f Visitor, node ast.Node) {
// pass

case *ast.UpdateStmt:
if n.Relation != nil {
Walk(f, n.Relation)
if n.Relations != nil {
Walk(f, n.Relations)
}
if n.TargetList != nil {
Walk(f, n.TargetList)
Expand Down

0 comments on commit 466c3e1

Please sign in to comment.