diff --git a/activerecord/CHANGELOG.md b/activerecord/CHANGELOG.md index a948b9babf6c1..ebd780e51f2b8 100644 --- a/activerecord/CHANGELOG.md +++ b/activerecord/CHANGELOG.md @@ -1,3 +1,7 @@ +* Add `ActiveRecord::Base.prohibit_shard_swapping` to prevent attempts to change the shard within a block. + + *John Crepezzi*, *Eileen M. Uchitelle* + * Filter unchanged attributes with default function from insert query when `partial_inserts` is disabled. *Akshay Birajdar*, *Jacopo Beschi* diff --git a/activerecord/lib/active_record/connection_handling.rb b/activerecord/lib/active_record/connection_handling.rb index 0c1e5da116f97..1083c56639456 100644 --- a/activerecord/lib/active_record/connection_handling.rb +++ b/activerecord/lib/active_record/connection_handling.rb @@ -182,7 +182,7 @@ def connected_to_many(*classes, role:, shard: nil, prevent_writes: false) prevent_writes = true if role == ActiveRecord.reading_role - connected_to_stack << { role: role, shard: shard, prevent_writes: prevent_writes, klasses: classes } + append_to_connected_to_stack(role: role, shard: shard, prevent_writes: prevent_writes, klasses: classes) yield ensure connected_to_stack.pop @@ -202,7 +202,25 @@ def connecting_to(role: default_role, shard: default_shard, prevent_writes: fals prevent_writes = true if role == ActiveRecord.reading_role - self.connected_to_stack << { role: role, shard: shard, prevent_writes: prevent_writes, klasses: [self] } + append_to_connected_to_stack(role: role, shard: shard, prevent_writes: prevent_writes, klasses: [self]) + end + + # Prohibit swapping shards while inside of the passed block. + # + # In some cases you may want to be able to swap shards but not allow a + # nested call to connected_to or connected_to_many to swap again. This + # is useful in cases you're using sharding to provide per-request + # database isolation. + def prohibit_shard_swapping + Thread.current.thread_variable_set(:prohibit_shard_swapping, true) + yield + ensure + Thread.current.thread_variable_set(:prohibit_shard_swapping, false) + end + + # Determine whether or not shard swapping is currently prohibited + def shard_swapping_prohibited? + Thread.current.thread_variable_get(:prohibit_shard_swapping) end # Prevent writing to the database regardless of role. @@ -357,12 +375,12 @@ def with_role_and_shard(role, shard, prevent_writes) if ActiveRecord.legacy_connection_handling with_handler(role.to_sym) do connection_handler.while_preventing_writes(prevent_writes) do - self.connected_to_stack << { shard: shard, klasses: [self] } + append_to_connected_to_stack(shard: shard, klasses: [self]) yield end end else - self.connected_to_stack << { role: role, shard: shard, prevent_writes: prevent_writes, klasses: [self] } + append_to_connected_to_stack(role: role, shard: shard, prevent_writes: prevent_writes, klasses: [self]) return_value = yield return_value.load if return_value.is_a? ActiveRecord::Relation return_value @@ -371,6 +389,14 @@ def with_role_and_shard(role, shard, prevent_writes) self.connected_to_stack.pop end + def append_to_connected_to_stack(entry) + if shard_swapping_prohibited? && entry[:shard].present? + raise ArgumentError, "cannot swap `shard` while shard swapping is prohibited." + end + + connected_to_stack << entry + end + def swap_connection_handler(handler, &blk) # :nodoc: old_handler, ActiveRecord::Base.connection_handler = ActiveRecord::Base.connection_handler, handler return_value = yield diff --git a/activerecord/test/cases/connection_adapters/connection_handlers_sharding_db_test.rb b/activerecord/test/cases/connection_adapters/connection_handlers_sharding_db_test.rb index 965b198751f74..c678fb986901f 100644 --- a/activerecord/test/cases/connection_adapters/connection_handlers_sharding_db_test.rb +++ b/activerecord/test/cases/connection_adapters/connection_handlers_sharding_db_test.rb @@ -251,6 +251,59 @@ def test_calling_connected_to_on_a_non_existent_shard_raises assert_equal "No connection pool for 'ActiveRecord::Base' found for the 'foo' shard.", error.message end + + def test_cannot_swap_shards_while_prohibited + previous_env, ENV["RAILS_ENV"] = ENV["RAILS_ENV"], "default_env" + + config = { + "default_env" => { + "primary" => { "adapter" => "sqlite3", "database" => "test/db/primary.sqlite3" }, + "primary_shard_one" => { "adapter" => "sqlite3", "database" => "test/db/primary_shard_one.sqlite3" } + } + } + + @prev_configs, ActiveRecord::Base.configurations = ActiveRecord::Base.configurations, config + + ActiveRecord::Base.connects_to(shards: { + default: { writing: :primary }, + shard_one: { writing: :primary_shard_one } + }) + + assert_raises(ArgumentError) do + ActiveRecord::Base.prohibit_shard_swapping do + ActiveRecord::Base.connected_to(role: :reading, shard: :default) do + end + end + end + ensure + ActiveRecord::Base.configurations = @prev_configs + ActiveRecord::Base.establish_connection(:arunit) + ENV["RAILS_ENV"] = previous_env + end + + def test_can_swap_roles_while_shard_swapping_is_prohibited + previous_env, ENV["RAILS_ENV"] = ENV["RAILS_ENV"], "default_env" + + config = { + "default_env" => { + "primary" => { "adapter" => "sqlite3", "database" => "test/db/primary.sqlite3" }, + "primary_replica" => { "adapter" => "sqlite3", "database" => "test/db/primary.sqlite3", "replica" => true } + } + } + + @prev_configs, ActiveRecord::Base.configurations = ActiveRecord::Base.configurations, config + + ActiveRecord::Base.connects_to(shards: { default: { writing: :primary, reading: :primary_replica } }) + + ActiveRecord::Base.prohibit_shard_swapping do # no exception + ActiveRecord::Base.connected_to(role: :reading) do + end + end + ensure + ActiveRecord::Base.configurations = @prev_configs + ActiveRecord::Base.establish_connection(:arunit) + ENV["RAILS_ENV"] = previous_env + end end class SecondaryBase < ActiveRecord::Base