Permalink
Browse files

Association validation does not belong in a before_save callback: mov…

…e it into a validation method. Restores the expected model.valid? == model.save. Add tests for cancelling save by returning false from a before_save callback. Remove assumption that before_destroy's return value indicates whether the record was destroyed.

git-svn-id: http://svn-commit.rubyonrails.org/rails/trunk@2434 5ecf4fe2-1ee6-0310-87b1-e25e094e27de
  • Loading branch information...
1 parent c6d8a1f commit e19bd169fac08052e7f0ae4ae4e5ac368629d31f @jeremy jeremy committed Oct 2, 2005
@@ -308,8 +308,8 @@ def has_many(association_id, options = {})
end
add_multiple_associated_save_callbacks(association_name)
- add_association_callbacks(association_name, options)
-
+ add_association_callbacks(association_name, options)
+
collection_accessor_methods(association_name, association_class_name, association_class_primary_key_name, options, HasManyAssociation)
# deprecated api
@@ -686,39 +686,39 @@ def require_association_class(class_name)
end
def add_multiple_associated_save_callbacks(association_name)
- module_eval do
- before_save <<-end_eval
- @new_record_before_save = new_record?
- association = instance_variable_get("@#{association_name}")
- if association.respond_to?(:loaded?)
- if new_record?
- records_to_save = association
- else
- records_to_save = association.select{ |record| record.new_record? }
- end
- records_to_save.all? { |record| record.valid? }
+ method_name = "validate_associated_records_for_#{association_name}".to_sym
+ define_method(method_name) do
+ @new_record_before_save = new_record?
+ association = instance_variable_get("@#{association_name}")
+ if association.respond_to?(:loaded?)
+ if new_record?
+ association
+ else
+ association.select { |record| record.new_record? }
+ end.each do |record|
+ errors.add "#{association_name}" unless record.valid?
end
- end_eval
+ end
end
- module_eval do
- after_callback = <<-end_eval
- association = instance_variable_get("@#{association_name}")
- if association.respond_to?(:loaded?)
- if @new_record_before_save
- records_to_save = association
- else
- records_to_save = association.select{ |record| record.new_record? }
- end
- records_to_save.each{ |record| association.send(:insert_record, record) }
- association.send(:construct_sql) # reconstruct the SQL queries now that we know the owner's id
+ validate method_name
+
+ after_callback = <<-end_eval
+ association = instance_variable_get("@#{association_name}")
+ if association.respond_to?(:loaded?)
+ if @new_record_before_save
+ records_to_save = association
+ else
+ records_to_save = association.select { |record| record.new_record? }
end
- end_eval
+ records_to_save.each { |record| association.send(:insert_record, record) }
+ association.send(:construct_sql) # reconstruct the SQL queries now that we know the owner's id
+ end
+ end_eval
- # Doesn't use after_save as that would save associations added in after_create/after_update twice
- after_create(after_callback)
- after_update(after_callback)
- end
+ # Doesn't use after_save as that would save associations added in after_create/after_update twice
+ after_create(after_callback)
+ after_update(after_callback)
end
def association_constructor_method(constructor, association_name, association_class_name, association_class_primary_key_name, options, association_proxy_class)
@@ -894,16 +894,15 @@ def association_join(reflection)
end
def add_association_callbacks(association_name, options)
- callbacks = %w(before_add after_add before_remove after_remove)
- callbacks.each do |callback_name|
- full_callback_name = "#{callback_name.to_s}_for_#{association_name.to_s}"
- defined_callbacks = options[callback_name.to_sym]
- if options.has_key?(callback_name.to_sym)
- callback_array = defined_callbacks.kind_of?(Array) ? defined_callbacks : [defined_callbacks]
- class_inheritable_reader full_callback_name.to_sym
- write_inheritable_array(full_callback_name.to_sym, callback_array)
- end
- end
+ callbacks = %w(before_add after_add before_remove after_remove)
+ callbacks.each do |callback_name|
+ full_callback_name = "#{callback_name.to_s}_for_#{association_name.to_s}"
+ defined_callbacks = options[callback_name.to_sym]
+ if options.has_key?(callback_name.to_sym)
+ class_inheritable_reader full_callback_name.to_sym
+ write_inheritable_array(full_callback_name.to_sym, [defined_callbacks].flatten)
+ end
+ end
end
def extract_record(schema_abbreviations, table_name, row)
@@ -915,5 +914,6 @@ def extract_record(schema_abbreviations, table_name, row)
return record
end
end
+
end
end
@@ -1206,7 +1206,6 @@ def frozen?
private
def create_or_update
if new_record? then create else update end
- true
end
# Updates the associated record with values matching those of the instant attributes.
@@ -657,7 +657,12 @@ def validation_method(on)
# The validation process on save can be skipped by passing false. The regular Base#save method is
# replaced with this when the validations module is mixed in, which it is by default.
def save_with_validation(perform_validation = true)
- if perform_validation && valid? || !perform_validation then save_without_validation else false end
+ if perform_validation && valid? || !perform_validation
+ save_without_validation
+ true
+ else
+ false
+ end
end
# Attempts to save the record just like Base.save but will raise a RecordInvalid exception instead of returning false
@@ -694,8 +699,7 @@ def valid?
# Returns the Errors object that holds all information about attribute error messages.
def errors
- @errors = Errors.new(self) if @errors.nil?
- @errors
+ @errors ||= Errors.new(self)
end
protected
@@ -425,6 +425,7 @@ def test_invalid_adding
firm = Firm.find(1)
assert !(firm.clients_of_firm << c = Client.new)
assert c.new_record?
+ assert !firm.valid?
assert !firm.save
assert c.new_record?
end
@@ -436,7 +437,7 @@ def test_invalid_adding_before_save
new_firm.clients_of_firm.concat([c = Client.new, Client.new("name" => "Apple")])
assert c.new_record?
assert !c.valid?
- assert new_firm.valid?
+ assert !new_firm.valid?
assert !new_firm.save
assert c.new_record?
assert new_firm.new_record?
@@ -73,20 +73,39 @@ def on_after_save
class ImmutableDeveloper < ActiveRecord::Base
set_table_name 'developers'
- before_destroy :cancel_destroy
-
- private
-
- def cancel_destroy
- return false
+ validates_inclusion_of :salary, :in => 50000..200000
+
+ before_save :cancel
+ before_destroy :cancel
+
+ def cancelled?
+ @cancelled == true
end
+
+ private
+ def cancel
+ @cancelled = true
+ false
+ end
end
class ImmutableMethodDeveloper < ActiveRecord::Base
set_table_name 'developers'
- def before_destroy
- return false
+ validates_inclusion_of :salary, :in => 50000..200000
+
+ def cancelled?
+ @cancelled == true
+ end
+
+ def before_save
+ @cancelled = true
+ false
+ end
+
+ def before_destroy
+ @cancelled = true
+ false
end
end
@@ -301,24 +320,43 @@ def test_delete
[ :after_initialize, :block ],
], david.history
end
-
+
+ def test_before_save_returning_false
+ david = ImmutableDeveloper.find(1)
+ assert david.valid?
+ assert david.save
+ assert david.cancelled?
+
+ david = ImmutableDeveloper.find(1)
+ david.salary = 10_000_000
+ assert !david.valid?
+ assert !david.save
+ assert !david.cancelled?
+
+ david = ImmutableMethodDeveloper.find(1)
+ assert david.valid?
+ assert david.save
+ assert david.cancelled?
+
+ david = ImmutableMethodDeveloper.find(1)
+ david.salary = 10_000_000
+ assert !david.valid?
+ assert !david.save
+ assert !david.cancelled?
+ end
+
def test_before_destroy_returning_false
david = ImmutableDeveloper.find(1)
- devs = ImmutableDeveloper.find(:all).size
- assert !david.destroy
- # cancel_destroy returns false so the destruction should
- # be cancelled
- assert_equal ImmutableDeveloper.find(:all).size, devs
-
+ david.destroy
+ assert david.cancelled?
+ assert_not_nil ImmutableDeveloper.find_by_id(1)
+
david = ImmutableMethodDeveloper.find(1)
- devs = ImmutableMethodDeveloper.find(:all).size
- assert !david.destroy
- # before_destroy returns false so the destruction should
- # be cancelled
- assert_equal ImmutableMethodDeveloper.find(:all).size, devs
+ david.destroy
+ assert david.cancelled?
+ assert_not_nil ImmutableMethodDeveloper.find_by_id(1)
end
-
-
+
def test_zzz_callback_returning_false # must be run last since we modify CallbackDeveloper
david = CallbackDeveloper.find(1)

0 comments on commit e19bd16

Please sign in to comment.