diff --git a/src/model.js b/src/model.js index 25b5d1776766..2cab12accbc7 100644 --- a/src/model.js +++ b/src/model.js @@ -1759,6 +1759,14 @@ class Model { tableNames[this.getTableName(options)] = true; options = Utils.cloneDeep(options); + // Add CLS transaction + if (options.transaction === undefined && this.sequelize.constructor._cls) { + const t = this.sequelize.constructor._cls.get('transaction'); + if (t) { + options.transaction = t; + } + } + _.defaults(options, { hooks: true }); // set rejectOnEmpty option, defaults to model options @@ -1965,6 +1973,14 @@ class Model { } options = Utils.cloneDeep(options); + // Add CLS transaction + if (options.transaction === undefined && this.sequelize.constructor._cls) { + const t = this.sequelize.constructor._cls.get('transaction'); + if (t) { + options.transaction = t; + } + } + if (options.limit === undefined) { const uniqueSingleColumns = _.chain(this.uniqueKeys).values().filter(c => c.fields.length === 1).map('column').value(); @@ -2075,6 +2091,15 @@ class Model { static async count(options) { options = Utils.cloneDeep(options); options = _.defaults(options, { hooks: true }); + + // Add CLS transaction + if (options.transaction === undefined && this.sequelize.constructor._cls) { + const t = this.sequelize.constructor._cls.get('transaction'); + if (t) { + options.transaction = t; + } + } + options.raw = true; if (options.hooks) { await this.runHooks('beforeCount', options); @@ -2521,6 +2546,14 @@ class Model { ...Utils.cloneDeep(options) }; + // Add CLS transaction + if (options.transaction === undefined && this.sequelize.constructor._cls) { + const t = this.sequelize.constructor._cls.get('transaction'); + if (t) { + options.transaction = t; + } + } + const createdAtAttr = this._timestampAttributes.createdAt; const updatedAtAttr = this._timestampAttributes.updatedAt; const hasPrimary = this.primaryKeyField in values || this.primaryKeyAttribute in values; @@ -2616,6 +2649,14 @@ class Model { const now = Utils.now(this.sequelize.options.dialect); options = Utils.cloneDeep(options); + // Add CLS transaction + if (options.transaction === undefined && this.sequelize.constructor._cls) { + const t = this.sequelize.constructor._cls.get('transaction'); + if (t) { + options.transaction = t; + } + } + options.model = this; if (!options.includeValidated) { @@ -2972,6 +3013,14 @@ class Model { static async destroy(options) { options = Utils.cloneDeep(options); + // Add CLS transaction + if (options.transaction === undefined && this.sequelize.constructor._cls) { + const t = this.sequelize.constructor._cls.get('transaction'); + if (t) { + options.transaction = t; + } + } + this._injectScope(options); if (!options || !(options.where || options.truncate)) { @@ -3062,6 +3111,14 @@ class Model { ...options }; + // Add CLS transaction + if (options.transaction === undefined && this.sequelize.constructor._cls) { + const t = this.sequelize.constructor._cls.get('transaction'); + if (t) { + options.transaction = t; + } + } + options.type = QueryTypes.RAW; options.model = this; @@ -3127,6 +3184,14 @@ class Model { static async update(values, options) { options = Utils.cloneDeep(options); + // Add CLS transaction + if (options.transaction === undefined && this.sequelize.constructor._cls) { + const t = this.sequelize.constructor._cls.get('transaction'); + if (t) { + options.transaction = t; + } + } + this._injectScope(options); this._optionsMustContainWhere(options); @@ -3917,6 +3982,15 @@ class Model { } options = Utils.cloneDeep(options); + + // Add CLS transaction + if (options.transaction === undefined && this.sequelize.constructor._cls) { + const t = this.sequelize.constructor._cls.get('transaction'); + if (t) { + options.transaction = t; + } + } + options = _.defaults(options, { hooks: true, validate: true @@ -4237,6 +4311,15 @@ class Model { if (Array.isArray(options)) options = { fields: options }; options = Utils.cloneDeep(options); + + // Add CLS transaction + if (options.transaction === undefined && this.sequelize.constructor._cls) { + const t = this.sequelize.constructor._cls.get('transaction'); + if (t) { + options.transaction = t; + } + } + const setOptions = Utils.cloneDeep(options); setOptions.attributes = options.fields; this.set(values, setOptions); @@ -4271,6 +4354,14 @@ class Model { ...options }; + // Add CLS transaction + if (options.transaction === undefined && this.sequelize.constructor._cls) { + const t = this.sequelize.constructor._cls.get('transaction'); + if (t) { + options.transaction = t; + } + } + // Run before hook if (options.hooks) { await this.constructor.runHooks('beforeDestroy', this, options); @@ -4340,6 +4431,14 @@ class Model { ...options }; + // Add CLS transaction + if (options.transaction === undefined && this.sequelize.constructor._cls) { + const t = this.sequelize.constructor._cls.get('transaction'); + if (t) { + options.transaction = t; + } + } + // Run before hook if (options.hooks) { await this.constructor.runHooks('beforeRestore', this, options); diff --git a/test/integration/cls.test.js b/test/integration/cls.test.js index f84c6f4d59e8..1a2f2b7bce56 100644 --- a/test/integration/cls.test.js +++ b/test/integration/cls.test.js @@ -136,6 +136,254 @@ if (current.dialect.supports.transactions) { }); }); + describe('Model Hook integration', () => { + + function testHooks({ method, hooks: hookNames, optionPos, execute, getModel }) { + it(`passes the transaction to hooks {${hookNames.join(',')}} when calling ${method}`, async function() { + await this.sequelize.transaction(async transaction => { + const hooks = Object.create(null); + + for (const hookName of hookNames) { + hooks[hookName] = sinon.spy(); + } + + const User = Reflect.apply(getModel, this, []); + + for (const [hookName, spy] of Object.entries(hooks)) { + User[hookName](spy); + } + + await Reflect.apply(execute, this, [User]); + + const spyMatcher = []; + // ignore all arguments until we get to the option bag. + for (let i = 0; i < optionPos; i++) { + spyMatcher.push(sinon.match.any); + } + + // find the transaction in the option bag + spyMatcher.push(sinon.match.has('transaction', transaction)); + + for (const [hookName, spy] of Object.entries(hooks)) { + expect( + spy, + `hook ${hookName} did not receive the transaction from CLS.` + ).to.have.been.calledWith(...spyMatcher); + } + }); + }); + } + + testHooks({ + method: 'Model.bulkCreate', + hooks: ['beforeBulkCreate', 'beforeCreate', 'afterCreate', 'afterBulkCreate'], + optionPos: 1, + async execute(User) { + await User.bulkCreate([{ name: 'bob' }], { individualHooks: true }); + }, + getModel() { + return this.User; + } + }); + + testHooks({ + method: 'Model.findAll', + hooks: ['beforeFind', 'beforeFindAfterExpandIncludeAll', 'beforeFindAfterOptions'], + optionPos: 0, + async execute(User) { + await User.findAll(); + }, + getModel() { + return this.User; + } + }); + + testHooks({ + method: 'Model.findAll', + hooks: ['afterFind'], + optionPos: 1, + async execute(User) { + await User.findAll(); + }, + getModel() { + return this.User; + } + }); + + testHooks({ + method: 'Model.count', + hooks: ['beforeCount'], + optionPos: 0, + async execute(User) { + await User.count(); + }, + getModel() { + return this.User; + } + }); + + testHooks({ + method: 'Model.upsert', + hooks: ['beforeUpsert', 'afterUpsert'], + optionPos: 1, + async execute(User) { + await User.upsert({ + id: 1, + name: 'bob' + }); + }, + getModel() { + return this.User; + } + }); + + testHooks({ + method: 'Model.destroy', + hooks: ['beforeBulkDestroy', 'afterBulkDestroy'], + optionPos: 0, + async execute(User) { + await User.destroy({ where: { name: 'bob' } }); + }, + getModel() { + return this.User; + } + }); + + testHooks({ + method: 'Model.destroy with individualHooks', + hooks: ['beforeDestroy', 'beforeDestroy'], + optionPos: 1, + async execute(User) { + await User.create({ name: 'bob' }); + await User.destroy({ where: { name: 'bob' }, individualHooks: true }); + }, + getModel() { + return this.User; + } + }); + + testHooks({ + method: 'Model#destroy', + hooks: ['beforeDestroy', 'beforeDestroy'], + optionPos: 1, + async execute(User) { + const user = await User.create({ name: 'bob' }); + await user.destroy(); + }, + getModel() { + return this.User; + } + }); + + testHooks({ + method: 'Model.update', + hooks: ['beforeBulkUpdate', 'afterBulkUpdate'], + optionPos: 0, + async execute(User) { + await User.update({ name: 'alice' }, { where: { name: 'bob' } }); + }, + getModel() { + return this.User; + } + }); + + testHooks({ + method: 'Model.update with individualHooks', + hooks: ['beforeUpdate', 'afterUpdate'], + optionPos: 1, + async execute(User) { + await User.create({ name: 'bob' }); + await User.update({ name: 'alice' }, { where: { name: 'bob' }, individualHooks: true }); + }, + getModel() { + return this.User; + } + }); + + testHooks({ + method: 'Model#save (isNewRecord)', + hooks: ['beforeCreate', 'afterCreate'], + optionPos: 1, + async execute(User) { + const user = User.build({ name: 'bob' }); + user.name = 'alice'; + await user.save(); + }, + getModel() { + return this.User; + } + }); + + testHooks({ + method: 'Model#save (!isNewRecord)', + hooks: ['beforeUpdate', 'afterUpdate'], + optionPos: 1, + async execute(User) { + const user = await User.create({ name: 'bob' }); + user.name = 'alice'; + await user.save(); + }, + getModel() { + return this.User; + } + }); + + describe('paranoid restore', () => { + beforeEach(async function() { + this.ParanoidUser = this.sequelize.define('ParanoidUser', { + name: Sequelize.STRING + }, { paranoid: true }); + + await this.ParanoidUser.sync({ force: true }); + }); + + testHooks({ + method: 'Model.restore', + hooks: ['beforeBulkRestore', 'afterBulkRestore'], + optionPos: 0, + async execute() { + const User = this.ParanoidUser; + await User.restore({ where: { name: 'bob' } }); + }, + getModel() { + return this.ParanoidUser; + } + }); + + testHooks({ + method: 'Model.restore with individualHooks', + hooks: ['beforeRestore', 'afterRestore'], + optionPos: 1, + async execute() { + const User = this.ParanoidUser; + + await User.create({ name: 'bob' }); + await User.destroy({ where: { name: 'bob' } }); + await User.restore({ where: { name: 'bob' }, individualHooks: true }); + }, + getModel() { + return this.ParanoidUser; + } + }); + + testHooks({ + method: 'Model#restore', + hooks: ['beforeRestore', 'afterRestore'], + optionPos: 1, + async execute() { + const User = this.ParanoidUser; + + const user = await User.create({ name: 'bob' }); + await user.destroy(); + await user.restore(); + }, + getModel() { + return this.ParanoidUser; + } + }); + }); + }); + it('CLS namespace is stored in Sequelize._cls', function() { expect(Sequelize._cls).to.equal(this.ns); });