diff --git a/activerecord/lib/active_record/attribute_methods/serialization.rb b/activerecord/lib/active_record/attribute_methods/serialization.rb index 04328aeb5e23e..c7c61128ea253 100644 --- a/activerecord/lib/active_record/attribute_methods/serialization.rb +++ b/activerecord/lib/active_record/attribute_methods/serialization.rb @@ -133,6 +133,7 @@ def serialize(attr_name, class_name_or_coder = Object, **options) raise ColumnNotSerializableError.new(attr_name, cast_type) end + cast_type = cast_type.subtype if Type::Serialized === cast_type Type::Serialized.new(cast_type, coder) end end diff --git a/activerecord/lib/active_record/attributes.rb b/activerecord/lib/active_record/attributes.rb index 8cc41a25b8fad..04a83d5dc4d57 100644 --- a/activerecord/lib/active_record/attributes.rb +++ b/activerecord/lib/active_record/attributes.rb @@ -208,21 +208,30 @@ module ClassMethods # tracking is performed. The methods +changed?+ and +changed_in_place?+ # will be called from ActiveModel::Dirty. See the documentation for those # methods in ActiveModel::Type::Value for more details. - def attribute(name, cast_type = nil, **options, &decorator) + def attribute(name, cast_type = nil, default: NO_DEFAULT_PROVIDED, **options, &block) name = name.to_s reload_schema_from_cache - prev_cast_type, prev_options, prev_decorator = attributes_to_define_after_schema_loads[name] + case cast_type + when Symbol + type = cast_type + cast_type = -> _ { Type.lookup(type, **options, adapter: Type.adapter_name_from(self)) } + when nil + if (prev_cast_type, prev_default = attributes_to_define_after_schema_loads[name]) + default = prev_default if default == NO_DEFAULT_PROVIDED - unless cast_type && prev_cast_type - cast_type ||= prev_cast_type - options = prev_options || options if options.empty? - decorator ||= prev_decorator + cast_type = if block_given? + -> subtype { yield Proc === prev_cast_type ? prev_cast_type[subtype] : prev_cast_type } + else + prev_cast_type + end + else + cast_type = block || -> subtype { subtype } + end end - self.attributes_to_define_after_schema_loads = attributes_to_define_after_schema_loads.merge( - name => [cast_type, options, decorator] - ) + self.attributes_to_define_after_schema_loads = + attributes_to_define_after_schema_loads.merge(name => [cast_type, default]) end # This is the low level API which sits beneath +attribute+. It only @@ -255,16 +264,9 @@ def define_attribute( def load_schema! # :nodoc: super - attributes_to_define_after_schema_loads.each do |name, (type, options, decorator)| - if type.is_a?(Symbol) - type = ActiveRecord::Type.lookup(type, **options.except(:default), adapter: ActiveRecord::Type.adapter_name_from(self)) - elsif type.nil? - type = type_for_attribute(name) - end - - type = decorator[type] if decorator - - define_attribute(name, type, **options.slice(:default)) + attributes_to_define_after_schema_loads.each do |name, (cast_type, default)| + cast_type = cast_type[type_for_attribute(name)] if Proc === cast_type + define_attribute(name, cast_type, default: default) end end diff --git a/activerecord/lib/active_record/enum.rb b/activerecord/lib/active_record/enum.rb index 2e619b1b5f600..ca9c759f1930f 100644 --- a/activerecord/lib/active_record/enum.rb +++ b/activerecord/lib/active_record/enum.rb @@ -153,8 +153,10 @@ def assert_valid_value(value) end end + attr_reader :subtype + private - attr_reader :name, :mapping, :subtype + attr_reader :name, :mapping end def enum(definitions) @@ -182,6 +184,7 @@ def enum(definitions) attr = attribute_alias?(name) ? attribute_alias(name) : name attribute(attr, **default) do |subtype| + subtype = subtype.subtype if EnumType === subtype EnumType.new(attr, enum_values, subtype) end diff --git a/activerecord/test/cases/serialized_attribute_test.rb b/activerecord/test/cases/serialized_attribute_test.rb index d9f01ca9f3d8a..a750e572ecea9 100644 --- a/activerecord/test/cases/serialized_attribute_test.rb +++ b/activerecord/test/cases/serialized_attribute_test.rb @@ -392,6 +392,68 @@ def test_nil_is_always_persisted_as_null assert_equal [topic], Topic.where(content: nil) end + class EncryptedType < ActiveRecord::Type::Text + include ActiveModel::Type::Helpers::Mutable + + attr_reader :subtype, :encryptor + + def initialize(subtype: ActiveModel::Type::String.new) + super() + + @subtype = subtype + @encryptor = ActiveSupport::MessageEncryptor.new("abcd" * 8) + end + + def serialize(value) + subtype.serialize(value).yield_self do |cleartext| + encryptor.encrypt_and_sign(cleartext) unless cleartext.nil? + end + end + + def deserialize(ciphertext) + encryptor.decrypt_and_verify(ciphertext) + .yield_self { |cleartext| subtype.deserialize(cleartext) } unless ciphertext.nil? + end + + def changed_in_place?(old, new) + if old.nil? + !new.nil? + else + deserialize(old) != new + end + end + end + + def test_decorated_type_with_type_for_attribute + old_registry = ActiveRecord::Type.registry + ActiveRecord::Type.registry = ActiveRecord::Type.registry.dup + ActiveRecord::Type.register :encrypted, EncryptedType + + klass = Class.new(ActiveRecord::Base) do + self.table_name = Topic.table_name + store :content + attribute :content, :encrypted, subtype: type_for_attribute(:content) + end + + topic = klass.create!(content: { trial: true }) + + assert_equal({ "trial" => true }, topic.content) + ensure + ActiveRecord::Type.registry = old_registry + end + + def test_decorated_type_with_decorator_block + klass = Class.new(ActiveRecord::Base) do + self.table_name = Topic.table_name + store :content + attribute(:content) { |subtype| EncryptedType.new(subtype: subtype) } + end + + topic = klass.create!(content: { trial: true }) + + assert_equal({ "trial" => true }, topic.content) + end + def test_mutation_detection_does_not_double_serialize coder = Object.new def coder.dump(value)