Skip to content

Commit

Permalink
Retry known idempotent SELECT queries on connection-related exceptions
Browse files Browse the repository at this point in the history
This commit makes two types of queries retry-able by opting into our `allow_retry` flag:
1) SELECT queries we construct by walking the Arel tree via `#to_sql_and_binds`. We use a
new `retryable` attribute on collector classes, which defaults to true for most node types,
but will be set to false for non-idempotent node types (functions, SQL literals, etc). The
`retryable` value is returned from  `#to_sql_and_binds` and used by `#select_all` and
passed down the call stack, eventually reaching the adapter's `#internal_exec_query` method.

Internally-generated SQL literals are marked as retryable via a new `retryable` attribute on
`Arel::Nodes::SqlLiteral`.

2) `#find` and `#find_by` queries with known attributes. We set `allow_retry: true` in `#cached_find_by`,
and pass this down to `#find_by_sql` and `#_query_by_sql`.

These changes ensure that queries we know are safe to retry can be retried automatically.
  • Loading branch information
adrianna-chang-shopify committed Mar 20, 2024
1 parent 6b5f058 commit 7327f20
Show file tree
Hide file tree
Showing 31 changed files with 221 additions and 52 deletions.
9 changes: 9 additions & 0 deletions activerecord/CHANGELOG.md
@@ -1,3 +1,12 @@
* Retry known idempotent SELECT queries on connection-related exceptions

SELECT queries we construct by walking the Arel tree and / or with known model attributes
are idempotent and can safely be retried in the case of a connection error. Previously,
adapters such as `TrilogyAdapter` would raise `ActiveRecord::ConnectionFailed: Trilogy::EOFError`
when encountering a connection error mid-request.

*Adrianna Chang*

* Add dirties option to uncached

This adds a `dirties` option to `ActiveRecord::Base.uncached` and
Expand Down
Expand Up @@ -93,7 +93,7 @@ def strict_loading?
def append_constraints(connection, join, constraints)
if join.is_a?(Arel::Nodes::StringJoin)
join_string = Arel::Nodes::And.new(constraints.unshift join.left)
join.left = Arel.sql(connection.visitor.compile(join_string))
join.left = Arel.sql(connection.visitor.compile(join_string), retryable: true)
else
right = join.right
right.expr = Arel::Nodes::And.new(constraints.unshift right.expr)
Expand Down
Expand Up @@ -14,7 +14,7 @@ def to_sql(arel_or_sql_string, binds = [])
sql
end

def to_sql_and_binds(arel_or_sql_string, binds = [], preparable = nil) # :nodoc:
def to_sql_and_binds(arel_or_sql_string, binds = [], preparable = nil, allow_retry = false) # :nodoc:
# Arel::TreeManager -> Arel::Node
if arel_or_sql_string.respond_to?(:ast)
arel_or_sql_string = arel_or_sql_string.ast
Expand All @@ -27,6 +27,7 @@ def to_sql_and_binds(arel_or_sql_string, binds = [], preparable = nil) # :nodoc:
end

collector = collector()
collector.retryable = true

if prepared_statements
collector.preparable = true
Expand All @@ -41,10 +42,11 @@ def to_sql_and_binds(arel_or_sql_string, binds = [], preparable = nil) # :nodoc:
else
sql = visitor.compile(arel_or_sql_string, collector)
end
[sql.freeze, binds, preparable]
allow_retry = collector.retryable
[sql.freeze, binds, preparable, allow_retry]
else
arel_or_sql_string = arel_or_sql_string.dup.freeze unless arel_or_sql_string.frozen?
[arel_or_sql_string, binds, preparable]
[arel_or_sql_string, binds, preparable, allow_retry]
end
end
private :to_sql_and_binds
Expand All @@ -64,11 +66,15 @@ def cacheable_query(klass, arel) # :nodoc:
end

# Returns an ActiveRecord::Result instance.
def select_all(arel, name = nil, binds = [], preparable: nil, async: false)
def select_all(arel, name = nil, binds = [], preparable: nil, async: false, allow_retry: false)
arel = arel_from_relation(arel)
sql, binds, preparable = to_sql_and_binds(arel, binds, preparable)
sql, binds, preparable, allow_retry = to_sql_and_binds(arel, binds, preparable, allow_retry)

select(sql, name, binds, prepare: prepared_statements && preparable, async: async && FutureResult::SelectAll)
select(sql, name, binds,
prepare: prepared_statements && preparable,
async: async && FutureResult::SelectAll,
allow_retry: allow_retry
)
rescue ::RangeError
ActiveRecord::Result.empty(async: async)
end
Expand Down Expand Up @@ -495,7 +501,7 @@ def with_yaml_fallback(value) # :nodoc:
end

# This is a safe default, even if not high precision on all databases
HIGH_PRECISION_CURRENT_TIMESTAMP = Arel.sql("CURRENT_TIMESTAMP").freeze # :nodoc:
HIGH_PRECISION_CURRENT_TIMESTAMP = Arel.sql("CURRENT_TIMESTAMP", retryable: true).freeze # :nodoc:
private_constant :HIGH_PRECISION_CURRENT_TIMESTAMP

# Returns an Arel SQL literal for the CURRENT_TIMESTAMP for usage with
Expand All @@ -507,7 +513,7 @@ def high_precision_current_timestamp
HIGH_PRECISION_CURRENT_TIMESTAMP
end

def internal_exec_query(sql, name = "SQL", binds = [], prepare: false, async: false) # :nodoc:
def internal_exec_query(sql, name = "SQL", binds = [], prepare: false, async: false, allow_retry: false) # :nodoc:
raise NotImplementedError
end

Expand Down Expand Up @@ -606,7 +612,7 @@ def combine_multi_statements(total_sql)
end

# Returns an ActiveRecord::Result instance.
def select(sql, name = nil, binds = [], prepare: false, async: false)
def select(sql, name = nil, binds = [], prepare: false, async: false, allow_retry: false)
if async && async_enabled?
if current_transaction.joinable?
raise AsynchronousQueryInsideTransactionError, "Asynchronous queries are not allowed inside transactions"
Expand All @@ -627,7 +633,7 @@ def select(sql, name = nil, binds = [], prepare: false, async: false)
return future_result
end

result = internal_exec_query(sql, name, binds, prepare: prepare)
result = internal_exec_query(sql, name, binds, prepare: prepare, allow_retry: allow_retry)
if async
FutureResult.wrap(result)
else
Expand Down
Expand Up @@ -204,19 +204,19 @@ def clear_query_cache
pool.clear_query_cache
end

def select_all(arel, name = nil, binds = [], preparable: nil, async: false) # :nodoc:
def select_all(arel, name = nil, binds = [], preparable: nil, async: false, allow_retry: false) # :nodoc:
arel = arel_from_relation(arel)

# If arel is locked this is a SELECT ... FOR UPDATE or somesuch.
# Such queries should not be cached.
if @query_cache&.enabled? && !(arel.respond_to?(:locked) && arel.locked)
sql, binds, preparable = to_sql_and_binds(arel, binds, preparable)
sql, binds, preparable, allow_retry = to_sql_and_binds(arel, binds, preparable)

if async
result = lookup_sql_cache(sql, name, binds) || super(sql, name, binds, preparable: preparable, async: async)
result = lookup_sql_cache(sql, name, binds) || super(sql, name, binds, preparable: preparable, async: async, allow_retry: allow_retry)
FutureResult.wrap(result)
else
cache_sql(sql, name, binds) { super(sql, name, binds, preparable: preparable, async: async) }
cache_sql(sql, name, binds) { super(sql, name, binds, preparable: preparable, async: async, allow_retry: allow_retry) }
end
else
super
Expand Down
Expand Up @@ -229,12 +229,12 @@ def disable_referential_integrity # :nodoc:
# Mysql2Adapter doesn't have to free a result after using it, but we use this method
# to write stuff in an abstract way without concerning ourselves about whether it
# needs to be explicitly freed or not.
def execute_and_free(sql, name = nil, async: false) # :nodoc:
def execute_and_free(sql, name = nil, async: false, allow_retry: false) # :nodoc:
sql = transform_query(sql)
check_if_write_query(sql)

mark_transaction_written_if_write(sql)
yield raw_execute(sql, name, async: async)
yield raw_execute(sql, name, async: async, allow_retry: allow_retry)
end

def begin_db_transaction # :nodoc:
Expand Down
Expand Up @@ -11,7 +11,7 @@ module DatabaseStatements

# https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_current-timestamp
# https://dev.mysql.com/doc/refman/5.7/en/date-and-time-type-syntax.html
HIGH_PRECISION_CURRENT_TIMESTAMP = Arel.sql("CURRENT_TIMESTAMP(6)").freeze # :nodoc:
HIGH_PRECISION_CURRENT_TIMESTAMP = Arel.sql("CURRENT_TIMESTAMP(6)", retryable: true).freeze # :nodoc:
private_constant :HIGH_PRECISION_CURRENT_TIMESTAMP

def write_query?(sql) # :nodoc:
Expand Down
Expand Up @@ -18,9 +18,9 @@ def select_all(*, **) # :nodoc:
result
end

def internal_exec_query(sql, name = "SQL", binds = [], prepare: false, async: false) # :nodoc:
def internal_exec_query(sql, name = "SQL", binds = [], prepare: false, async: false, allow_retry: false) # :nodoc:
if without_prepared_statement?(binds)
execute_and_free(sql, name, async: async) do |result|
execute_and_free(sql, name, async: async, allow_retry: allow_retry) do |result|
if result
build_result(columns: result.fields, rows: result.to_a)
else
Expand Down
Expand Up @@ -124,7 +124,7 @@ def exec_restart_db_transaction # :nodoc:
end

# From https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-CURRENT
HIGH_PRECISION_CURRENT_TIMESTAMP = Arel.sql("CURRENT_TIMESTAMP").freeze # :nodoc:
HIGH_PRECISION_CURRENT_TIMESTAMP = Arel.sql("CURRENT_TIMESTAMP", retryable: true).freeze # :nodoc:
private_constant :HIGH_PRECISION_CURRENT_TIMESTAMP

def high_precision_current_timestamp
Expand Down
Expand Up @@ -881,7 +881,7 @@ def exec_no_cache(sql, name, binds, async:, allow_retry:, materialize_transactio

type_casted_binds = type_casted_binds(binds)
log(sql, name, binds, type_casted_binds, async: async) do |notification_payload|
with_raw_connection(allow_retry: false, materialize_transactions: materialize_transactions) do |conn|
with_raw_connection(allow_retry: allow_retry, materialize_transactions: materialize_transactions) do |conn|
result = conn.exec_params(sql, type_casted_binds)
verified!
notification_payload[:row_count] = result.count
Expand All @@ -895,7 +895,7 @@ def exec_cache(sql, name, binds, async:, allow_retry:, materialize_transactions:

update_typemap_for_default_timezone

with_raw_connection(allow_retry: false, materialize_transactions: materialize_transactions) do |conn|
with_raw_connection(allow_retry: allow_retry, materialize_transactions: materialize_transactions) do |conn|
stmt_key = prepare_statement(sql, binds, conn)
type_casted_binds = type_casted_binds(binds)

Expand Down
Expand Up @@ -21,7 +21,7 @@ def explain(arel, binds = [], _options = [])
SQLite3::ExplainPrettyPrinter.new.pp(result)
end

def internal_exec_query(sql, name = nil, binds = [], prepare: false, async: false) # :nodoc:
def internal_exec_query(sql, name = nil, binds = [], prepare: false, async: false, allow_retry: false) # :nodoc:
sql = transform_query(sql)
check_if_write_query(sql)

Expand Down Expand Up @@ -106,7 +106,7 @@ def exec_rollback_db_transaction # :nodoc:

# https://stackoverflow.com/questions/17574784
# https://www.sqlite.org/lang_datefunc.html
HIGH_PRECISION_CURRENT_TIMESTAMP = Arel.sql("STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')").freeze # :nodoc:
HIGH_PRECISION_CURRENT_TIMESTAMP = Arel.sql("STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')", retryable: true).freeze # :nodoc:
private_constant :HIGH_PRECISION_CURRENT_TIMESTAMP

def high_precision_current_timestamp
Expand Down
Expand Up @@ -12,12 +12,12 @@ def select_all(*, **) # :nodoc:
result
end

def internal_exec_query(sql, name = "SQL", binds = [], prepare: false, async: false) # :nodoc:
def internal_exec_query(sql, name = "SQL", binds = [], prepare: false, async: false, allow_retry: false) # :nodoc:
sql = transform_query(sql)
check_if_write_query(sql)
mark_transaction_written_if_write(sql)

result = raw_execute(sql, name, async: async)
result = raw_execute(sql, name, async: async, allow_retry: allow_retry)
ActiveRecord::Result.new(result.fields, result.to_a)
end

Expand Down
2 changes: 1 addition & 1 deletion activerecord/lib/active_record/core.rb
Expand Up @@ -431,7 +431,7 @@ def cached_find_by(keys, values)
}

begin
statement.execute(values.flatten, lease_connection).first
statement.execute(values.flatten, lease_connection, allow_retry: true).first
rescue TypeError
raise ActiveRecord::StatementInvalid
end
Expand Down
2 changes: 1 addition & 1 deletion activerecord/lib/active_record/internal_metadata.rb
Expand Up @@ -153,7 +153,7 @@ def update_entry(connection, key, new_value)

def select_entry(connection, key)
sm = Arel::SelectManager.new(arel_table)
sm.project(Arel::Nodes::SqlLiteral.new("*"))
sm.project(Arel::Nodes::SqlLiteral.new("*", retryable: true))
sm.where(arel_table[primary_key].eq(Arel::Nodes::BindParam.new(key)))
sm.order(arel_table[primary_key].asc)
sm.limit = 1
Expand Down
8 changes: 4 additions & 4 deletions activerecord/lib/active_record/querying.rb
Expand Up @@ -47,8 +47,8 @@ module Querying
#
# Note that building your own SQL query string from user input may expose your application to
# injection attacks (https://guides.rubyonrails.org/security.html#sql-injection).
def find_by_sql(sql, binds = [], preparable: nil, &block)
_load_from_sql(_query_by_sql(sql, binds, preparable: preparable), &block)
def find_by_sql(sql, binds = [], preparable: nil, allow_retry: false, &block)
_load_from_sql(_query_by_sql(sql, binds, preparable: preparable, allow_retry: allow_retry), &block)
end

# Same as <tt>#find_by_sql</tt> but perform the query asynchronously and returns an ActiveRecord::Promise.
Expand All @@ -58,8 +58,8 @@ def async_find_by_sql(sql, binds = [], preparable: nil, &block)
end
end

def _query_by_sql(sql, binds = [], preparable: nil, async: false) # :nodoc:
lease_connection.select_all(sanitize_sql(sql), "#{name} Load", binds, preparable: preparable, async: async)
def _query_by_sql(sql, binds = [], preparable: nil, async: false, allow_retry: false) # :nodoc:
lease_connection.select_all(sanitize_sql(sql), "#{name} Load", binds, preparable: preparable, async: async, allow_retry: allow_retry)
end

def _load_from_sql(result_set, &block) # :nodoc:
Expand Down
4 changes: 2 additions & 2 deletions activerecord/lib/active_record/relation/calculations.rb
Expand Up @@ -446,7 +446,7 @@ def aggregate_column(column_name)
return column_name if Arel::Expressions === column_name

arel_column(column_name.to_s) do |name|
Arel.sql(column_name == :all ? "*" : name)
column_name == :all ? Arel.sql("*", retryable: true) : Arel.sql(name)
end
end

Expand Down Expand Up @@ -643,7 +643,7 @@ def build_count_subquery(relation, column_name, distinct)
relation.select_values = [ aggregate_column(column_name).as(column_alias) ]
end

subquery_alias = Arel.sql("subquery_for_count")
subquery_alias = Arel.sql("subquery_for_count", retryable: true)
select_value = operation_over_aggregate_column(column_alias, "count", false)

relation.build_subquery(subquery_alias, select_value)
Expand Down
4 changes: 2 additions & 2 deletions activerecord/lib/active_record/relation/predicate_builder.rb
Expand Up @@ -28,9 +28,9 @@ def build_from_hash(attributes, &block)
def self.references(attributes)
attributes.each_with_object([]) do |(key, value), result|
if value.is_a?(Hash)
result << Arel.sql(key)
result << Arel.sql(key, retryable: true)
elsif (idx = key.rindex("."))
result << Arel.sql(key[0, idx])
result << Arel.sql(key[0, idx], retryable: true)
end
end
end
Expand Down
2 changes: 1 addition & 1 deletion activerecord/lib/active_record/relation/query_methods.rb
Expand Up @@ -2013,7 +2013,7 @@ def order_column(field)
if attr_name == "count" && !group_values.empty?
table[attr_name]
else
Arel.sql(adapter_class.quote_table_name(attr_name))
Arel.sql(adapter_class.quote_table_name(attr_name), retryable: true)
end
end
end
Expand Down
6 changes: 3 additions & 3 deletions activerecord/lib/active_record/statement_cache.rb
Expand Up @@ -62,7 +62,7 @@ def sql_for(binds, connection)
end

class PartialQueryCollector
attr_accessor :preparable
attr_accessor :preparable, :retryable

def initialize
@parts = []
Expand Down Expand Up @@ -142,12 +142,12 @@ def initialize(query_builder, bind_map, klass)
@klass = klass
end

def execute(params, connection, &block)
def execute(params, connection, allow_retry: false, &block)
bind_values = bind_map.bind params

sql = query_builder.sql_for bind_values, connection

klass.find_by_sql(sql, bind_values, preparable: true, &block)
klass.find_by_sql(sql, bind_values, preparable: true, allow_retry: allow_retry, &block)
rescue ::RangeError
[]
end
Expand Down
10 changes: 7 additions & 3 deletions activerecord/lib/arel.rb
Expand Up @@ -45,16 +45,20 @@ module Arel
# that this behavior only applies when bind value parameters are
# supplied in the call; without them, the placeholder tokens have no
# special meaning, and will be passed through to the query as-is.
def self.sql(sql_string, *positional_binds, **named_binds)
#
# The +:retryable+ option can be used to mark the SQL as safe to retry.
# Use this option only if the SQL is idempotent, as it could be executed
# more than once.
def self.sql(sql_string, *positional_binds, retryable: false, **named_binds)
if positional_binds.empty? && named_binds.empty?
Arel::Nodes::SqlLiteral.new sql_string
Arel::Nodes::SqlLiteral.new(sql_string, retryable: retryable)
else
Arel::Nodes::BoundSqlLiteral.new sql_string, positional_binds, named_binds
end
end

def self.star # :nodoc:
sql "*"
sql("*", retryable: true)
end

def self.arel_node?(value) # :nodoc:
Expand Down
2 changes: 1 addition & 1 deletion activerecord/lib/arel/alias_predication.rb
Expand Up @@ -3,7 +3,7 @@
module Arel # :nodoc: all
module AliasPredication
def as(other)
Nodes::As.new self, Nodes::SqlLiteral.new(other)
Nodes::As.new self, Nodes::SqlLiteral.new(other, retryable: true)
end
end
end
2 changes: 2 additions & 0 deletions activerecord/lib/arel/collectors/bind.rb
Expand Up @@ -3,6 +3,8 @@
module Arel # :nodoc: all
module Collectors
class Bind
attr_accessor :retryable

def initialize
@binds = []
end
Expand Down
7 changes: 7 additions & 0 deletions activerecord/lib/arel/collectors/composite.rb
Expand Up @@ -4,12 +4,19 @@ module Arel # :nodoc: all
module Collectors
class Composite
attr_accessor :preparable
attr_reader :retryable

def initialize(left, right)
@left = left
@right = right
end

def retryable=(retryable)
left.retryable = retryable
right.retryable = retryable
@retryable = retryable
end

def <<(str)
left << str
right << str
Expand Down

0 comments on commit 7327f20

Please sign in to comment.