Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eager load composite primary keys models #48490

Merged
merged 2 commits into from Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 10 additions & 8 deletions activerecord/lib/active_record/associations/join_dependency.rb
Expand Up @@ -253,22 +253,24 @@ def construct(ar_parent, parent, row, seen, model_cache, strict_loading_value)
end

if node.primary_key
key = aliases.column_alias(node, node.primary_key)
id = row[key]
keys = Array(node.primary_key).map { |column| aliases.column_alias(node, column) }
ids = keys.map { |key| row[key] }
else
key = aliases.column_alias(node, node.reflection.join_primary_key.to_s)
id = nil # Avoid id-based model caching.
keys = Array(node.reflection.join_primary_key).map { |column| aliases.column_alias(node, column.to_s) }
ids = keys.map { nil } # Avoid id-based model caching.
end

if row[key].nil?
if keys.any? { |key| row[key].nil? }
nil_association = ar_parent.association(node.reflection.name)
nil_association.loaded!
next
end

unless model = seen[ar_parent][node][id]
model = construct_model(ar_parent, node, row, model_cache, id, strict_loading_value)
seen[ar_parent][node][id] = model if id
ids.each do |id|
unless model = seen[ar_parent][node][id]
model = construct_model(ar_parent, node, row, model_cache, id, strict_loading_value)
seen[ar_parent][node][id] = model if id
end
end

construct(model, node, row, seen, model_cache, strict_loading_value)
Expand Down
Expand Up @@ -1355,18 +1355,24 @@ def columns_for_distinct(columns, orders) # :nodoc:
end

def distinct_relation_for_primary_key(relation) # :nodoc:
primary_key_columns = Array(relation.primary_key).map do |column|
visitor.compile(relation.table[column])
end

values = columns_for_distinct(
visitor.compile(relation.table[relation.primary_key]),
primary_key_columns,
relation.order_values
)

limited = relation.reselect(values).distinct!
limited_ids = select_rows(limited.arel, "SQL").map(&:last)
limited_ids = select_rows(limited.arel, "SQL").map do |results|
results.last(Array(relation.primary_key).length) # ignores order values for MySQL and Postgres
end

if limited_ids.empty?
relation.none!
else
relation.where!(relation.primary_key => limited_ids)
relation.where!(**Array(relation.primary_key).zip(limited_ids.transpose).to_h)
end

relation.limit_value = relation.offset_value = nil
Expand Down
19 changes: 19 additions & 0 deletions activerecord/test/cases/associations/eager_test.rb
Expand Up @@ -34,6 +34,7 @@
require "models/matey"
require "models/parrot"
require "models/sharded"
require "models/cpk"

class EagerLoadingTooManyIdsTest < ActiveRecord::TestCase
fixtures :citations
Expand Down Expand Up @@ -1713,6 +1714,24 @@ def test_preloading_has_many_through_with_custom_scope
assert_equal(expected_tag_ids.sort, blog_post.tags.map(&:id).sort)
end

test "preloading belongs_to with cpk" do
order = Cpk::Order.create!(shop_id: 2)
order_agreement = Cpk::OrderAgreement.create!(order: order)
assert_equal order, Cpk::OrderAgreement.eager_load(:order).find_by(id: order_agreement.id).order
end

test "preloading has_many with cpk" do
order = Cpk::Order.create!(shop_id: 2)
order_agreement = Cpk::OrderAgreement.create!(order: order)
assert_equal [order_agreement], Cpk::Order.eager_load(:order_agreements).find_by(id: order.id).order_agreements
end

test "preloading has_one with cpk" do
order = Cpk::Order.create!(shop_id: 2)
book = Cpk::Book.create!(order: order, author_id: 1, number: 3)
assert_equal book, Cpk::Order.eager_load(:book).find_by(id: order.id).book
end

private
def find_all_ordered(klass, include = nil)
klass.order("#{klass.table_name}.#{klass.primary_key}").includes(include).to_a
Expand Down