Skip to content

Commit

Permalink
Prevent access using thread_mattr_accessor
Browse files Browse the repository at this point in the history
  • Loading branch information
stevecrozz committed Mar 23, 2024
1 parent 8fc596a commit 3c1bce0
Show file tree
Hide file tree
Showing 20 changed files with 74 additions and 108 deletions.
1 change: 1 addition & 0 deletions activerecord/lib/active_record.rb
Expand Up @@ -39,6 +39,7 @@
module ActiveRecord
extend ActiveSupport::Autoload

autoload :AccessPrevention
autoload :Base
autoload :Callbacks
autoload :ConnectionHandling
Expand Down
38 changes: 38 additions & 0 deletions activerecord/lib/active_record/access_prevention.rb
@@ -0,0 +1,38 @@
# frozen_string_literal: true

module ActiveRecord
class PreventedAccessError < ActiveRecordError # :nodoc:
end

# = Active Record Access Prevention
module AccessPrevention
extend ActiveSupport::Concern

thread_mattr_accessor :enabled, instance_accessor: false, default: false

module ClassMethods
# Lets you prevent database access from ActiveRecord for the duration of
# a block.
#
# ==== Examples
# ActiveRecord::Base.while_preventing_access do
# Project.first # raises an exception
# end
#
def while_preventing_access(&block)
previous_enabled = preventing_access?
AccessPrevention.enabled = true
yield
ensure
AccessPrevention.enabled = previous_enabled
end

# Determines whether access is currently being prevented.
#
# Returns the value of +enabled+.
def preventing_access?
AccessPrevention.enabled
end
end
end
end
1 change: 1 addition & 0 deletions activerecord/lib/active_record/base.rb
Expand Up @@ -330,6 +330,7 @@ class Base
include Suppressor
include Normalization
include Marshalling::Methods
include AccessPrevention

self.param_delimiter = "_"
end
Expand Down
Expand Up @@ -69,10 +69,6 @@ def primary_class?
def current_preventing_writes
false
end

def current_preventing_access
false
end
end

def initialize
Expand Down
Expand Up @@ -514,7 +514,8 @@ def internal_exec_query(sql, name = "SQL", binds = [], prepare: false, async: fa
private
def internal_execute(sql, name = "SCHEMA", allow_retry: false, materialize_transactions: true)
sql = transform_query(sql)
check_if_query_prevented(sql)
check_if_access_prevented(sql)
check_if_write_query(sql)

mark_transaction_written_if_write(sql)

Expand Down
Expand Up @@ -197,12 +197,14 @@ def with_instrumenter(instrumenter, &block) # :nodoc:
end
end

def check_if_query_prevented(sql) # :nodoc:
def check_if_write_query(sql) # :nodoc:
if preventing_writes? && write_query?(sql)
raise ActiveRecord::ReadOnlyError, "Write query attempted while in readonly mode: #{sql}"
end
end

if preventing_access?
def check_if_access_prevented(sql) # :nodoc:
if AccessPrevention.enabled
raise ActiveRecord::PreventedAccessError, "Query attempted while preventing access: #{sql}"
end
end
Expand Down Expand Up @@ -238,15 +240,6 @@ def preventing_writes?
connection_class.current_preventing_writes
end

# Determines whether access is currently being prevented.
#
# Returns the value of +current_preventing_access+.
def preventing_access?
return false if connection_class.nil?

connection_class.current_preventing_access
end

def prepared_statements?
@prepared_statements && !prepared_statements_disabled_cache.include?(object_id)
end
Expand Down
Expand Up @@ -231,7 +231,8 @@ def disable_referential_integrity # :nodoc:
# needs to be explicitly freed or not.
def execute_and_free(sql, name = nil, async: false) # :nodoc:
sql = transform_query(sql)
check_if_query_prevented(sql)
check_if_access_prevented(sql)
check_if_write_query(sql)

mark_transaction_written_if_write(sql)
yield raw_execute(sql, name, async: async)
Expand Down
Expand Up @@ -112,7 +112,8 @@ def raw_execute(sql, name, async: false, allow_retry: false, materialize_transac

def exec_stmt_and_free(sql, name, binds, cache_stmt: false, async: false)
sql = transform_query(sql)
check_if_query_prevented(sql)
check_if_access_prevented(sql)
check_if_write_query(sql)

mark_transaction_written_if_write(sql)

Expand Down
Expand Up @@ -857,6 +857,7 @@ def load_types_queries(initializer, oids)

def execute_and_clear(sql, name, binds, prepare: false, async: false, allow_retry: false, materialize_transactions: true)
sql = transform_query(sql)
check_if_write_query(sql)
check_if_query_prevented(sql)

if !prepare || without_prepared_statement?(binds)
Expand Down
Expand Up @@ -23,7 +23,8 @@ def explain(arel, binds = [], _options = [])

def internal_exec_query(sql, name = nil, binds = [], prepare: false, async: false) # :nodoc:
sql = transform_query(sql)
check_if_query_prevented(sql)
check_if_access_prevented(sql)
check_if_write_query(sql)

mark_transaction_written_if_write(sql)

Expand Down Expand Up @@ -136,7 +137,8 @@ def execute_batch(statements, name = nil)
statements = statements.map { |sql| transform_query(sql) }
sql = combine_multi_statements(statements)

check_if_query_prevented(sql)
check_if_access_prevented(sql)
check_if_write_query(sql)
mark_transaction_written_if_write(sql)

log(sql, name) do |notification_payload|
Expand Down
Expand Up @@ -14,7 +14,8 @@ def select_all(*, **) # :nodoc:

def internal_exec_query(sql, name = "SQL", binds = [], prepare: false, async: false) # :nodoc:
sql = transform_query(sql)
check_if_query_prevented(sql)
check_if_access_prevented(sql)
check_if_write_query(sql)
mark_transaction_written_if_write(sql)

result = raw_execute(sql, name, async: async)
Expand All @@ -23,7 +24,8 @@ def internal_exec_query(sql, name = "SQL", binds = [], prepare: false, async: fa

def exec_insert(sql, name, binds, pk = nil, sequence_name = nil, returning: nil) # :nodoc:
sql = transform_query(sql)
check_if_query_prevented(sql)
check_if_access_prevented(sql)
check_if_write_query(sql)
mark_transaction_written_if_write(sql)

sql, _binds = sql_for_insert(sql, pk, binds, returning)
Expand All @@ -32,7 +34,8 @@ def exec_insert(sql, name, binds, pk = nil, sequence_name = nil, returning: nil)

def exec_delete(sql, name = nil, binds = []) # :nodoc:
sql = transform_query(sql)
check_if_query_prevented(sql)
check_if_access_prevented(sql)
check_if_write_query(sql)
mark_transaction_written_if_write(sql)

result = raw_execute(to_sql(sql, binds), name)
Expand Down
28 changes: 9 additions & 19 deletions activerecord/lib/active_record/connection_handling.rb
Expand Up @@ -131,7 +131,7 @@ def connects_to(database: {}, shards: {})
# ActiveRecord::Base.connected_to(role: :reading, shard: :shard_one_replica) do
# Dog.first # finds first Dog record stored on the shard one replica
# end
def connected_to(role: nil, shard: nil, prevent_writes: false, prevent_access: false, &blk)
def connected_to(role: nil, shard: nil, prevent_writes: false, &blk)
if self != Base && !abstract_class
raise NotImplementedError, "calling `connected_to` is only allowed on ActiveRecord::Base or abstract classes."
end
Expand All @@ -144,13 +144,12 @@ def connected_to(role: nil, shard: nil, prevent_writes: false, prevent_access: f
raise ArgumentError, "must provide a `shard` and/or `role`."
end

with_role_and_shard(role, shard, prevent_writes: prevent_writes, prevent_access: prevent_access, &blk)
with_role_and_shard(role, shard, prevent_writes, &blk)
end

# Connects a role and/or shard to the provided connection names. Optionally +prevent_writes+
# can be passed to block writes on a connection. +reading+ will automatically set
# +prevent_writes+ to true. Optionally +prevent_access+ can be passed to block all access on a
# connection.
# +prevent_writes+ to true.
#
# +connected_to_many+ is an alternative to deeply nested +connected_to+ blocks.
#
Expand All @@ -161,7 +160,7 @@ def connected_to(role: nil, shard: nil, prevent_writes: false, prevent_access: f
# Dinner.first # Read from meals replica
# Person.first # Read from primary writer
# end
def connected_to_many(*classes, role:, shard: nil, prevent_writes: false, prevent_access: false)
def connected_to_many(*classes, role:, shard: nil, prevent_writes: false)
classes = classes.flatten

if self != Base || classes.include?(Base)
Expand All @@ -170,7 +169,7 @@ def connected_to_many(*classes, role:, shard: nil, prevent_writes: false, preven

prevent_writes = true if role == ActiveRecord.reading_role

append_to_connected_to_stack(role: role, shard: shard, prevent_writes: prevent_writes, prevent_access: prevent_access, klasses: classes)
append_to_connected_to_stack(role: role, shard: shard, prevent_writes: prevent_writes, klasses: classes)
yield
ensure
connected_to_stack.pop
Expand All @@ -183,10 +182,10 @@ def connected_to_many(*classes, role:, shard: nil, prevent_writes: false, preven
#
# It is not recommended to use this method in a request since it
# does not yield to a block like +connected_to+.
def connecting_to(role: default_role, shard: default_shard, prevent_writes: false, prevent_access: false)
def connecting_to(role: default_role, shard: default_shard, prevent_writes: false)
prevent_writes = true if role == ActiveRecord.reading_role

append_to_connected_to_stack(role: role, shard: shard, prevent_writes: prevent_writes, prevent_access: prevent_access, klasses: [self])
append_to_connected_to_stack(role: role, shard: shard, prevent_writes: prevent_writes, klasses: [self])
end

# Prohibit swapping shards while inside of the passed block.
Expand Down Expand Up @@ -223,15 +222,6 @@ def while_preventing_writes(enabled = true, &block)
connected_to(role: current_role, prevent_writes: enabled, &block)
end

# Prevent all database access regardless of role.
#
# In some cases you may want to be sure that no database access occurs in a
# context. +while_preventing_access+ will prevent any database access for
# the duration of the block.
def while_preventing_access(enabled = true, &block)
connected_to(role: current_role, prevent_access: enabled, &block)
end

# Returns true if role is the current connected role and/or
# current connected shard. If no shard is passed, the default will be
# used.
Expand Down Expand Up @@ -353,10 +343,10 @@ def resolve_config_for_connection(config_or_env)
Base.configurations.resolve(config_or_env)
end

def with_role_and_shard(role, shard, prevent_writes:, prevent_access:)
def with_role_and_shard(role, shard, prevent_writes)
prevent_writes = true if role == ActiveRecord.reading_role

append_to_connected_to_stack(role: role, shard: shard, prevent_writes: prevent_writes, prevent_access: prevent_access, klasses: [self])
append_to_connected_to_stack(role: role, shard: shard, prevent_writes: prevent_writes, klasses: [self])
return_value = yield
return_value.load if return_value.is_a? ActiveRecord::Relation
return_value
Expand Down
19 changes: 0 additions & 19 deletions activerecord/lib/active_record/core.rb
Expand Up @@ -189,25 +189,6 @@ def self.current_preventing_writes
false
end

# Returns the symbol representing the current setting for
# preventing access.
#
# ActiveRecord::Base.connected_to(role: :reading, prevent_access: true) do
# ActiveRecord::Base.current_preventing_access #=> true
# end
#
# ActiveRecord::Base.connected_to(role: :writing, prevent_access: false) do
# ActiveRecord::Base.current_preventing_access #=> false
# end
def self.current_preventing_access
connected_to_stack.reverse_each do |hash|
return hash[:prevent_access] if !hash[:prevent_access].nil? && hash[:klasses].include?(Base)
return hash[:prevent_access] if !hash[:prevent_access].nil? && hash[:klasses].include?(connection_class_for_self)
end

false
end

def self.connected_to_stack # :nodoc:
if connected_to_stack = ActiveSupport::IsolatedExecutionState[:active_record_connected_to_stack]
connected_to_stack
Expand Down
Expand Up @@ -17,9 +17,5 @@ def self.primary_class?
def self.current_preventing_writes
false
end

def self.current_preventing_access
false
end
end
end
6 changes: 3 additions & 3 deletions activerecord/test/cases/adapter_prevent_access_test.rb
Expand Up @@ -10,13 +10,13 @@ def setup
end

def test_preventing_access_predicate
assert_not_predicate @connection, :preventing_access?
assert_not ActiveRecord::Base.preventing_access?

ActiveRecord::Base.while_preventing_access do
assert_predicate @connection, :preventing_access?
assert_predicate ActiveRecord::Base, :preventing_access?
end

assert_not_predicate @connection, :preventing_access?
assert_not ActiveRecord::Base.preventing_access?
end

def test_errors_when_query_is_called_while_preventing_access
Expand Down
2 changes: 1 addition & 1 deletion activerecord/test/cases/base_prevent_access_test.rb
Expand Up @@ -43,7 +43,7 @@ class BasePreventWritesTest < ActiveRecord::TestCase

test "current_preventing_access" do
ActiveRecord::Base.while_preventing_access do
assert ActiveRecord::Base.current_preventing_access, "expected connection current_preventing_access to return true"
assert_predicate ActiveRecord::Base, :preventing_access?, "expected connection current_preventing_access to return true"
end
end
end
Expand Down
9 changes: 0 additions & 9 deletions activerecord/test/cases/base_test.rb
Expand Up @@ -1932,15 +1932,6 @@ def test_protected_environments_are_stored_as_an_array_of_string
ActiveRecord::Base.connected_to_stack.pop
end

test "#connecting_to with prevent_access" do
SecondAbstractClass.connecting_to(role: :writing, prevent_access: true)

assert SecondAbstractClass.connected_to?(role: :writing)
assert SecondAbstractClass.current_preventing_access
ensure
ActiveRecord::Base.connected_to_stack.pop
end

test "#connected_to_many cannot be called on anything but ActiveRecord::Base" do
assert_raises NotImplementedError do
SecondAbstractClass.connected_to_many([SecondAbstractClass], role: :writing)
Expand Down
Expand Up @@ -188,7 +188,6 @@ def test_establish_connection_using_two_level_configurations

assert_not_nil pool = @handler.retrieve_connection_pool("development")
assert_not_predicate pool.lease_connection, :preventing_writes?
assert_not_predicate pool.lease_connection, :preventing_access?
assert_equal "test/db/primary.sqlite3", pool.db_config.database
ensure
ActiveRecord::Base.configurations = @prev_configs
Expand All @@ -205,7 +204,6 @@ def test_establish_connection_using_top_level_key_in_two_level_config

assert_not_nil pool = @handler.retrieve_connection_pool("development_readonly")
assert_not_predicate pool.lease_connection, :preventing_writes?
assert_not_predicate pool.lease_connection, :preventing_access?
assert_equal "test/db/readonly.sqlite3", pool.db_config.database
ensure
ActiveRecord::Base.configurations = @prev_configs
Expand All @@ -222,7 +220,6 @@ def test_establish_connection_with_string_owner_name

assert_not_nil pool = @handler.retrieve_connection_pool("custom_connection")
assert_not_predicate pool.lease_connection, :preventing_writes?
assert_not_predicate pool.lease_connection, :preventing_access?
assert_equal "test/db/readonly.sqlite3", pool.db_config.database
ensure
ActiveRecord::Base.configurations = @prev_configs
Expand Down
Expand Up @@ -131,15 +131,13 @@ def test_switching_connections_via_handler
assert ActiveRecord::Base.connected_to?(role: :reading)
assert_not ActiveRecord::Base.connected_to?(role: :writing)
assert_predicate ActiveRecord::Base.lease_connection, :preventing_writes?
assert_not_predicate ActiveRecord::Base.lease_connection, :preventing_access?
end

ActiveRecord::Base.connected_to(role: :writing) do
assert_equal :writing, ActiveRecord::Base.current_role
assert ActiveRecord::Base.connected_to?(role: :writing)
assert_not ActiveRecord::Base.connected_to?(role: :reading)
assert_not_predicate ActiveRecord::Base.lease_connection, :preventing_writes?
assert_not_predicate ActiveRecord::Base.lease_connection, :preventing_access?
end
ensure
ActiveRecord::Base.configurations = @prev_configs
Expand Down

0 comments on commit 3c1bce0

Please sign in to comment.