diff --git a/internal/mysql/param.go b/internal/mysql/param.go index 021f074c48..ecb7865b29 100644 --- a/internal/mysql/param.go +++ b/internal/mysql/param.go @@ -44,12 +44,20 @@ func paramsInLimitExpr(limit *sqlparser.Limit, s *Schema, tableAliasMap FromTabl func paramsInWhereExpr(e sqlparser.SQLNode, s *Schema, tableAliasMap FromTables, defaultTable string, settings dinosql.GenerateSettings) ([]*Param, error) { params := []*Param{} + if e == nil { + return params, nil + } else if expr, ok := e.(*sqlparser.Where); ok { + if expr == nil { + return params, nil + } + e = expr.Expr + } switch v := e.(type) { case *sqlparser.Where: if v == nil { return params, nil } - return paramsInWhereExpr(v.Expr, s, tableAliasMap, defaultTable, settings) + return paramsInWhereExpr(v, s, tableAliasMap, defaultTable, settings) case *sqlparser.ComparisonExpr: p, found, err := paramInComparison(v, s, tableAliasMap, defaultTable, settings) if err != nil { diff --git a/internal/mysql/parse.go b/internal/mysql/parse.go index d2d9a4de73..24a7cac5ac 100644 --- a/internal/mysql/parse.go +++ b/internal/mysql/parse.go @@ -272,8 +272,10 @@ func parseUpdate(node *sqlparser.Update, query string, s *Schema, settings dinos params := []*Param{} for _, updateExpr := range node.Exprs { col := updateExpr.Name - newValue, isParam := updateExpr.Expr.(*sqlparser.SQLVal) - if !isParam { + newValue, isValue := updateExpr.Expr.(*sqlparser.SQLVal) + if !isValue { + continue + } else if isParam := newValue.Type == sqlparser.ValArg; !isParam { continue } colDfn, err := s.getColType(col, tableAliasMap, defaultTable) @@ -289,7 +291,7 @@ func parseUpdate(node *sqlparser.Update, query string, s *Schema, settings dinos params = append(params, ¶m) } - whereParams, err := paramsInWhereExpr(node.Where.Expr, s, tableAliasMap, defaultTable, settings) + whereParams, err := paramsInWhereExpr(node.Where, s, tableAliasMap, defaultTable, settings) if err != nil { return nil, fmt.Errorf("failed to parse params from WHERE expression: %w", err) } diff --git a/internal/mysql/parse_test.go b/internal/mysql/parse_test.go index 373293511e..c6bdc013cd 100644 --- a/internal/mysql/parse_test.go +++ b/internal/mysql/parse_test.go @@ -388,6 +388,22 @@ UPDATE users SET first_name = ?, last_name = ? WHERE id > ? AND first_name = ? L SchemaLookup: mockSchema, }, }, + testCase{ + name: "update_without_where", + input: expected{ + query: "/* name: UpdateAllUsers :exec */ update users set first_name = 'Bob'", + schema: mockSchema, + }, + output: &Query{ + SQL: "update users set first_name = 'Bob'", + Columns: nil, + Params: []*Param{}, + Name: "UpdateAllUsers", + Cmd: ":exec", + DefaultTableName: "users", + SchemaLookup: mockSchema, + }, + }, testCase{ name: "update_users", input: expected{