-
-
Notifications
You must be signed in to change notification settings - Fork 1.6k
SessionIndexing
This recipe presents a generalized way to "index" objects in memory as they are placed into Sessions, so that later they can be retrieved based on particular criteria. The use case for this could be to assist in writing before_flush() event handlers, where particular subsets of objects in a Session need to be inspected, and a SQL round trip is specifically not wanted, typically due to performance concerns.
The technique is actually pretty simplistic, and does not account for the case where the objects are mutated in the Session, such that the object would be indexed differently. To handle that, attribute-on-change events would also need to be intercepted, resulting in a re-indexing of a particular target index.
In constrast to this recipe, it is of course vastly simpler just to use the Session normally, emitting a query against the database whose results are then correlated against what's already in the Sessions' identity map; the use case here is specifically one of avoiding those round trips.
import weakref
import collections
from sqlalchemy import event
from sqlalchemy.orm import Session
from sqlalchemy.orm import mapper
class Index(object):
"""An in-memory 'index' of objects in sessions.
Listens for objects being attached to sessions and
indexes them according to a series of user-defined "indexing"
functions.
"""
def __init__(self):
# dictionary of (name of index -> how to index)
self._index_fns = weakref.WeakKeyDictionary()
# dictionary of (session object ->
# dictionary of
# ((indexname, value) -> set of instances)
# )
self._by_session = weakref.WeakKeyDictionary()
@event.listens_for(mapper, "load")
def object_loaded(instance, ctx):
self._index_object(ctx.session, instance)
@event.listens_for(Session, "after_attach")
def index_object(session, instance):
self._index_object(session, instance)
def _index_object(self, session, instance):
# object attached to a session
# get a dictionary for this session
if session not in self._by_session:
# per session we store a dictionary of sets
self._by_session[session] = by_session = \
collections.defaultdict(weakref.WeakSet)
else:
by_session = self._by_session[session]
# find all the indexes for this object's class,
# and superclasses too.
typ = type(instance)
for cls in typ.__mro__:
if cls in self._index_fns:
# all the "index" functions for this class
for name, rec in self._index_fns[cls].items():
if rec['include_subclasses'] or cls is rec['cls']:
# call the indexing function, build a key
key = name, rec['fn'](instance)
by_session[key].add(instance)
def indexed(self, cls, name, include_subclasses=True):
"""Log a function as indexing a certain class."""
if cls not in self._index_fns:
self._index_fns[cls] = byclass = {}
else:
byclass = self._index_fns[cls]
def decorate(fn):
byclass[name] = {
"fn": fn,
"cls": cls,
"include_subclasses": include_subclasses
}
return fn
return decorate
def __getattr__(self, name):
"""Return an index-lookup function."""
def go(sess, value):
by_session = self._by_session.get(sess)
if by_session is None:
return set()
key = name, value
return set(by_session[key]).intersection(
set(sess.identity_map.values()).union(sess.new))
return go
indexes = Index()
if __name__ == '__main__':
# demonstration
from sqlalchemy import Column, String, Integer
from sqlalchemy.orm import Session
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
class User(Base):
__tablename__ = 'user'
id = Column(Integer, primary_key=True)
name = Column(String)
class Address(Base):
__tablename__ = 'address'
id = Column(Integer, primary_key=True)
name = Column(String)
@indexes.indexed(User, "user_byname")
def index_user_byname(obj):
return obj.name
@indexes.indexed(Address, "address_byname")
def index_address_byname(obj):
return obj.name
a1, a2, a3 = User(name='a'), User(name='a'), User(name='a')
b1, b2, b3 = User(name='b'), User(name='b'), User(name='b')
c1, c2, c3 = User(name='c'), User(name='c'), User(name='c')
d1, d2, d3 = User(name='d'), User(name='d'), User(name='d')
e1, e2, e3 = User(name='e'), User(name='e'), User(name='e')
ad_a, ad_b, ad_c = Address(name='a'), Address(name='b'), Address(name='c')
s1, s2, s3 = Session(), Session(), Session()
s1.add_all([a1, b1, b2, d2, e3, ad_c])
s2.add_all([a2, c2, e1, e2, ad_a])
s3.add_all([b3, c1, d1, d3, ad_b])
assert indexes.user_byname(s1, "b") == set([b1, b2])
assert indexes.user_byname(s2, "e") == set([e1, e2])
assert indexes.user_byname(s2, "c") == set([c2])
assert indexes.user_byname(s3, "b") == set([b3])
assert indexes.address_byname(s3, "b") == set([ad_b])
assert indexes.address_byname(s3, "c") == set()
assert indexes.address_byname(s1, "c") == set([ad_c])
s2.expunge(e2)
assert indexes.user_byname(s2, "e") == set([e1])
s2.close()
assert indexes.user_byname(s2, "e") == set([])