diff --git a/update.go b/update.go index 8d658d7..eb2a9c4 100644 --- a/update.go +++ b/update.go @@ -16,6 +16,7 @@ type updateData struct { Prefixes []Sqlizer Table string SetClauses []setClause + From Sqlizer WhereParts []Sqlizer OrderBys []string Limit string @@ -100,6 +101,14 @@ func (d *updateData) ToSql() (sqlStr string, args []interface{}, err error) { } sql.WriteString(strings.Join(setSqls, ", ")) + if d.From != nil { + sql.WriteString(" FROM ") + args, err = appendToSql([]Sqlizer{d.From}, sql, "", args) + if err != nil { + return + } + } + if len(d.WhereParts) > 0 { sql.WriteString(" WHERE ") args, err = appendToSql(d.WhereParts, sql, " AND ", args) @@ -233,6 +242,19 @@ func (b UpdateBuilder) SetMap(clauses map[string]interface{}) UpdateBuilder { return b } +// From adds FROM clause to the query +// FROM is valid construct in postgresql only. +func (b UpdateBuilder) From(from string) UpdateBuilder { + return builder.Set(b, "From", newPart(from)).(UpdateBuilder) +} + +// FromSelect sets a subquery into the FROM clause of the query. +func (b UpdateBuilder) FromSelect(from SelectBuilder, alias string) UpdateBuilder { + // Prevent misnumbered parameters in nested selects (#183). + from = from.PlaceholderFormat(Question) + return builder.Set(b, "From", Alias(from, alias)).(UpdateBuilder) +} + // Where adds WHERE expressions to the query. // // See SelectBuilder.Where for more information. diff --git a/update_test.go b/update_test.go index 9951451..a2bee1a 100644 --- a/update_test.go +++ b/update_test.go @@ -82,3 +82,26 @@ func TestUpdateBuilderNoRunner(t *testing.T) { _, err := b.Exec() assert.Equal(t, RunnerNotSet, err) } + +func TestUpdateBuilderFrom(t *testing.T) { + sql, _, err := Update("employees").Set("sales_count", 100).From("accounts").Where("accounts.name = ?", "ACME").ToSql() + assert.NoError(t, err) + assert.Equal(t, "UPDATE employees SET sales_count = ? FROM accounts WHERE accounts.name = ?", sql) +} + +func TestUpdateBuilderFromSelect(t *testing.T) { + sql, _, err := Update("employees"). + Set("sales_count", 100). + FromSelect(Select("id"). + From("accounts"). + Where("accounts.name = ?", "ACME"), "subquery"). + Where("employees.account_id = subquery.id").ToSql() + assert.NoError(t, err) + + expectedSql := + "UPDATE employees " + + "SET sales_count = ? " + + "FROM (SELECT id FROM accounts WHERE accounts.name = ?) AS subquery " + + "WHERE employees.account_id = subquery.id" + assert.Equal(t, expectedSql, sql) +}