Skip to content

Commit

Permalink
Add ability to prevent access to a database
Browse files Browse the repository at this point in the history
  • Loading branch information
stevecrozz committed Mar 19, 2024
1 parent 1b226d5 commit c438333
Show file tree
Hide file tree
Showing 22 changed files with 300 additions and 19 deletions.
17 changes: 17 additions & 0 deletions activerecord/CHANGELOG.md
@@ -1,3 +1,20 @@
* Add the ability to prevent access to a database for the duration of a block.

Allows the application to prevent database access. This can be useful to
ensure a routine does not depend on database access.

If `while_preventing_access` is called and there is a database query within
the block, the connection will raise an exception.

One purpose of this is to catch accidental reads.

For example, an application may have a method which is known to be called
in tight loops. If database access from within this method could lead to
unacceptable performance impacts, it may be desirable to prevent database
access within this method.

*Stephen Crosby*

* Add dirties option to uncached

This adds a `dirties` option to `ActiveRecord::Base.uncached` and
Expand Down
Expand Up @@ -69,6 +69,10 @@ def primary_class?
def current_preventing_writes
false
end

def current_preventing_access
false
end
end

def initialize
Expand Down
Expand Up @@ -514,7 +514,7 @@ def internal_exec_query(sql, name = "SQL", binds = [], prepare: false, async: fa
private
def internal_execute(sql, name = "SCHEMA", allow_retry: false, materialize_transactions: true)
sql = transform_query(sql)
check_if_write_query(sql)
check_if_query_prevented(sql)

mark_transaction_written_if_write(sql)

Expand Down
Expand Up @@ -197,10 +197,14 @@ def with_instrumenter(instrumenter, &block) # :nodoc:
end
end

def check_if_write_query(sql) # :nodoc:
def check_if_query_prevented(sql) # :nodoc:
if preventing_writes? && write_query?(sql)
raise ActiveRecord::ReadOnlyError, "Write query attempted while in readonly mode: #{sql}"
end

if preventing_access?
raise ActiveRecord::PreventedAccessError, "Query attempted while preventing access: #{sql}"
end
end

def replica?
Expand Down Expand Up @@ -234,6 +238,15 @@ def preventing_writes?
connection_class.current_preventing_writes
end

# Determines whether access is currently being prevented.
#
# Returns the value of +current_preventing_access+.
def preventing_access?
return false if connection_class.nil?

connection_class.current_preventing_access
end

def prepared_statements?
@prepared_statements && !prepared_statements_disabled_cache.include?(object_id)
end
Expand Down
Expand Up @@ -231,7 +231,7 @@ def disable_referential_integrity # :nodoc:
# needs to be explicitly freed or not.
def execute_and_free(sql, name = nil, async: false) # :nodoc:
sql = transform_query(sql)
check_if_write_query(sql)
check_if_query_prevented(sql)

mark_transaction_written_if_write(sql)
yield raw_execute(sql, name, async: async)
Expand Down
Expand Up @@ -112,7 +112,7 @@ def raw_execute(sql, name, async: false, allow_retry: false, materialize_transac

def exec_stmt_and_free(sql, name, binds, cache_stmt: false, async: false)
sql = transform_query(sql)
check_if_write_query(sql)
check_if_query_prevented(sql)

mark_transaction_written_if_write(sql)

Expand Down
Expand Up @@ -857,7 +857,7 @@ def load_types_queries(initializer, oids)

def execute_and_clear(sql, name, binds, prepare: false, async: false, allow_retry: false, materialize_transactions: true)
sql = transform_query(sql)
check_if_write_query(sql)
check_if_query_prevented(sql)

if !prepare || without_prepared_statement?(binds)
result = exec_no_cache(sql, name, binds, async: async, allow_retry: allow_retry, materialize_transactions: materialize_transactions)
Expand Down
Expand Up @@ -23,7 +23,7 @@ def explain(arel, binds = [], _options = [])

def internal_exec_query(sql, name = nil, binds = [], prepare: false, async: false) # :nodoc:
sql = transform_query(sql)
check_if_write_query(sql)
check_if_query_prevented(sql)

mark_transaction_written_if_write(sql)

Expand Down Expand Up @@ -136,7 +136,7 @@ def execute_batch(statements, name = nil)
statements = statements.map { |sql| transform_query(sql) }
sql = combine_multi_statements(statements)

check_if_write_query(sql)
check_if_query_prevented(sql)
mark_transaction_written_if_write(sql)

log(sql, name) do |notification_payload|
Expand Down
Expand Up @@ -14,7 +14,7 @@ def select_all(*, **) # :nodoc:

def internal_exec_query(sql, name = "SQL", binds = [], prepare: false, async: false) # :nodoc:
sql = transform_query(sql)
check_if_write_query(sql)
check_if_query_prevented(sql)
mark_transaction_written_if_write(sql)

result = raw_execute(sql, name, async: async)
Expand All @@ -23,7 +23,7 @@ def internal_exec_query(sql, name = "SQL", binds = [], prepare: false, async: fa

def exec_insert(sql, name, binds, pk = nil, sequence_name = nil, returning: nil) # :nodoc:
sql = transform_query(sql)
check_if_write_query(sql)
check_if_query_prevented(sql)
mark_transaction_written_if_write(sql)

sql, _binds = sql_for_insert(sql, pk, binds, returning)
Expand All @@ -32,7 +32,7 @@ def exec_insert(sql, name, binds, pk = nil, sequence_name = nil, returning: nil)

def exec_delete(sql, name = nil, binds = []) # :nodoc:
sql = transform_query(sql)
check_if_write_query(sql)
check_if_query_prevented(sql)
mark_transaction_written_if_write(sql)

result = raw_execute(to_sql(sql, binds), name)
Expand Down
28 changes: 19 additions & 9 deletions activerecord/lib/active_record/connection_handling.rb
Expand Up @@ -131,7 +131,7 @@ def connects_to(database: {}, shards: {})
# ActiveRecord::Base.connected_to(role: :reading, shard: :shard_one_replica) do
# Dog.first # finds first Dog record stored on the shard one replica
# end
def connected_to(role: nil, shard: nil, prevent_writes: false, &blk)
def connected_to(role: nil, shard: nil, prevent_writes: false, prevent_access: false, &blk)
if self != Base && !abstract_class
raise NotImplementedError, "calling `connected_to` is only allowed on ActiveRecord::Base or abstract classes."
end
Expand All @@ -144,12 +144,13 @@ def connected_to(role: nil, shard: nil, prevent_writes: false, &blk)
raise ArgumentError, "must provide a `shard` and/or `role`."
end

with_role_and_shard(role, shard, prevent_writes, &blk)
with_role_and_shard(role, shard, prevent_writes: prevent_writes, prevent_access: prevent_access, &blk)
end

# Connects a role and/or shard to the provided connection names. Optionally +prevent_writes+
# can be passed to block writes on a connection. +reading+ will automatically set
# +prevent_writes+ to true.
# +prevent_writes+ to true. Optionally +prevent_access+ can be passed to block all access on a
# connection.
#
# +connected_to_many+ is an alternative to deeply nested +connected_to+ blocks.
#
Expand All @@ -160,7 +161,7 @@ def connected_to(role: nil, shard: nil, prevent_writes: false, &blk)
# Dinner.first # Read from meals replica
# Person.first # Read from primary writer
# end
def connected_to_many(*classes, role:, shard: nil, prevent_writes: false)
def connected_to_many(*classes, role:, shard: nil, prevent_writes: false, prevent_access: false)
classes = classes.flatten

if self != Base || classes.include?(Base)
Expand All @@ -169,7 +170,7 @@ def connected_to_many(*classes, role:, shard: nil, prevent_writes: false)

prevent_writes = true if role == ActiveRecord.reading_role

append_to_connected_to_stack(role: role, shard: shard, prevent_writes: prevent_writes, klasses: classes)
append_to_connected_to_stack(role: role, shard: shard, prevent_writes: prevent_writes, prevent_access: prevent_access, klasses: classes)
yield
ensure
connected_to_stack.pop
Expand All @@ -182,10 +183,10 @@ def connected_to_many(*classes, role:, shard: nil, prevent_writes: false)
#
# It is not recommended to use this method in a request since it
# does not yield to a block like +connected_to+.
def connecting_to(role: default_role, shard: default_shard, prevent_writes: false)
def connecting_to(role: default_role, shard: default_shard, prevent_writes: false, prevent_access: false)
prevent_writes = true if role == ActiveRecord.reading_role

append_to_connected_to_stack(role: role, shard: shard, prevent_writes: prevent_writes, klasses: [self])
append_to_connected_to_stack(role: role, shard: shard, prevent_writes: prevent_writes, prevent_access: prevent_access, klasses: [self])
end

# Prohibit swapping shards while inside of the passed block.
Expand Down Expand Up @@ -222,6 +223,15 @@ def while_preventing_writes(enabled = true, &block)
connected_to(role: current_role, prevent_writes: enabled, &block)
end

# Prevent all database access regardless of role.
#
# In some cases you may want to be sure that no database access occurs in a
# context. +while_preventing_access+ will prevent any database access for
# the duration of the block.
def while_preventing_access(enabled = true, &block)
connected_to(role: current_role, prevent_access: enabled, &block)
end

# Returns true if role is the current connected role and/or
# current connected shard. If no shard is passed, the default will be
# used.
Expand Down Expand Up @@ -343,10 +353,10 @@ def resolve_config_for_connection(config_or_env)
Base.configurations.resolve(config_or_env)
end

def with_role_and_shard(role, shard, prevent_writes)
def with_role_and_shard(role, shard, prevent_writes:, prevent_access:)
prevent_writes = true if role == ActiveRecord.reading_role

append_to_connected_to_stack(role: role, shard: shard, prevent_writes: prevent_writes, klasses: [self])
append_to_connected_to_stack(role: role, shard: shard, prevent_writes: prevent_writes, prevent_access: prevent_access, klasses: [self])
return_value = yield
return_value.load if return_value.is_a? ActiveRecord::Relation
return_value
Expand Down
19 changes: 19 additions & 0 deletions activerecord/lib/active_record/core.rb
Expand Up @@ -189,6 +189,25 @@ def self.current_preventing_writes
false
end

# Returns the symbol representing the current setting for
# preventing access.
#
# ActiveRecord::Base.connected_to(role: :reading, prevent_access: true) do
# ActiveRecord::Base.current_preventing_access #=> true
# end
#
# ActiveRecord::Base.connected_to(role: :writing, prevent_access: false) do
# ActiveRecord::Base.current_preventing_access #=> false
# end
def self.current_preventing_access
connected_to_stack.reverse_each do |hash|
return hash[:prevent_access] if !hash[:prevent_access].nil? && hash[:klasses].include?(Base)
return hash[:prevent_access] if !hash[:prevent_access].nil? && hash[:klasses].include?(connection_class_for_self)
end

false
end

def self.connected_to_stack # :nodoc:
if connected_to_stack = ActiveSupport::IsolatedExecutionState[:active_record_connected_to_stack]
connected_to_stack
Expand Down
4 changes: 4 additions & 0 deletions activerecord/lib/active_record/errors.rb
Expand Up @@ -115,6 +115,10 @@ class ExclusiveConnectionTimeoutError < ConnectionTimeoutError
class ReadOnlyError < ActiveRecordError
end

# Raised when database access is attempted on a connection preventing access.
class PreventedAccessError < ActiveRecordError
end

# Raised when Active Record cannot find a record by given id or set of ids.
class RecordNotFound < ActiveRecordError
attr_reader :model, :primary_key, :id
Expand Down
Expand Up @@ -17,5 +17,9 @@ def self.primary_class?
def self.current_preventing_writes
false
end

def self.current_preventing_access
false
end
end
end
32 changes: 32 additions & 0 deletions activerecord/test/cases/adapter_prevent_access_test.rb
@@ -0,0 +1,32 @@
# frozen_string_literal: true

require "cases/helper"
require "support/connection_helper"

module ActiveRecord
class AdapterPreventAccessTest < ActiveRecord::TestCase
def setup
@connection = ActiveRecord::Base.lease_connection
end

def test_preventing_access_predicate
assert_not_predicate @connection, :preventing_access?

ActiveRecord::Base.while_preventing_access do
assert_predicate @connection, :preventing_access?
end

assert_not_predicate @connection, :preventing_access?
end

def test_errors_when_query_is_called_while_preventing_access
@connection.select_all("SELECT count(*) FROM subscribers")

ActiveRecord::Base.while_preventing_access do
assert_raises(ActiveRecord::PreventedAccessError) do
@connection.select_all("SELECT count(*) FROM subscribers")
end
end
end
end
end
@@ -0,0 +1,22 @@
# frozen_string_literal: true

require "cases/helper"
require "support/ddl_helper"

class AdapterPreventAccessTest < ActiveRecord::AbstractMysqlTestCase
include DdlHelper

def setup
@conn = ActiveRecord::Base.lease_connection
end

def test_error_when_a_query_is_called_while_preventing_access
@conn.execute("INSERT INTO `engines` (`car_id`) VALUES ('138853948594')")

ActiveRecord::Base.while_preventing_access do
assert_raises(ActiveRecord::PreventedAccessError) do
@conn.execute("SELECT `engines`.* FROM `engines` WHERE `engines`.`car_id` = '138853948594'")
end
end
end
end
@@ -0,0 +1,28 @@
# frozen_string_literal: true

require "cases/helper"
require "support/ddl_helper"
require "support/connection_helper"

module ActiveRecord
module ConnectionAdapters
class PostgreSQLAdapterPreventWritesTest < ActiveRecord::PostgreSQLTestCase
include DdlHelper
include ConnectionHelper

def setup
@connection = ActiveRecord::Base.lease_connection
end

def test_error_when_a_query_is_called_while_preventing_access
@conn.execute("INSERT INTO `engines` (`car_id`) VALUES ('138853948594')")

ActiveRecord::Base.while_preventing_access do
assert_raises(ActiveRecord::PreventedAccessError) do
@conn.execute("SELECT `engines`.* FROM `engines` WHERE `engines`.`car_id` = '138853948594'")
end
end
end
end
end
end

0 comments on commit c438333

Please sign in to comment.