diff --git a/src/model-internals.ts b/src/model-internals.ts index 3d5f7594a30a..b3fbef795adc 100644 --- a/src/model-internals.ts +++ b/src/model-internals.ts @@ -1,7 +1,9 @@ import NodeUtil from 'util'; import { EagerLoadingError } from './errors'; +import type { Transactionable } from './model'; +import type { Sequelize } from './sequelize'; import { isModelStatic } from './utils/model-utils.js'; - +import type { Transaction } from './index'; // TODO: strictly type this file during the TS migration of model.js // The goal of this file is to include the different private methods that are currently present on the Model class. @@ -145,3 +147,12 @@ export function throwInvalidInclude(include: any): never { throw new EagerLoadingError(`Invalid Include received. Include has to be either a Model, an Association, the name of an association, or a plain object compatible with IncludeOptions. Got ${NodeUtil.inspect(include)} instead`); } + +export function setTransactionFromCls(options: Transactionable, sequelize: Sequelize): void { + if (options.transaction === undefined && sequelize.Sequelize._cls) { + const t = sequelize.Sequelize._cls.get('transaction'); + if (t) { + options.transaction = t as Transaction; + } + } +} diff --git a/src/model.js b/src/model.js index 09b0732efdca..777df609910a 100644 --- a/src/model.js +++ b/src/model.js @@ -16,7 +16,7 @@ const sequelizeErrors = require('./errors'); const DataTypes = require('./data-types'); const Hooks = require('./hooks'); const { Op } = require('./operators'); -const { _validateIncludedElements, combineIncludes, throwInvalidInclude } = require('./model-internals'); +const { _validateIncludedElements, combineIncludes, throwInvalidInclude, setTransactionFromCls } = require('./model-internals'); const { noDoubleNestedGroup, scopeRenamedToWithScope, schemaRenamedToWithSchema, noModelDropSchema } = require('./utils/deprecations'); // This list will quickly become dated, but failing to maintain this list just means @@ -1729,6 +1729,9 @@ Specify a different name for either index to resolve this issue.`); tableNames[this.getTableName(options)] = true; options = Utils.cloneDeep(options); + // Add CLS transaction + setTransactionFromCls(options, this.sequelize); + _.defaults(options, { hooks: true, model: this }); // set rejectOnEmpty option, defaults to model options @@ -2044,6 +2047,10 @@ Specify a different name for either index to resolve this issue.`); static async count(options) { options = Utils.cloneDeep(options); options = _.defaults(options, { hooks: true }); + + // Add CLS transaction + setTransactionFromCls(options, this.sequelize); + options.raw = true; if (options.hooks) { await this.runHooks('beforeCount', options); @@ -2478,6 +2485,9 @@ Specify a different name for either index to resolve this issue.`); ...Utils.cloneDeep(options), }; + // Add CLS transaction + setTransactionFromCls(options, this.sequelize); + const createdAtAttr = this._timestampAttributes.createdAt; const updatedAtAttr = this._timestampAttributes.updatedAt; const hasPrimary = this.primaryKeyField in values || this.primaryKeyAttribute in values; @@ -2571,6 +2581,9 @@ Specify a different name for either index to resolve this issue.`); const now = Utils.now(this.sequelize.options.dialect); options = Utils.cloneDeep(options); + // Add CLS transaction + setTransactionFromCls(options, this.sequelize); + options.model = this; if (!options.includeValidated) { @@ -2918,6 +2931,9 @@ Specify a different name for either index to resolve this issue.`); static async destroy(options) { options = Utils.cloneDeep(options); + // Add CLS transaction + setTransactionFromCls(options, this.sequelize); + this._injectScope(options); if (!options || !(options.where || options.truncate)) { @@ -3006,6 +3022,9 @@ Specify a different name for either index to resolve this issue.`); ...options, }; + // Add CLS transaction + setTransactionFromCls(options, this.sequelize); + options.type = QueryTypes.RAW; options.model = this; @@ -3063,6 +3082,9 @@ Specify a different name for either index to resolve this issue.`); static async update(values, options) { options = Utils.cloneDeep(options); + // Add CLS transaction + setTransactionFromCls(options, this.sequelize); + this._injectScope(options); this._optionsMustContainWhere(options); @@ -3930,6 +3952,9 @@ Instead of specifying a Model, either: validate: true, }); + // Add CLS transaction + setTransactionFromCls(options, this.sequelize); + if (!options.fields) { if (this.isNewRecord) { options.fields = Object.keys(this.constructor.rawAttributes); @@ -4290,6 +4315,9 @@ Instead of specifying a Model, either: ...options, }; + // Add CLS transaction + setTransactionFromCls(options, this.sequelize); + // Run before hook if (options.hooks) { await this.constructor.runHooks('beforeDestroy', this, options); @@ -4365,6 +4393,9 @@ Instead of specifying a Model, either: ...options, }; + // Add CLS transaction + setTransactionFromCls(options, this.sequelize); + // Run before hook if (options.hooks) { await this.constructor.runHooks('beforeRestore', this, options); diff --git a/test/integration/cls.test.js b/test/integration/cls.test.js index 93110c24bf8c..b1dd28a2731c 100644 --- a/test/integration/cls.test.js +++ b/test/integration/cls.test.js @@ -141,6 +141,255 @@ if (current.dialect.supports.transactions) { }); }); + // reason for this test: https://github.com/sequelize/sequelize/issues/12973 + describe('Model Hook integration', () => { + + function testHooks({ method, hooks: hookNames, optionPos, execute, getModel }) { + it(`passes the transaction to hooks {${hookNames.join(',')}} when calling ${method}`, async function () { + await this.sequelize.transaction(async transaction => { + const hooks = Object.create(null); + + for (const hookName of hookNames) { + hooks[hookName] = sinon.spy(); + } + + const User = Reflect.apply(getModel, this, []); + + for (const [hookName, spy] of Object.entries(hooks)) { + User[hookName](spy); + } + + await Reflect.apply(execute, this, [User]); + + const spyMatcher = []; + // ignore all arguments until we get to the option bag. + for (let i = 0; i < optionPos; i++) { + spyMatcher.push(sinon.match.any); + } + + // find the transaction in the option bag + spyMatcher.push(sinon.match.has('transaction', transaction)); + + for (const [hookName, spy] of Object.entries(hooks)) { + expect( + spy, + `hook ${hookName} did not receive the transaction from CLS.`, + ).to.have.been.calledWith(...spyMatcher); + } + }); + }); + } + + testHooks({ + method: 'Model.bulkCreate', + hooks: ['beforeBulkCreate', 'beforeCreate', 'afterCreate', 'afterBulkCreate'], + optionPos: 1, + async execute(User) { + await User.bulkCreate([{ name: 'bob' }], { individualHooks: true }); + }, + getModel() { + return this.User; + }, + }); + + testHooks({ + method: 'Model.findAll', + hooks: ['beforeFind', 'beforeFindAfterExpandIncludeAll', 'beforeFindAfterOptions'], + optionPos: 0, + async execute(User) { + await User.findAll(); + }, + getModel() { + return this.User; + }, + }); + + testHooks({ + method: 'Model.findAll', + hooks: ['afterFind'], + optionPos: 1, + async execute(User) { + await User.findAll(); + }, + getModel() { + return this.User; + }, + }); + + testHooks({ + method: 'Model.count', + hooks: ['beforeCount'], + optionPos: 0, + async execute(User) { + await User.count(); + }, + getModel() { + return this.User; + }, + }); + + testHooks({ + method: 'Model.upsert', + hooks: ['beforeUpsert', 'afterUpsert'], + optionPos: 1, + async execute(User) { + await User.upsert({ + id: 1, + name: 'bob', + }); + }, + getModel() { + return this.User; + }, + }); + + testHooks({ + method: 'Model.destroy', + hooks: ['beforeBulkDestroy', 'afterBulkDestroy'], + optionPos: 0, + async execute(User) { + await User.destroy({ where: { name: 'bob' } }); + }, + getModel() { + return this.User; + }, + }); + + testHooks({ + method: 'Model.destroy with individualHooks', + hooks: ['beforeDestroy', 'beforeDestroy'], + optionPos: 1, + async execute(User) { + await User.create({ name: 'bob' }); + await User.destroy({ where: { name: 'bob' }, individualHooks: true }); + }, + getModel() { + return this.User; + }, + }); + + testHooks({ + method: 'Model#destroy', + hooks: ['beforeDestroy', 'beforeDestroy'], + optionPos: 1, + async execute(User) { + const user = await User.create({ name: 'bob' }); + await user.destroy(); + }, + getModel() { + return this.User; + }, + }); + + testHooks({ + method: 'Model.update', + hooks: ['beforeBulkUpdate', 'afterBulkUpdate'], + optionPos: 0, + async execute(User) { + await User.update({ name: 'alice' }, { where: { name: 'bob' } }); + }, + getModel() { + return this.User; + }, + }); + + testHooks({ + method: 'Model.update with individualHooks', + hooks: ['beforeUpdate', 'afterUpdate'], + optionPos: 1, + async execute(User) { + await User.create({ name: 'bob' }); + await User.update({ name: 'alice' }, { where: { name: 'bob' }, individualHooks: true }); + }, + getModel() { + return this.User; + }, + }); + + testHooks({ + method: 'Model#save (isNewRecord)', + hooks: ['beforeCreate', 'afterCreate'], + optionPos: 1, + async execute(User) { + const user = User.build({ name: 'bob' }); + user.name = 'alice'; + await user.save(); + }, + getModel() { + return this.User; + }, + }); + + testHooks({ + method: 'Model#save (!isNewRecord)', + hooks: ['beforeUpdate', 'afterUpdate'], + optionPos: 1, + async execute(User) { + const user = await User.create({ name: 'bob' }); + user.name = 'alice'; + await user.save(); + }, + getModel() { + return this.User; + }, + }); + + describe('paranoid restore', () => { + beforeEach(async function () { + this.ParanoidUser = this.sequelize.define('ParanoidUser', { + name: DataTypes.STRING, + }, { paranoid: true }); + + await this.ParanoidUser.sync({ force: true }); + }); + + testHooks({ + method: 'Model.restore', + hooks: ['beforeBulkRestore', 'afterBulkRestore'], + optionPos: 0, + async execute() { + const User = this.ParanoidUser; + await User.restore({ where: { name: 'bob' } }); + }, + getModel() { + return this.ParanoidUser; + }, + }); + + testHooks({ + method: 'Model.restore with individualHooks', + hooks: ['beforeRestore', 'afterRestore'], + optionPos: 1, + async execute() { + const User = this.ParanoidUser; + + await User.create({ name: 'bob' }); + await User.destroy({ where: { name: 'bob' } }); + await User.restore({ where: { name: 'bob' }, individualHooks: true }); + }, + getModel() { + return this.ParanoidUser; + }, + }); + + testHooks({ + method: 'Model#restore', + hooks: ['beforeRestore', 'afterRestore'], + optionPos: 1, + async execute() { + const User = this.ParanoidUser; + + const user = await User.create({ name: 'bob' }); + await user.destroy(); + await user.restore(); + }, + getModel() { + return this.ParanoidUser; + }, + }); + }); + }); + it('CLS namespace is stored in Sequelize._cls', function () { expect(Sequelize._cls).to.equal(this.ns); });