diff --git a/bun.go b/bun.go index 0b1969b2..2a788c4c 100644 --- a/bun.go +++ b/bun.go @@ -15,9 +15,17 @@ type ( NullTime = schema.NullTime BaseModel = schema.BaseModel + Query = schema.Query + BeforeAppendModelHook = schema.BeforeAppendModelHook + + BeforeScanRowHook = schema.BeforeScanRowHook + AfterScanRowHook = schema.AfterScanRowHook + + // DEPRECATED. Use BeforeScanRowHook instead. BeforeScanHook = schema.BeforeScanHook - AfterScanHook = schema.AfterScanHook + // DEPRECATED. Use AfterScanRowHook instead. + AfterScanHook = schema.AfterScanHook ) type BeforeSelectHook interface { diff --git a/hook.go b/hook.go index 7b60d2a2..7cca7ef6 100644 --- a/hook.go +++ b/hook.go @@ -3,7 +3,6 @@ package bun import ( "context" "database/sql" - "reflect" "strings" "sync/atomic" "time" @@ -11,18 +10,11 @@ import ( "github.com/uptrace/bun/schema" ) -type IQuery interface { - schema.QueryAppender - Operation() string - GetModel() Model - GetTableName() string -} - type QueryEvent struct { DB *DB QueryAppender schema.QueryAppender // Deprecated: use IQuery instead - IQuery IQuery + IQuery Query Query string QueryArgs []interface{} Model Model @@ -58,7 +50,7 @@ type QueryHook interface { func (db *DB) beforeQuery( ctx context.Context, - iquery IQuery, + iquery Query, query string, queryArgs []interface{}, model Model, @@ -116,13 +108,3 @@ func (db *DB) afterQueryFromIndex(ctx context.Context, event *QueryEvent, hookIn db.queryHooks[hookIndex].AfterQuery(ctx, event) } } - -//------------------------------------------------------------------------------ - -func callBeforeScanHook(ctx context.Context, v reflect.Value) error { - return v.Interface().(schema.BeforeScanHook).BeforeScan(ctx) -} - -func callAfterScanHook(ctx context.Context, v reflect.Value) error { - return v.Interface().(schema.AfterScanHook).AfterScan(ctx) -} diff --git a/internal/dbtest/model_hook_test.go b/internal/dbtest/model_hook_test.go index 69dc4437..dd913006 100644 --- a/internal/dbtest/model_hook_test.go +++ b/internal/dbtest/model_hook_test.go @@ -39,17 +39,14 @@ func TestModelHook(t *testing.T) { } func testModelHook(t *testing.T, dbName string, db *bun.DB) { - _, err := db.NewDropTable().Model((*ModelHookTest)(nil)).IfExists().Exec(ctx) - require.NoError(t, err) - - _, err = db.NewCreateTable().Model((*ModelHookTest)(nil)).Exec(ctx) + err := db.ResetModel(ctx, (*ModelHookTest)(nil)) require.NoError(t, err) { hook := &ModelHookTest{ID: 1} _, err := db.NewInsert().Model(hook).Exec(ctx) require.NoError(t, err) - require.Equal(t, []string{"BeforeInsert", "AfterInsert"}, events.Flush()) + require.Equal(t, []string{"BeforeInsert", "BeforeAppendModel", "AfterInsert"}, events.Flush()) } { @@ -58,13 +55,14 @@ func testModelHook(t *testing.T, dbName string, db *bun.DB) { require.NoError(t, err) require.Equal(t, []string{ "BeforeSelect", + "BeforeAppendModel", "BeforeScan", "AfterScan", "AfterSelect", }, events.Flush()) } - { + t.Run("selectEmptySlice", func(t *testing.T) { hooks := make([]ModelHookTest, 0) err := db.NewSelect().Model(&hooks).Scan(ctx) require.NoError(t, err) @@ -74,20 +72,20 @@ func testModelHook(t *testing.T, dbName string, db *bun.DB) { "AfterScan", "AfterSelect", }, events.Flush()) - } + }) { hook := &ModelHookTest{ID: 1} _, err := db.NewUpdate().Model(hook).Where("id = 1").Exec(ctx) require.NoError(t, err) - require.Equal(t, []string{"BeforeUpdate", "AfterUpdate"}, events.Flush()) + require.Equal(t, []string{"BeforeUpdate", "BeforeAppendModel", "AfterUpdate"}, events.Flush()) } { hook := &ModelHookTest{ID: 1} _, err := db.NewDelete().Model(hook).Where("id = 1").Exec(ctx) require.NoError(t, err) - require.Equal(t, []string{"BeforeDelete", "AfterDelete"}, events.Flush()) + require.Equal(t, []string{"BeforeDelete", "BeforeAppendModel", "AfterDelete"}, events.Flush()) } { @@ -95,6 +93,30 @@ func testModelHook(t *testing.T, dbName string, db *bun.DB) { require.NoError(t, err) require.Equal(t, []string{"BeforeDelete", "AfterDelete"}, events.Flush()) } + + t.Run("insertSlice", func(t *testing.T) { + hooks := []ModelHookTest{{ID: 1}, {ID: 2}} + _, err := db.NewInsert().Model(&hooks).Exec(ctx) + require.NoError(t, err) + require.Equal(t, []string{ + "BeforeInsert", + "BeforeAppendModel", + "BeforeAppendModel", + "AfterInsert", + }, events.Flush()) + }) + + t.Run("insertSliceOfPtr", func(t *testing.T) { + hooks := []*ModelHookTest{{ID: 3}, {ID: 4}} + _, err := db.NewInsert().Model(&hooks).Exec(ctx) + require.NoError(t, err) + require.Equal(t, []string{ + "BeforeInsert", + "BeforeAppendModel", + "BeforeAppendModel", + "AfterInsert", + }, events.Flush()) + }) } type ModelHookTest struct { @@ -102,6 +124,13 @@ type ModelHookTest struct { Value string } +var _ bun.BeforeAppendModelHook = (*ModelHookTest)(nil) + +func (t *ModelHookTest) BeforeAppendModel(query bun.Query) error { + events.Add("BeforeAppendModel") + return nil +} + var _ bun.BeforeScanHook = (*ModelHookTest)(nil) func (t *ModelHookTest) BeforeScan(c context.Context) error { @@ -182,7 +211,7 @@ func (t *ModelHookTest) AfterDelete(ctx context.Context, query *bun.DeleteQuery) func assertQueryModel(query interface{ GetModel() bun.Model }) { switch value := query.GetModel().Value(); value.(type) { - case *ModelHookTest, *[]ModelHookTest: + case *ModelHookTest, *[]ModelHookTest, *[]*ModelHookTest: // ok default: panic(fmt.Errorf("unexpected: %T", value)) diff --git a/model.go b/model.go index 71a3a1e6..88ec4899 100644 --- a/model.go +++ b/model.go @@ -15,10 +15,7 @@ var errNilModel = errors.New("bun: Model(nil)") var timeType = reflect.TypeOf((*time.Time)(nil)).Elem() -type Model interface { - ScanRows(ctx context.Context, rows *sql.Rows) (int, error) - Value() interface{} -} +type Model = schema.Model type rowScanner interface { ScanRow(ctx context.Context, rows *sql.Rows) error @@ -27,8 +24,9 @@ type rowScanner interface { type TableModel interface { Model - schema.BeforeScanHook - schema.AfterScanHook + schema.BeforeAppendModelHook + schema.BeforeScanRowHook + schema.AfterScanRowHook ScanColumn(column string, src interface{}) error Table() *schema.Table diff --git a/model_table_slice.go b/model_table_slice.go index b312b663..2d207567 100644 --- a/model_table_slice.go +++ b/model_table_slice.go @@ -91,10 +91,31 @@ func (m *sliceTableModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, er return n, nil } +var _ schema.BeforeAppendModelHook = (*sliceTableModel)(nil) + +func (m *sliceTableModel) BeforeAppendModel(query Query) error { + if !m.table.HasBeforeAppendModelHook() { + return nil + } + + sliceLen := m.slice.Len() + for i := 0; i < sliceLen; i++ { + strct := m.slice.Index(i) + if !m.sliceOfPtr { + strct = strct.Addr() + } + err := strct.Interface().(schema.BeforeAppendModelHook).BeforeAppendModel(query) + if err != nil { + return err + } + } + return nil +} + // Inherit these hooks from structTableModel. var ( - _ schema.BeforeScanHook = (*sliceTableModel)(nil) - _ schema.AfterScanHook = (*sliceTableModel)(nil) + _ schema.BeforeScanRowHook = (*sliceTableModel)(nil) + _ schema.AfterScanRowHook = (*sliceTableModel)(nil) ) func (m *sliceTableModel) updateSoftDeleteField(tm time.Time) error { diff --git a/model_table_struct.go b/model_table_struct.go index fba17f42..0e74160f 100644 --- a/model_table_struct.go +++ b/model_table_struct.go @@ -100,38 +100,65 @@ func (m *structTableModel) mountJoins() { } } -var _ schema.BeforeScanHook = (*structTableModel)(nil) +var _ schema.BeforeAppendModelHook = (*structTableModel)(nil) -func (m *structTableModel) BeforeScan(ctx context.Context) error { - if !m.table.HasBeforeScanHook() { +func (m *structTableModel) BeforeAppendModel(query Query) error { + if !m.table.HasBeforeAppendModelHook() || !m.strct.IsValid() { return nil } - return callBeforeScanHook(ctx, m.strct.Addr()) + return m.strct.Addr().Interface().(schema.BeforeAppendModelHook).BeforeAppendModel(query) } -var _ schema.AfterScanHook = (*structTableModel)(nil) +var _ schema.BeforeScanRowHook = (*structTableModel)(nil) -func (m *structTableModel) AfterScan(ctx context.Context) error { - if !m.table.HasAfterScanHook() || !m.structInited { +func (m *structTableModel) BeforeScanRow(ctx context.Context) error { + if m.table.HasBeforeScanRowHook() { + return m.strct.Addr().Interface().(schema.BeforeScanRowHook).BeforeScanRow(ctx) + } + if m.table.HasBeforeScanHook() { + return m.strct.Addr().Interface().(schema.BeforeScanHook).BeforeScan(ctx) + } + return nil +} + +var _ schema.AfterScanRowHook = (*structTableModel)(nil) + +func (m *structTableModel) AfterScanRow(ctx context.Context) error { + if !m.structInited { return nil } - var firstErr error + if m.table.HasAfterScanRowHook() { + firstErr := m.strct.Addr().Interface().(schema.AfterScanRowHook).AfterScanRow(ctx) + + for _, j := range m.joins { + switch j.Relation.Type { + case schema.HasOneRelation, schema.BelongsToRelation: + if err := j.JoinModel.AfterScanRow(ctx); err != nil && firstErr == nil { + firstErr = err + } + } + } - if err := callAfterScanHook(ctx, m.strct.Addr()); err != nil && firstErr == nil { - firstErr = err + return firstErr } - for _, j := range m.joins { - switch j.Relation.Type { - case schema.HasOneRelation, schema.BelongsToRelation: - if err := j.JoinModel.AfterScan(ctx); err != nil && firstErr == nil { - firstErr = err + if m.table.HasAfterScanHook() { + firstErr := m.strct.Addr().Interface().(schema.AfterScanHook).AfterScan(ctx) + + for _, j := range m.joins { + switch j.Relation.Type { + case schema.HasOneRelation, schema.BelongsToRelation: + if err := j.JoinModel.AfterScanRow(ctx); err != nil && firstErr == nil { + firstErr = err + } } } + + return firstErr } - return firstErr + return nil } func (m *structTableModel) getJoin(name string) *relationJoin { @@ -257,7 +284,7 @@ func (m *structTableModel) ScanRow(ctx context.Context, rows *sql.Rows) error { } func (m *structTableModel) scanRow(ctx context.Context, rows *sql.Rows, dest []interface{}) error { - if err := m.BeforeScan(ctx); err != nil { + if err := m.BeforeScanRow(ctx); err != nil { return err } @@ -266,7 +293,7 @@ func (m *structTableModel) scanRow(ctx context.Context, rows *sql.Rows, dest []i return err } - if err := m.AfterScan(ctx); err != nil { + if err := m.AfterScanRow(ctx); err != nil { return err } diff --git a/query_base.go b/query_base.go index 199c7df1..4f10cb5f 100644 --- a/query_base.go +++ b/query_base.go @@ -165,6 +165,13 @@ func (q *baseQuery) getModel(dest []interface{}) (Model, error) { return newModel(q.db, dest) } +func (q *baseQuery) beforeAppendModel(query Query) error { + if q.tableModel != nil { + return q.tableModel.BeforeAppendModel(query) + } + return nil +} + //------------------------------------------------------------------------------ func (q *baseQuery) checkSoftDelete() error { @@ -462,7 +469,7 @@ func (q *baseQuery) _getFields(omitPK bool) ([]*schema.Field, error) { func (q *baseQuery) scan( ctx context.Context, - iquery IQuery, + iquery Query, query string, model Model, hasDest bool, @@ -494,7 +501,7 @@ func (q *baseQuery) scan( func (q *baseQuery) exec( ctx context.Context, - iquery IQuery, + iquery Query, query string, ) (sql.Result, error) { ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, q.model) diff --git a/query_delete.go b/query_delete.go index ef42e644..0221696a 100644 --- a/query_delete.go +++ b/query_delete.go @@ -134,9 +134,13 @@ func (q *DeleteQuery) Operation() string { } func (q *DeleteQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if err := q.beforeAppendModel(q); err != nil { + return nil, err + } if q.err != nil { return nil, q.err } + fmter = formatterWithModel(fmter, q) if q.isSoftDelete() { diff --git a/query_insert.go b/query_insert.go index 79d9a430..2f4eda03 100644 --- a/query_insert.go +++ b/query_insert.go @@ -155,6 +155,10 @@ func (q *InsertQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, e if q.err != nil { return nil, q.err } + if err := q.beforeAppendModel(q); err != nil { + return nil, err + } + fmter = formatterWithModel(fmter, q) b, err = q.appendWith(fmter, b) diff --git a/query_select.go b/query_select.go index 59470ff5..3bfe5e4a 100644 --- a/query_select.go +++ b/query_select.go @@ -364,15 +364,18 @@ func (q *SelectQuery) Operation() string { } func (q *SelectQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + if err := q.beforeAppendModel(q); err != nil { + return nil, err + } return q.appendQuery(fmter, b, false) } func (q *SelectQuery) appendQuery( fmter schema.Formatter, b []byte, count bool, ) (_ []byte, err error) { - if q.err != nil { - return nil, q.err - } fmter = formatterWithModel(fmter, q) cteCount := count && (len(q.group) > 0 || q.distinctOn != nil) @@ -854,6 +857,12 @@ type countQuery struct { } func (q countQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + // if err := q.beforeAppendModel(q); err != nil { + // return nil, err + // } return q.appendQuery(fmter, b, true) } @@ -864,6 +873,13 @@ type existsQuery struct { } func (q existsQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + // if err := q.beforeAppendModel(q); err != nil { + // return nil, err + // } + b = append(b, "SELECT EXISTS ("...) b, err = q.appendQuery(fmter, b, false) diff --git a/query_update.go b/query_update.go index ae9c0f44..a6065968 100644 --- a/query_update.go +++ b/query_update.go @@ -170,9 +170,13 @@ func (q *UpdateQuery) Operation() string { } func (q *UpdateQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if err := q.beforeAppendModel(q); err != nil { + return nil, err + } if q.err != nil { return nil, q.err } + fmter = formatterWithModel(fmter, q) b, err = q.appendWith(fmter, b) diff --git a/schema/hook.go b/schema/hook.go index 5391981d..def4075b 100644 --- a/schema/hook.go +++ b/schema/hook.go @@ -2,9 +2,32 @@ package schema import ( "context" + "database/sql" "reflect" ) +type Model interface { + ScanRows(ctx context.Context, rows *sql.Rows) (int, error) + Value() interface{} +} + +type Query interface { + QueryAppender + Operation() string + GetModel() Model + GetTableName() string +} + +//------------------------------------------------------------------------------ + +type BeforeAppendModelHook interface { + BeforeAppendModel(query Query) error +} + +var beforeAppendModelHookType = reflect.TypeOf((*BeforeAppendModelHook)(nil)).Elem() + +//------------------------------------------------------------------------------ + type BeforeScanHook interface { BeforeScan(context.Context) error } @@ -18,3 +41,19 @@ type AfterScanHook interface { } var afterScanHookType = reflect.TypeOf((*AfterScanHook)(nil)).Elem() + +//------------------------------------------------------------------------------ + +type BeforeScanRowHook interface { + BeforeScanRow(context.Context) error +} + +var beforeScanRowHookType = reflect.TypeOf((*BeforeScanRowHook)(nil)).Elem() + +//------------------------------------------------------------------------------ + +type AfterScanRowHook interface { + AfterScanRow(context.Context) error +} + +var afterScanRowHookType = reflect.TypeOf((*AfterScanRowHook)(nil)).Elem() diff --git a/schema/table.go b/schema/table.go index 827c12f6..d512b567 100644 --- a/schema/table.go +++ b/schema/table.go @@ -15,8 +15,11 @@ import ( ) const ( - beforeScanHookFlag internal.Flag = 1 << iota + beforeAppendModelHookFlag internal.Flag = 1 << iota + beforeScanHookFlag afterScanHookFlag + beforeScanRowHookFlag + afterScanRowHookFlag ) var ( @@ -84,8 +87,13 @@ func newTable(dialect Dialect, typ reflect.Type) *Table { typ reflect.Type flag internal.Flag }{ + {beforeAppendModelHookType, beforeAppendModelHookFlag}, + {beforeScanHookType, beforeScanHookFlag}, {afterScanHookType, afterScanHookFlag}, + + {beforeScanRowHookType, beforeScanRowHookFlag}, + {afterScanRowHookType, afterScanRowHookFlag}, } typ = reflect.PtrTo(t.Type) @@ -95,6 +103,22 @@ func newTable(dialect Dialect, typ reflect.Type) *Table { } } + // Deprecated. + deprecatedHooks := []struct { + typ reflect.Type + flag internal.Flag + msg string + }{ + {beforeScanHookType, beforeScanHookFlag, "rename BeforeScan hook to BeforeScanRow"}, + {afterScanHookType, afterScanHookFlag, "rename AfterScan hook to AfterScanRow"}, + } + for _, hook := range deprecatedHooks { + if typ.Implements(hook.typ) { + internal.Deprecated.Printf("%s: %s", t.TypeName, hook.msg) + t.flags = t.flags.Set(hook.flag) + } + } + return t } @@ -777,9 +801,18 @@ func (t *Table) inlineFields(field *Field, seen map[reflect.Type]struct{}) { //------------------------------------------------------------------------------ -func (t *Table) Dialect() Dialect { return t.dialect } +func (t *Table) Dialect() Dialect { return t.dialect } + +func (t *Table) HasBeforeAppendModelHook() bool { return t.flags.Has(beforeAppendModelHookFlag) } + +// DEPRECATED. Use HasBeforeScanRowHook. func (t *Table) HasBeforeScanHook() bool { return t.flags.Has(beforeScanHookFlag) } -func (t *Table) HasAfterScanHook() bool { return t.flags.Has(afterScanHookFlag) } + +// DEPRECATED. Use HasAfterScanRowHook. +func (t *Table) HasAfterScanHook() bool { return t.flags.Has(afterScanHookFlag) } + +func (t *Table) HasBeforeScanRowHook() bool { return t.flags.Has(beforeScanRowHookFlag) } +func (t *Table) HasAfterScanRowHook() bool { return t.flags.Has(afterScanRowHookFlag) } //------------------------------------------------------------------------------