diff --git a/lib/baby_squeel/active_record.rb b/lib/baby_squeel/active_record.rb index 4f71e53..3417c24 100644 --- a/lib/baby_squeel/active_record.rb +++ b/lib/baby_squeel/active_record.rb @@ -13,7 +13,8 @@ def sifter(name, &block) module QueryMethods # Constructs Arel for ActiveRecord::Base#joins using the DSL. def joining(&block) - joins DSL.evaluate(unscoped, &block) + arel, binds = DSL.evaluate_joins(unscoped, &block) + joins(arel).tap { |s| s.bind_values += binds } end # Constructs Arel for ActiveRecord::Base#select using the DSL. diff --git a/lib/baby_squeel/dsl.rb b/lib/baby_squeel/dsl.rb index a0b41f0..e23c766 100644 --- a/lib/baby_squeel/dsl.rb +++ b/lib/baby_squeel/dsl.rb @@ -4,16 +4,32 @@ module BabySqueel class DSL < Table - # Evaluates a block in the context of a new DSL instance. - def self.evaluate(scope, &block) - new(scope).evaluate(&block) - end + class << self + # Evaluates a block and unwraps the nodes + def evaluate(scope, &block) + Nodes.unwrap evaluate!(scope, &block) + end + + # Evaluates a block in the context of a DSL instance + def evaluate!(scope, &block) + new(scope).evaluate(&block) + end + + # Evaluates a block specifically for a join. In this + # case, we'll return an array of Arel join nodes and + # a list of bind parameters. + def evaluate_joins(scope, &block) + dependency = evaluate!(scope, &block)._arel + join_arel = Nodes.unwrap(dependency._arel) + [join_arel, dependency.bind_values] + end - # Evaluates a block in the context of a new DSL instance - # and passes all arguments to the block. - def self.evaluate_sifter(scope, *args, &block) - evaluate scope do |root| - root.instance_exec(*args, &block) + # Evaluates a block in the context of a new DSL instance + # and passes all arguments to the block. + def evaluate_sifter(scope, *args, &block) + evaluate scope do |root| + root.instance_exec(*args, &block) + end end end @@ -39,16 +55,16 @@ def sql(value) # Quotes a string and marks it as SQL def quoted(value) - sql @scope.connection.quote(value) + sql _scope.connection.quote(value) end # Evaluates a DSL block. If arity is given, this method # `yield` itself, rather than `instance_eval`. def evaluate(&block) if block.arity.zero? - Nodes.unwrap instance_eval(&block) + instance_eval(&block) else - Nodes.unwrap yield(self) + yield(self) end end diff --git a/lib/baby_squeel/join_dependency.rb b/lib/baby_squeel/join_dependency.rb index 869c93c..0c4ca2f 100644 --- a/lib/baby_squeel/join_dependency.rb +++ b/lib/baby_squeel/join_dependency.rb @@ -1,7 +1,10 @@ module BabySqueel class JoinDependency - def initialize(scope, associations = []) - @scope = scope + attr_reader :bind_values + + def initialize(table, associations = []) + @table = table + @bind_values = [] @associations = associations end @@ -10,9 +13,13 @@ def initialize(scope, associations = []) # # Each association is built individually so that the correct # Arel join node will be used for each individual association. - def constraints - @associations.each.with_index.inject([]) do |joins, (assoc, i)| - inject @associations[0..i], joins, assoc._join + def _arel + if @table._on + [@table._join.new(@table._table, @table._on)] + else + @associations.each.with_index.inject([]) do |joins, (assoc, i)| + inject @associations[0..i], joins, assoc._join + end end end @@ -25,7 +32,12 @@ def inject(associations, theirs, join_node) end def build(names, join_node) - @scope.joins(names).join_sources.map do |join| + relation = @table._scope.joins(names) + + @bind_values = relation.arel.bind_values + @bind_values += relation.bind_values + + relation.join_sources.map do |join| join_node.new(join.left, join.right) end end diff --git a/lib/baby_squeel/nodes.rb b/lib/baby_squeel/nodes.rb index 82d9ed3..bde91b8 100644 --- a/lib/baby_squeel/nodes.rb +++ b/lib/baby_squeel/nodes.rb @@ -17,7 +17,7 @@ def wrap(arel) # ActiveRecord. def unwrap(node) if node.respond_to? :_arel - node._arel + unwrap node._arel elsif node.is_a? Array node.map { |n| unwrap(n) } else @@ -94,7 +94,7 @@ def in(rel) end def _arel - parent_arel = @parent._arel + parent_arel = @parent._arel._arel if parent_arel && parent_arel.last parent_arel.last.left[@name] diff --git a/lib/baby_squeel/table.rb b/lib/baby_squeel/table.rb index eb88f9f..64e4c15 100644 --- a/lib/baby_squeel/table.rb +++ b/lib/baby_squeel/table.rb @@ -14,10 +14,10 @@ def initialize(model_name, name) end class Table - attr_accessor :_on, :_join, :_table + attr_accessor :_scope, :_on, :_join, :_table def initialize(scope) - @scope = scope + @_scope = scope @_table = scope.arel_table @_join = Arel::Nodes::InnerJoin end @@ -30,15 +30,15 @@ def [](key) # Constructs a new BabySqueel::Association. Raises # an exception if the association is not found. def association(name) - if reflection = @scope.reflect_on_association(name) + if reflection = _scope.reflect_on_association(name) Association.new(self, reflection) else - raise AssociationNotFoundError.new(@scope.model_name, name) + raise AssociationNotFoundError.new(_scope.model_name, name) end end def sift(sifter_name, *args) - Nodes.wrap @scope.public_send("sift_#{sifter_name}", *args) + Nodes.wrap _scope.public_send("sift_#{sifter_name}", *args) end # Alias a table. This is only possible when joining @@ -89,19 +89,15 @@ def on!(node) # 2. Resolve the assocition's join clauses using ActiveRecord. # def _arel(associations = []) - if _on - _join.new(_table, _on) - else - JoinDependency.new(@scope, associations).constraints - end + JoinDependency.new(self, associations) end private def resolve(name) - if @scope.column_names.include?(name.to_s) + if _scope.column_names.include?(name.to_s) self[name] - elsif @scope.reflect_on_association(name) + elsif _scope.reflect_on_association(name) association(name) end end @@ -114,7 +110,7 @@ def method_missing(name, *args, &block) return super if !args.empty? || block_given? resolve(name) || begin - raise NotFoundError.new(@scope.model_name, name) + raise NotFoundError.new(_scope.model_name, name) end end end diff --git a/spec/baby_squeel/active_record/query_methods/joining_spec.rb b/spec/baby_squeel/active_record/query_methods/joining_spec.rb index e2c4e9b..6353696 100644 --- a/spec/baby_squeel/active_record/query_methods/joining_spec.rb +++ b/spec/baby_squeel/active_record/query_methods/joining_spec.rb @@ -86,6 +86,16 @@ EOSQL end + it 'merges bind values' do + relation = Post.joining { ugly_author_comments } + + expect(relation).to produce_sql(<<-EOSQL) + SELECT "posts".* FROM "posts" + INNER JOIN "authors" ON "authors"."id" = "posts"."author_id" AND "authors"."ugly" = 't' + INNER JOIN "comments" ON "comments"."author_id" = "authors"."id" + EOSQL + end + context 'with complex conditions' do it 'inner joins' do relation = Post.joining { diff --git a/spec/support/models.rb b/spec/support/models.rb index 193c72b..eab31b8 100644 --- a/spec/support/models.rb +++ b/spec/support/models.rb @@ -3,10 +3,17 @@ class Author < ActiveRecord::Base has_many :comments end +class UglyAuthor < Author + default_scope { where ugly: true } +end + class Post < ActiveRecord::Base has_many :comments belongs_to :author has_many :author_comments, through: :author, source: :comments + + belongs_to :ugly_author, foreign_key: :author_id + has_many :ugly_author_comments, through: :ugly_author, source: :comments end class Comment < ActiveRecord::Base diff --git a/spec/support/schema.rb b/spec/support/schema.rb index ec7a411..56bdb8c 100644 --- a/spec/support/schema.rb +++ b/spec/support/schema.rb @@ -8,6 +8,7 @@ ActiveRecord::Schema.define do create_table :authors, force: true do |t| t.string :name + t.boolean :ugly t.timestamps null: false end