Permalink
Browse files

Use hash[:Post][1] style identity maps for each table.

  • Loading branch information...
1 parent 4f3b8e1 commit 301dd3d5143077f95bebd434ca2ad1c80c1b5866 @miloops miloops committed Oct 14, 2010
@@ -253,9 +253,9 @@ def method_missing(method, *args)
def load_target
return nil unless defined?(@loaded)
- if !loaded? and (!@owner.persisted? || foreign_key_present)
- if IdentityMap.enabled?
- @target = IdentityMap.get(@reflection.class_name, @owner[@reflection.association_foreign_key])
+ if !loaded? and (@owner.persisted? || foreign_key_present)
+ if IdentityMap.enabled? && defined?(@reflection.klass)
+ @target = IdentityMap.get(@reflection.klass, @owner[@reflection.association_foreign_key])
end
@target ||= find_target
end
@@ -792,6 +792,10 @@ def ===(object)
object.is_a?(self)
end
+ def symbolized_base_class
+ @symbolized_base_class ||= base_class.to_s.to_sym
+ end
+
# Returns the base AR subclass that this class descends from. If A
# extends AR::Base, A.base_class will return A. If B descends from A
# through some arbitrarily deep hierarchy, B.base_class will return A.
@@ -887,7 +891,10 @@ def instantiate(record)
record_id = sti_class.primary_key && record[sti_class.primary_key]
if ActiveRecord::IdentityMap.enabled? && record_id
- if instance = identity_map.get(sti_class.name, record_id)
+ if (column = sti_class.columns_hash[sti_class.primary_key]) && column.number?
+ record_id = record_id.to_i
+ end
+ if instance = identity_map.get(sti_class, record_id)
instance.reinit_with('attributes' => record)
else
instance = sti_class.allocate.init_with('attributes' => record)
@@ -23,7 +23,7 @@ class << self
attr_accessor :enabled
def current
- repositories[current_repository_name] ||= Weakling::WeakHash.new
+ repositories[current_repository_name] ||= Hash.new { |h,k| h[k] = Weakling::WeakHash.new }
end
def with_repository(name = :default)
@@ -43,16 +43,20 @@ def without
self.enabled = old
end
- def get(class_name, primary_key)
- current[[class_name, primary_key.to_s]]
+ def get(klass, primary_key)
+ if obj = current[klass.symbolized_base_class][primary_key]
+ return obj if obj.id == primary_key && klass == obj.class
+ end
+
+ nil
end
def add(record)
- current[[record.class.name, record.id.to_s]] = record
+ current[record.class.symbolized_base_class][record.id] = record
end
def remove(record)
- current.delete([record.class.name, record.id.to_s])
+ current[record.class.symbolized_base_class].delete(record.id)
end
def clear
@@ -53,6 +53,13 @@ def test_find_by_id
)
end
+ def test_find_by_string_and_numeric_id
+ assert_same(
+ Client.find_by_id("3"),
+ Client.find_by_id(3)
+ )
+ end
+
def test_find_by_pkey
assert_same(
Subscriber.find_by_nick('swistak'),

0 comments on commit 301dd3d

Please sign in to comment.