diff --git a/activerecord/lib/active_record/associations/join_dependency.rb b/activerecord/lib/active_record/associations/join_dependency.rb index b852656c854c6..871ca1c91db84 100644 --- a/activerecord/lib/active_record/associations/join_dependency.rb +++ b/activerecord/lib/active_record/associations/join_dependency.rb @@ -69,7 +69,7 @@ def graft(*associations) join_assocs = join_associations associations.reject { |association| - join_assocs.detect { |a| association == a } + join_assocs.detect { |a| node_cmp association, a } }.each { |association| join_node = find_parent_node(association.parent) || @join_root type = association.join_type @@ -122,14 +122,19 @@ def instantiate(result_set) private def find_parent_node(parent) - @join_root.find { |join_part| - case parent - when JoinBase - parent.base_klass == join_part.base_klass - else - parent == join_part - end - } + @join_root.find { |join_part| node_cmp parent, join_part } + end + + def node_cmp(parent, join_part) + return unless parent.class == join_part.class + + case parent + when JoinBase + parent.base_klass == join_part.base_klass + else + parent.reflection == join_part.reflection && + node_cmp(parent.parent, join_part.parent) + end end def join_base @@ -182,7 +187,7 @@ def find_or_build_scalar(reflection, parent, join_type) def find_join_association(reflection, parent) join_associations.detect { |j| - j.reflection == reflection && j.parent == parent + j.reflection == reflection && node_cmp(j.parent, parent) } end diff --git a/activerecord/lib/active_record/associations/join_dependency/join_association.rb b/activerecord/lib/active_record/associations/join_dependency/join_association.rb index e2ac892e71e9d..3f9afa8992959 100644 --- a/activerecord/lib/active_record/associations/join_dependency/join_association.rb +++ b/activerecord/lib/active_record/associations/join_dependency/join_association.rb @@ -33,12 +33,6 @@ def initialize(reflection, index, parent, join_type, alias_tracker) def parent_table_name; parent.table_name; end alias :alias_suffix :parent_table_name - def ==(other) - other.class == self.class && - other.reflection == reflection && - other.parent == parent - end - def join_constraints joins = [] tables = @tables.dup diff --git a/activerecord/lib/active_record/associations/join_dependency/join_base.rb b/activerecord/lib/active_record/associations/join_dependency/join_base.rb index c90ae80e4a3ff..d3bc3dd1ad7a0 100644 --- a/activerecord/lib/active_record/associations/join_dependency/join_base.rb +++ b/activerecord/lib/active_record/associations/join_dependency/join_base.rb @@ -8,11 +8,6 @@ def initialize(klass) super(klass, nil) end - def ==(other) - other.class == self.class && - other.base_klass == base_klass - end - def aliased_prefix "t0" end diff --git a/activerecord/lib/active_record/associations/join_dependency/join_part.rb b/activerecord/lib/active_record/associations/join_dependency/join_part.rb index 3e9bdbfbab5e5..b8ce8ff46d6a9 100644 --- a/activerecord/lib/active_record/associations/join_dependency/join_part.rb +++ b/activerecord/lib/active_record/associations/join_dependency/join_part.rb @@ -48,10 +48,6 @@ def aliased_table Arel::Nodes::TableAlias.new table, aliased_table_name end - def ==(other) - raise NotImplementedError - end - # An Arel::Table for the active_record def table raise NotImplementedError