forked from coleifer/peewee
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gfk.py
122 lines (103 loc) · 4.33 KB
/
gfk.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
Provide a "Generic ForeignKey", similar to Django. A "GFK" is composed of two
columns: an object ID and an object type identifier. The object types are
collected in a global registry (all_models), so all you need to do is subclass
``gfk.Model`` and your model will be added to the registry.
Example:
class Tag(Model):
tag = CharField()
object_type = CharField(null=True)
object_id = IntegerField(null=True)
object = GFKField('object_type', 'object_id')
class Blog(Model):
tags = ReverseGFK(Tag, 'object_type', 'object_id')
class Photo(Model):
tags = ReverseGFK(Tag, 'object_type', 'object_id')
tag.object -> a blog or photo
blog.tags -> select query of tags for ``blog`` instance
Blog.tags -> select query of all tags for Blog instances
"""
from peewee import *
from peewee import BaseModel as _BaseModel
from peewee import FieldDescriptor
from peewee import Model as _Model
from peewee import SelectQuery
from peewee import UpdateQuery
from peewee import with_metaclass
all_models = set()
table_cache = {}
class BaseModel(_BaseModel):
def __new__(cls, name, bases, attrs):
cls = super(BaseModel, cls).__new__(cls, name, bases, attrs)
all_models.add(cls)
return cls
class Model(with_metaclass(BaseModel, _Model)):
pass
def get_model(tbl_name):
if tbl_name not in table_cache:
for model in all_models:
if model._meta.db_table == tbl_name:
table_cache[tbl_name] = model
break
return table_cache.get(tbl_name)
class GFKField(object):
def __init__(self, model_type_field='object_type',
model_id_field='object_id'):
self.model_type_field = model_type_field
self.model_id_field = model_id_field
self.att_name = '.'.join((self.model_type_field, self.model_id_field))
def get_obj(self, instance):
data = instance._data
if data.get(self.model_type_field) and data.get(self.model_id_field):
tbl_name = data[self.model_type_field]
model_class = get_model(tbl_name)
if not model_class:
raise AttributeError('Model for table "%s" not found in GFK '
'lookup.' % tbl_name)
query = model_class.select().where(
model_class._meta.primary_key == data[self.model_id_field])
return query.get()
def __get__(self, instance, instance_type=None):
if instance:
if self.att_name not in instance._obj_cache:
rel_obj = self.get_obj(instance)
if rel_obj:
instance._obj_cache[self.att_name] = rel_obj
return instance._obj_cache.get(self.att_name)
return self
def __set__(self, instance, value):
instance._obj_cache[self.att_name] = value
instance._data[self.model_type_field] = value._meta.db_table
instance._data[self.model_id_field] = value.get_id()
class ReverseGFK(object):
def __init__(self, model, model_type_field='object_type',
model_id_field='object_id'):
self.model_class = model
self.model_type_field = model._meta.fields[model_type_field]
self.model_id_field = model._meta.fields[model_id_field]
def __get__(self, instance, instance_type=None):
if instance:
return self.model_class.select().where(
(self.model_type_field == instance._meta.db_table) &
(self.model_id_field == instance.get_id())
)
else:
return self.model_class.select().where(
self.model_type_field == instance_type._meta.db_table
)
def __set__(self, instance, value):
mtv = instance._meta.db_table
miv = instance.get_id()
if (isinstance(value, SelectQuery) and
value.model_class == self.model_class):
uq = UpdateQuery(self.model_class, {
self.model_type_field: mtv,
self.model_id_field: miv,
}).where(value._where).execute()
elif all(map(lambda i: isinstance(i, self.model_class), value)):
for obj in value:
setattr(obj, self.model_type_field.name, mtv)
setattr(obj, self.model_id_field.name, miv)
obj.save()
else:
raise ValueError('ReverseGFK field unable to handle "%s"' % value)