Skip to content

Commit

Permalink
Relation#where build BoundSqlLiteral rather than eagerly interpolate
Browse files Browse the repository at this point in the history
Ref: #50793

To make not caching connection checkout viable, we need to reduced
the amount of places where we need a connection.

Once big source of this is query/relation building, where in many
cases it eagerly quote and interpolation bound values in SQL fragments.

Doing this requires an active connection because both MySQL and Postgres
may quote values differently based on the connection settings.

Instead of eagerly doing all this, we can instead just insert these
as bound values in the Arel AST. For adapters with prepared statements
this is better anyway as it will avoid leaking statements, and for those
that don't support it, it will simply delay the quoting to just
before the query is executed.

For now now only a subset of the API is migrated over, namely the
`where("title = ?", something)` form.

There is also `where("title = %s", something)` that I'm afraid won't
be able to be fixed and probably should be deprecated.

As well as `where("title = :title")` which I think is doable in a
followup.
  • Loading branch information
byroot committed Feb 20, 2024
1 parent 05eb098 commit 1a144e3
Show file tree
Hide file tree
Showing 15 changed files with 109 additions and 56 deletions.
Expand Up @@ -141,6 +141,11 @@ def type_cast(value) # :nodoc:
encode_array(value)
when Range
encode_range(value)
when Rational
value.to_f
when ActiveSupport::Duration
warn_quote_duration_deprecated
value.to_i
else
super
end
Expand Down
Expand Up @@ -63,14 +63,17 @@ def quote_default_expression(value, column) # :nodoc:

def type_cast(value) # :nodoc:
case value
when BigDecimal
when BigDecimal, Rational
value.to_f
when String
if value.encoding == Encoding::ASCII_8BIT
super(value.encode(Encoding::UTF_8))
else
super
end
when ActiveSupport::Duration
warn_quote_duration_deprecated
value.to_i
else
super
end
Expand Down
32 changes: 30 additions & 2 deletions activerecord/lib/active_record/relation/query_methods.rb
Expand Up @@ -1512,9 +1512,17 @@ def build_subquery(subquery_alias, select_value) # :nodoc:
def build_where_clause(opts, rest = []) # :nodoc:
opts = sanitize_forbidden_attributes(opts)

if opts.is_a?(Array)
opts, *rest = opts
end

case opts
when String, Array
parts = [klass.sanitize_sql(rest.empty? ? opts : [opts, *rest])]
when String
if opts.include?("?")
parts = [build_bound_sql_literal(opts, rest)]
else
parts = [klass.sanitize_sql(rest.empty? ? opts : [opts, *rest])]
end
when Hash
opts = opts.transform_keys do |key|
if key.is_a?(Array)
Expand Down Expand Up @@ -1550,6 +1558,26 @@ def async
spawn.async!
end

def build_bound_sql_literal(statement, values)
bound_values = values.map do |value|
if ActiveRecord::Relation === value
Arel.sql(value.to_sql)
elsif value.respond_to?(:map) && !value.acts_like?(:string)
values = value.map { |v| v.respond_to?(:id_for_database) ? v.id_for_database : v }
values.empty? ? nil : values
else
value = value.id_for_database if value.respond_to?(:id_for_database)
value
end
end

begin
Arel::Nodes::BoundSqlLiteral.new("(#{statement})", bound_values, nil)
rescue Arel::BindError => error
raise ActiveRecord::PreparedStatementInvalid, error.message
end
end

def lookup_table_klass_from_join_dependencies(table_name)
each_join_dependencies do |join|
return join.base_klass if table_name == join.table_name
Expand Down
12 changes: 8 additions & 4 deletions activerecord/lib/arel/nodes/bound_sql_literal.rb
Expand Up @@ -6,13 +6,17 @@ class BoundSqlLiteral < NodeExpression
attr_reader :sql_with_placeholders, :positional_binds, :named_binds

def initialize(sql_with_placeholders, positional_binds, named_binds)
if !positional_binds.empty? && !named_binds.empty?
raise BindError.new("cannot mix positional and named binds", sql_with_placeholders)
elsif !positional_binds.empty?
has_positional = !(positional_binds.nil? || positional_binds.empty?)
has_named = !(named_binds.nil? || named_binds.empty?)

if has_positional
if has_named
raise BindError.new("cannot mix positional and named binds", sql_with_placeholders)
end
if positional_binds.size != (expected = sql_with_placeholders.count("?"))
raise BindError.new("wrong number of bind variables (#{positional_binds.size} for #{expected})", sql_with_placeholders)
end
elsif !named_binds.empty?
elsif has_named
tokens_in_string = sql_with_placeholders.scan(/:(?<!::)([a-zA-Z]\w*)/).flatten.map(&:to_sym).uniq
tokens_in_hash = named_binds.keys.map(&:to_sym).uniq

Expand Down
3 changes: 2 additions & 1 deletion activerecord/lib/arel/update_manager.rb
Expand Up @@ -16,7 +16,8 @@ def table(table)
end

def set(values)
if String === values
case values
when String, Nodes::BoundSqlLiteral
@ast.values = [values]
else
@ast.values = values.map { |column, value|
Expand Down
Expand Up @@ -49,7 +49,7 @@ def test_create_null_bytes
end

def test_where_with_string_for_string_column_using_bind_parameters
assert_quoted_as "'Welcome to the weblog'", "Welcome to the weblog"
assert_quoted_as "'Welcome to the weblog'", "Welcome to the weblog", match: 1
end

def test_where_with_integer_for_string_column_using_bind_parameters
Expand Down Expand Up @@ -79,14 +79,16 @@ def test_where_with_duration_for_string_column_using_bind_parameters
end

private
def assert_quoted_as(expected, value)
def assert_quoted_as(expected, value, match: 0)
relation = Post.where("title = ?", value)
assert_equal(
%{SELECT `posts`.* FROM `posts` WHERE (title = #{expected})},
relation.to_sql,
)
assert_nothing_raised do # Make sure SQL is valid
relation.to_a
if match == 0
assert_empty relation.to_a
else
assert_equal match, relation.count
end
end
end
Expand Down
26 changes: 11 additions & 15 deletions activerecord/test/cases/adapters/postgresql/bind_parameter_test.rb
Expand Up @@ -10,50 +10,46 @@ class BindParameterTest < ActiveRecord::PostgreSQLTestCase
fixtures :posts

def test_where_with_string_for_string_column_using_bind_parameters
assert_quoted_as "'Welcome to the weblog'", "Welcome to the weblog"
assert_quoted_as "'Welcome to the weblog'", "Welcome to the weblog", match: 1
end

def test_where_with_integer_for_string_column_using_bind_parameters
assert_quoted_as "0", 0, valid: false
assert_quoted_as "0", 0
end

def test_where_with_float_for_string_column_using_bind_parameters
assert_quoted_as "0.0", 0.0, valid: false
assert_quoted_as "0.0", 0.0
end

def test_where_with_boolean_for_string_column_using_bind_parameters
assert_quoted_as "FALSE", false, valid: false
assert_quoted_as "FALSE", false
end

def test_where_with_decimal_for_string_column_using_bind_parameters
assert_quoted_as "0.0", BigDecimal(0), valid: false
assert_quoted_as "0.0", BigDecimal(0)
end

def test_where_with_rational_for_string_column_using_bind_parameters
assert_quoted_as "0/1", Rational(0), valid: false
assert_quoted_as "0/1", Rational(0)
end

def test_where_with_duration_for_string_column_using_bind_parameters
assert_deprecated(ActiveRecord.deprecator) do
assert_quoted_as "0", 0.seconds, valid: false
assert_quoted_as "0", 0.seconds
end
end

private
def assert_quoted_as(expected, value, valid: true)
def assert_quoted_as(expected, value, match: 0)
relation = Post.where("title = ?", value)
assert_equal(
%{SELECT "posts".* FROM "posts" WHERE (title = #{expected})},
relation.to_sql,
)
if valid
assert_nothing_raised do # Make sure SQL is valid
relation.to_a
end
if match == 0
assert_empty relation.to_a
else
assert_raises ActiveRecord::StatementInvalid do
relation.to_a
end
assert_equal match, relation.count
end
end
end
Expand Down
10 changes: 6 additions & 4 deletions activerecord/test/cases/adapters/sqlite3/bind_parameter_test.rb
Expand Up @@ -10,7 +10,7 @@ class BindParameterTest < ActiveRecord::SQLite3TestCase
fixtures :posts

def test_where_with_string_for_string_column_using_bind_parameters
assert_quoted_as "'Welcome to the weblog'", "Welcome to the weblog"
assert_quoted_as "'Welcome to the weblog'", "Welcome to the weblog", match: 1
end

def test_where_with_integer_for_string_column_using_bind_parameters
Expand Down Expand Up @@ -40,14 +40,16 @@ def test_where_with_duration_for_string_column_using_bind_parameters
end

private
def assert_quoted_as(expected, value)
def assert_quoted_as(expected, value, match: 0)
relation = Post.where("title = ?", value)
assert_equal(
%{SELECT "posts".* FROM "posts" WHERE (title = #{expected})},
relation.to_sql,
)
assert_nothing_raised do # Make sure SQL is valid
relation.to_a
if match == 0
assert_empty relation.to_a
else
assert_equal match, relation.count
end
end
end
Expand Down
2 changes: 1 addition & 1 deletion activerecord/test/cases/bind_parameter_test.rb
Expand Up @@ -102,7 +102,7 @@ def test_statement_cache_with_sql_string_literal

topics = Topic.where("topics.id = ?", 1)
assert_equal [1], topics.map(&:id)
assert_not_includes statement_cache, to_sql_key(topics.arel)
assert_includes statement_cache, to_sql_key(topics.arel)
end

def test_too_many_binds
Expand Down
2 changes: 1 addition & 1 deletion activerecord/test/cases/finder_test.rb
Expand Up @@ -200,7 +200,7 @@ def test_exists
assert_equal false, Topic.exists?(9999999999999999999999999999999)
assert_equal false, Topic.exists?(Topic.new.id)

assert_raise(NoMethodError) { Topic.exists?([1, 2]) }
assert_raise(ArgumentError) { Topic.exists?([1, 2]) }
end

def test_exists_with_scope
Expand Down
6 changes: 4 additions & 2 deletions activerecord/test/cases/quoting_test.rb
Expand Up @@ -239,8 +239,10 @@ def test_type_cast_unknown_should_raise_error
assert_raise(TypeError) { @conn.type_cast(obj) }
end

def test_type_cast_duration_should_raise_error
assert_raise(TypeError) { @conn.type_cast(1.hour) }
def test_type_cast_duration_should_raise_deprecation
assert_deprecated(ActiveRecord.deprecator) do
@conn.type_cast(1.hour)
end
end
end

Expand Down
32 changes: 16 additions & 16 deletions activerecord/test/cases/relation/merging_test.rb
Expand Up @@ -212,25 +212,25 @@ def test_merge_doesnt_duplicate_same_clauses

only_david = Author.where("#{author_id} IN (?)", david)

if current_adapter?(:Mysql2Adapter, :TrilogyAdapter)
assert_queries_match(/WHERE \(#{Regexp.escape(author_id)} IN \('1'\)\)\z/) do
assert_equal [david], only_david.merge(only_david)
end

assert_queries_match(/WHERE \(#{Regexp.escape(author_id)} IN \('1'\)\)\z/) do
assert_deprecated(ActiveRecord.deprecator) do
assert_equal [david], only_david.merge(only_david, rewhere: true)
end
matcher = if Author.connection.prepared_statements
if current_adapter?(:PostgreSQLAdapter)
/WHERE \(#{Regexp.escape(author_id)} IN \(\$1\)\)\z/
else
/WHERE \(#{Regexp.escape(author_id)} IN \(\?\)\)\z/
end
elsif current_adapter?(:Mysql2Adapter, :TrilogyAdapter)
/WHERE \(#{Regexp.escape(author_id)} IN \('1'\)\)\z/
else
assert_queries_match(/WHERE \(#{Regexp.escape(author_id)} IN \(1\)\)\z/) do
assert_equal [david], only_david.merge(only_david)
end
/WHERE \(#{Regexp.escape(author_id)} IN \(1\)\)\z/
end

assert_queries_match(matcher) do
assert_equal [david], only_david.merge(only_david)
end

assert_queries_match(/WHERE \(#{Regexp.escape(author_id)} IN \(1\)\)\z/) do
assert_deprecated(ActiveRecord.deprecator) do
assert_equal [david], only_david.merge(only_david, rewhere: true)
end
assert_queries_match(matcher) do
assert_deprecated(ActiveRecord.deprecator) do
assert_equal [david], only_david.merge(only_david, rewhere: true)
end
end
end
Expand Down
2 changes: 1 addition & 1 deletion activerecord/test/cases/relation_test.rb
Expand Up @@ -204,7 +204,7 @@ def self.sanitize_sql(args)

relation = Relation.new(klass)
relation.merge!(where: ["foo = ?", "bar"])
assert_equal Relation::WhereClause.new(["foo = bar"]), relation.where_clause
assert_equal Relation::WhereClause.new([Arel.sql("(foo = ?)", "bar")]), relation.where_clause
end

def test_merging_readonly_false
Expand Down
4 changes: 2 additions & 2 deletions activerecord/test/cases/relations_test.rb
Expand Up @@ -476,9 +476,9 @@ def test_finding_with_complex_order
def test_finding_with_sanitized_order
query = Tag.order([Arel.sql("field(id, ?)"), [1, 3, 2]]).to_sql
if current_adapter?(:Mysql2Adapter, :TrilogyAdapter)
assert_match(/field\(id, '1','3','2'\)/, query)
assert_match(/field\(id, '1',\s*'3',\s*'2'\)/, query)
else
assert_match(/field\(id, 1,3,2\)/, query)
assert_match(/field\(id, 1,\s*3,\s*2\)/, query)
end

query = Tag.order([Arel.sql("field(id, ?)"), []]).to_sql
Expand Down
14 changes: 12 additions & 2 deletions activerecord/test/cases/sanitize_test.rb
Expand Up @@ -91,11 +91,21 @@ def self.search_as_method(term)
}
end

assert_queries_match(/LIKE '20!% !_reduction!_!!'/) do
query = if searchable_post.connection.prepared_statements
if current_adapter?(:PostgreSQLAdapter)
/title LIKE \$1/
else
/title LIKE \?/
end
else
/LIKE '20!% !_reduction!_!!'/
end

assert_queries_match(query) do
searchable_post.search_as_method("20% _reduction_!").to_a
end

assert_queries_match(/LIKE '20!% !_reduction!_!!'/) do
assert_queries_match(query) do
searchable_post.search_as_scope("20% _reduction_!").to_a
end
end
Expand Down

0 comments on commit 1a144e3

Please sign in to comment.