Skip to content

Commit

Permalink
Added STI support to init and building associations
Browse files Browse the repository at this point in the history
Allows you to do BaseClass.new(:type => "SubClass") as well as
parent.children.build(:type => "SubClass") or parent.build_child
to initialize an STI subclass. Ensures that the class name is a
valid class and that it is in the ancestors of the super class
that the association is expecting.
  • Loading branch information
diminish7 authored and pixeltrix committed Nov 29, 2012
1 parent eba430a commit 89b5b31
Show file tree
Hide file tree
Showing 11 changed files with 163 additions and 1 deletion.
9 changes: 9 additions & 0 deletions activerecord/CHANGELOG.md
@@ -1,5 +1,14 @@
## Rails 4.0.0 (unreleased) ##

* Add STI support to init and building associations.
Allows you to do BaseClass.new(:type => "SubClass") as well as
parent.children.build(:type => "SubClass") or parent.build_child
to initialize an STI subclass. Ensures that the class name is a
valid class and that it is in the ancestors of the super class
that the association is expecting.

*Jason Rush*

* Observers was extracted from Active Record as `rails-observers` gem.

*Rafael Mendonça França*
Expand Down
1 change: 1 addition & 0 deletions activerecord/lib/active_record/base.rb
Expand Up @@ -13,6 +13,7 @@
require 'active_support/core_ext/kernel/singleton_class'
require 'active_support/core_ext/module/introspection'
require 'active_support/core_ext/object/duplicable'
require 'active_support/core_ext/class/subclasses'
require 'arel'
require 'active_record/errors'
require 'active_record/log_subscriber'
Expand Down
26 changes: 26 additions & 0 deletions activerecord/lib/active_record/inheritance.rb
Expand Up @@ -9,6 +9,19 @@ module Inheritance
end

module ClassMethods
# Determines if one of the attributes passed in is the inheritance column,
# and if the inheritance column is attr accessible, it initializes an
# instance of the given subclass instead of the base class
def new(*args, &block)
if (attrs = args.first).is_a?(Hash)
if subclass = subclass_from_attrs(attrs)
return subclass.new(*args, &block)
end
end
# Delegate to the original .new
super
end

# True if this isn't a concrete subclass needing a STI type condition.
def descends_from_active_record?
if self == Base
Expand Down Expand Up @@ -145,6 +158,19 @@ def type_condition(table = arel_table)

sti_column.in(sti_names)
end

# Detect the subclass from the inheritance column of attrs. If the inheritance column value
# is not self or a valid subclass, raises ActiveRecord::SubclassNotFound
# If this is a StrongParameters hash, and access to inheritance_column is not permitted,
# this will ignore the inheritance column and return nil
def subclass_from_attrs(attrs)
subclass_name = attrs.with_indifferent_access[inheritance_column]
return nil if subclass_name.blank? || subclass_name == self.name
unless subclass = subclasses.detect { |sub| sub.name == subclass_name }

This comment has been minimized.

Copy link
@ahacking

ahacking May 31, 2013

Any reason this isn't using descendants.detect to allow inheritance hierarchies greater than one level ?

raise ActiveRecord::SubclassNotFound.new("Invalid single-table inheritance type: #{subclass_name} is not a subclass of #{name}")
end
subclass
end
end

private
Expand Down
2 changes: 1 addition & 1 deletion activerecord/lib/active_record/reflection.rb
Expand Up @@ -179,7 +179,7 @@ def initialize(*args)
@collection = [:has_many, :has_and_belongs_to_many].include?(macro)
end

# Returns a new, unsaved instance of the associated class. +options+ will
# Returns a new, unsaved instance of the associated class. +attributes+ will
# be passed to the class's constructor.
def build_association(attributes, &block)
klass.new(attributes, &block)
Expand Down
Expand Up @@ -109,6 +109,34 @@ def test_building_the_belonging_object
assert_equal apple.id, citibank.firm_id
end

def test_building_the_belonging_object_with_implicit_sti_base_class
account = Account.new
company = account.build_firm
assert(company.kind_of?(Company), "Expected #{company.class} to be a Company")
end

def test_building_the_belonging_object_with_explicit_sti_base_class
account = Account.new
company = account.build_firm(:type => "Company")
assert(company.kind_of?(Company), "Expected #{company.class} to be a Company")
end

def test_building_the_belonging_object_with_sti_subclass
account = Account.new
company = account.build_firm(:type => "Firm")
assert(company.kind_of?(Firm), "Expected #{company.class} to be a Firm")
end

def test_building_the_belonging_object_with_an_invalid_type
account = Account.new
assert_raise(ActiveRecord::SubclassNotFound) { account.build_firm(:type => "InvalidType") }
end

def test_building_the_belonging_object_with_an_unrelated_type
account = Account.new
assert_raise(ActiveRecord::SubclassNotFound) { account.build_firm(:type => "Account") }
end

def test_building_the_belonging_object_with_primary_key
client = Client.create(:name => "Primary key client")
apple = client.build_firm_with_primary_key("name" => "Apple")
Expand Down
28 changes: 28 additions & 0 deletions activerecord/test/cases/associations/has_many_associations_test.rb
Expand Up @@ -144,6 +144,34 @@ def test_create_from_association_with_nil_values_should_work
assert_equal 'defaulty', bulb.name
end

def test_building_the_associated_object_with_implicit_sti_base_class
firm = DependentFirm.new
company = firm.companies.build
assert(company.kind_of?(Company), "Expected #{company.class} to be a Company")
end

def test_building_the_associated_object_with_explicit_sti_base_class
firm = DependentFirm.new
company = firm.companies.build(:type => "Company")
assert(company.kind_of?(Company), "Expected #{company.class} to be a Company")
end

def test_building_the_associated_object_with_sti_subclass
firm = DependentFirm.new
company = firm.companies.build(:type => "Client")
assert(company.kind_of?(Client), "Expected #{company.class} to be a Client")
end

def test_building_the_associated_object_with_an_invalid_type
firm = DependentFirm.new
assert_raise(ActiveRecord::SubclassNotFound) { firm.companies.build(:type => "Invalid") }
end

def test_building_the_associated_object_with_an_unrelated_type
firm = DependentFirm.new
assert_raise(ActiveRecord::SubclassNotFound) { firm.companies.build(:type => "Account") }
end

def test_association_keys_bypass_attribute_protection
car = Car.create(:name => 'honda')

Expand Down
30 changes: 30 additions & 0 deletions activerecord/test/cases/associations/has_one_associations_test.rb
Expand Up @@ -6,6 +6,8 @@
require 'models/pirate'
require 'models/car'
require 'models/bulb'
require 'models/author'
require 'models/post'

class HasOneAssociationsTest < ActiveRecord::TestCase
self.use_transactional_fixtures = false unless supports_savepoints?
Expand Down Expand Up @@ -212,6 +214,34 @@ def test_build_association_dont_create_transaction
}
end

def test_building_the_associated_object_with_implicit_sti_base_class
firm = DependentFirm.new
company = firm.build_company
assert(company.kind_of?(Company), "Expected #{company.class} to be a Company")
end

def test_building_the_associated_object_with_explicit_sti_base_class
firm = DependentFirm.new
company = firm.build_company(:type => "Company")
assert(company.kind_of?(Company), "Expected #{company.class} to be a Company")
end

def test_building_the_associated_object_with_sti_subclass
firm = DependentFirm.new
company = firm.build_company(:type => "Client")
assert(company.kind_of?(Client), "Expected #{company.class} to be a Client")
end

def test_building_the_associated_object_with_an_invalid_type
firm = DependentFirm.new
assert_raise(ActiveRecord::SubclassNotFound) { firm.build_company(:type => "Invalid") }
end

def test_building_the_associated_object_with_an_unrelated_type
firm = DependentFirm.new
assert_raise(ActiveRecord::SubclassNotFound) { firm.build_company(:type => "Account") }
end

def test_build_and_create_should_not_happen_within_scope
pirate = pirates(:blackbeard)
scoped_count = pirate.association(:foo_bulb).scope.where_values.count
Expand Down
15 changes: 15 additions & 0 deletions activerecord/test/cases/forbidden_attributes_protection_test.rb
@@ -1,6 +1,7 @@
require 'cases/helper'
require 'active_support/core_ext/hash/indifferent_access'
require 'models/person'
require 'models/company'

class ProtectedParams < ActiveSupport::HashWithIndifferentAccess
attr_accessor :permitted
Expand Down Expand Up @@ -40,6 +41,20 @@ def test_permitted_attributes_can_be_used_for_mass_assignment
assert_equal 'm', person.gender
end

def test_forbidden_attributes_cannot_be_used_for_sti_inheritance_column
params = ProtectedParams.new(type: 'Client')
assert_raises(ActiveModel::ForbiddenAttributesError) do
Company.new(params)
end
end

def test_permitted_attributes_can_be_used_for_sti_inheritance_column
params = ProtectedParams.new(type: 'Client')
params.permit!
person = Company.new(params)
assert_equal person.class, Client
end

def test_regular_hash_should_still_be_used_for_mass_assignment
person = Person.new(first_name: 'Guille', gender: 'm')

Expand Down
23 changes: 23 additions & 0 deletions activerecord/test/cases/inheritance_test.rb
Expand Up @@ -156,6 +156,29 @@ def test_alt_inheritance_save
assert_kind_of Cabbage, savoy
end

def test_inheritance_new_with_default_class
company = Company.new
assert_equal company.class, Company
end

def test_inheritance_new_with_base_class
company = Company.new(:type => 'Company')
assert_equal company.class, Company
end

def test_inheritance_new_with_subclass
firm = Company.new(:type => 'Firm')
assert_equal firm.class, Firm
end

def test_new_with_invalid_type
assert_raise(ActiveRecord::SubclassNotFound) { Company.new(:type => 'InvalidType') }
end

def test_new_with_unrelated_type
assert_raise(ActiveRecord::SubclassNotFound) { Company.new(:type => 'Account') }
end

def test_inheritance_condition
assert_equal 10, Company.count
assert_equal 2, Firm.count
Expand Down
1 change: 1 addition & 0 deletions activerecord/test/models/author.rb
@@ -1,5 +1,6 @@
class Author < ActiveRecord::Base
has_many :posts
has_one :post
has_many :very_special_comments, :through => :posts
has_many :posts_with_comments, -> { includes(:comments) }, :class_name => "Post"
has_many :popular_grouped_posts, -> { includes(:comments).group("type").having("SUM(comments_count) > 1").select("type") }, :class_name => "Post"
Expand Down
1 change: 1 addition & 0 deletions activerecord/test/models/company.rb
Expand Up @@ -111,6 +111,7 @@ def log_after_remove(record)
class DependentFirm < Company
has_one :account, :foreign_key => "firm_id", :dependent => :nullify
has_many :companies, :foreign_key => 'client_of', :dependent => :nullify
has_one :company, :foreign_key => 'client_of', :dependent => :nullify
end

class RestrictedFirm < Company
Expand Down

2 comments on commit 89b5b31

@runlevel5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 this is nice for whom who's doing all sort of workaround in FactoryGirl just to have a factory of an STI class.

@francispotter
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this assume you're using class names in the database rather than overriding them as per https://gist.github.com/sumskyi/1381880?

Please sign in to comment.