diff --git a/benchmarks/bench_test.go b/benchmarks/bench_test.go index 4007771..f8b1951 100644 --- a/benchmarks/bench_test.go +++ b/benchmarks/bench_test.go @@ -118,6 +118,23 @@ func BenchmarkKallaxInsertWithRelationships(b *testing.B) { } } +func BenchmarkKallaxUpdateWithRelationships(b *testing.B) { + db := setupDB(b, openTestDB(b)) + defer teardownDB(b, db) + + store := NewPersonStore(db) + pers := mkPersonWithRels() + if err := store.Insert(pers); err != nil { + b.Fatalf("error inserting: %s", err) + } + + for i := 0; i < b.N; i++ { + if _, err := store.Update(pers); err != nil { + b.Fatalf("error updating: %s", err) + } + } +} + func BenchmarkSQLBoilerInsertWithRelationships(b *testing.B) { db := setupDB(b, openTestDB(b)) defer teardownDB(b, db) @@ -193,6 +210,23 @@ func BenchmarkKallaxInsert(b *testing.B) { } } +func BenchmarkKallaxUpdate(b *testing.B) { + db := setupDB(b, openTestDB(b)) + defer teardownDB(b, db) + + store := NewPersonStore(db) + pers := &Person{Name: "foo"} + if err := store.Insert(pers); err != nil { + b.Fatalf("error inserting: %s", err) + } + + for i := 0; i < b.N; i++ { + if _, err := store.Update(pers); err != nil { + b.Fatalf("error updating: %s", err) + } + } +} + func BenchmarkSQLBoilerInsert(b *testing.B) { db := setupDB(b, openTestDB(b)) defer teardownDB(b, db) diff --git a/store.go b/store.go index d31c16c..05931bd 100644 --- a/store.go +++ b/store.go @@ -227,7 +227,8 @@ func (s *Store) Update(schema Schema, record Record, cols ...SchemaField) (int64 cols = schema.Columns() } - columnNames := ColumnNames(cols) + // remove the ID from there + columnNames := ColumnNames(cols)[1:] values, columnNames, err := RecordValues(record, columnNames...) if err != nil { return 0, err @@ -237,18 +238,24 @@ func (s *Store) Update(schema Schema, record Record, cols ...SchemaField) (int64 columnNames = append(columnNames, virtualCols...) values = append(values, virtualColValues...) - var clauses = make(map[string]interface{}, len(cols)) + var query bytes.Buffer + query.WriteString("UPDATE ") + query.WriteString(schema.Table()) + query.WriteString(" SET ") for i, col := range columnNames { - clauses[col] = values[i] + if i != 0 { + query.WriteRune(',') + } + query.WriteString(col) + query.WriteRune('=') + query.WriteString(fmt.Sprintf("$%d", i+1)) } + query.WriteString(" WHERE ") + query.WriteString(schema.ID().String()) + query.WriteRune('=') + query.WriteString(fmt.Sprintf("$%d", len(columnNames)+1)) - result, err := s.builder. - Update(schema.Table()). - SetMap(clauses). - Where(squirrel.Eq{ - schema.ID().String(): record.GetID(), - }). - Exec() + result, err := s.proxy.Exec(query.String(), append(values, record.GetID())...) if err != nil { return 0, err }