Skip to content

Commit

Permalink
Normalize code
Browse files Browse the repository at this point in the history
  • Loading branch information
uralbash committed Sep 29, 2017
1 parent d677039 commit c213176
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 56 deletions.
192 changes: 144 additions & 48 deletions sqlalchemy_mptt/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,31 @@
from sqlalchemy.orm.base import NO_VALUE


def _insert_subtree(table, connection, node_size,
node_pos_left, node_pos_right,
parent_pos_left, parent_pos_right, subtree,
parent_tree_id, parent_level, node_level, left_sibling,
table_pk):
def _insert_subtree(
table,
connection,
node_size,
node_pos_left,
node_pos_right,
parent_pos_left,
parent_pos_right,
subtree,
parent_tree_id,
parent_level,
node_level,
left_sibling,
table_pk
):
# step 1: rebuild inserted subtree
delta_lft = left_sibling['lft'] + 1
if not left_sibling['is_parent']:
delta_lft = left_sibling['rgt'] + 1
delta_rgt = delta_lft + node_size - 1

connection.execute(
table.update(table_pk.in_(subtree))
.values(
table.update(
table_pk.in_(subtree)
).values(
lft=table.c.lft - node_pos_left + delta_lft,
rgt=table.c.rgt - node_pos_right + delta_rgt,
level=table.c.level - node_level + parent_level + 1,
Expand All @@ -43,14 +54,20 @@ def _insert_subtree(table, connection, node_size,
# step 2: update key of right side
connection.execute(
table.update(
and_(table.c.rgt > delta_lft - 1,
table_pk.notin_(subtree),
table.c.tree_id == parent_tree_id)
and_(
table.c.rgt > delta_lft - 1,
table_pk.notin_(subtree),
table.c.tree_id == parent_tree_id
)
).values(
rgt=table.c.rgt + node_size,
lft=case(
[(table.c.lft > left_sibling['lft'],
table.c.lft + node_size)],
[
(
table.c.lft > left_sibling['lft'],
table.c.lft + node_size
)
],
else_=table.c.lft
)
)
Expand All @@ -76,15 +93,28 @@ def mptt_before_insert(mapper, connection, instance):
instance.right = 2
instance.level = instance.get_default_level()
tree_id = connection.scalar(
select([func.max(table.c.tree_id) + 1])) or 1
select(
[
func.max(table.c.tree_id) + 1
]
)
) or 1
instance.tree_id = tree_id
else:
(parent_pos_left,
parent_pos_right,
parent_tree_id,
parent_level) = connection.execute(
select([table.c.lft, table.c.rgt, table.c.tree_id, table.c.level]).
where(table_pk == instance.parent_id)
select(
[
table.c.lft,
table.c.rgt,
table.c.tree_id,
table.c.level
]
).where(
table_pk == instance.parent_id
)
).fetchone()

# Update key of right side
Expand All @@ -94,13 +124,21 @@ def mptt_before_insert(mapper, connection, instance):
table.c.tree_id == parent_tree_id)
).values(
lft=case(
[(table.c.lft > parent_pos_right,
table.c.lft + 2)],
[
(
table.c.lft > parent_pos_right,
table.c.lft + 2
)
],
else_=table.c.lft
),
rgt=case(
[(table.c.rgt >= parent_pos_right,
table.c.rgt + 2)],
[
(
table.c.rgt >= parent_pos_right,
table.c.rgt + 2
)
],
else_=table.c.rgt
)
)
Expand All @@ -119,14 +157,23 @@ def mptt_before_delete(mapper, connection, instance, delete=True):
db_pk = instance.get_pk_column()
table_pk = getattr(table.c, db_pk.name)
lft, rgt = connection.execute(
select([table.c.lft, table.c.rgt]).where(table_pk == pk)
select(
[
table.c.lft,
table.c.rgt
]
).where(
table_pk == pk
)
).fetchone()
delta = rgt - lft + 1

if delete:
mapper.base_mapper.confirm_deleted_rows = False
connection.execute(
table.delete(table_pk == pk)
table.delete(
table_pk == pk
)
)

if instance.parent_id or not delete:
Expand All @@ -144,14 +191,27 @@ def mptt_before_delete(mapper, connection, instance, delete=True):
"""
connection.execute(
table.update(
and_(table.c.rgt > rgt, table.c.tree_id == tree_id))
.values(
and_(
table.c.rgt > rgt,
table.c.tree_id == tree_id
)
).values(
lft=case(
[(table.c.lft > lft, table.c.lft - delta)],
[
(
table.c.lft > lft,
table.c.lft - delta
)
],
else_=table.c.lft
),
rgt=case(
[(table.c.rgt >= rgt, table.c.rgt - delta)],
[
(
table.c.rgt >= rgt,
table.c.rgt - delta
)
],
else_=table.c.rgt
)
)
Expand Down Expand Up @@ -189,24 +249,40 @@ def mptt_before_update(mapper, connection, instance):
table.c.parent_id,
table.c.level,
table.c.tree_id
]).where(table_pk == instance.mptt_move_before)
]
).where(
table_pk == instance.mptt_move_before
)
).fetchone()
current_lvl_nodes = connection.execute(
select([table.c.lft, table.c.rgt, table.c.parent_id,
table.c.tree_id]).
where(and_(table.c.level == right_sibling_level,
table.c.tree_id == right_sibling_tree_id,
table.c.lft < right_sibling_left))
select(
[
table.c.lft,
table.c.rgt,
table.c.parent_id,
table.c.tree_id
]
).where(
and_(
table.c.level == right_sibling_level,
table.c.tree_id == right_sibling_tree_id,
table.c.lft < right_sibling_left
)
)
).fetchall()
if current_lvl_nodes:
(left_sibling_left,
left_sibling_right,
left_sibling_parent,
left_sibling_tree_id) = current_lvl_nodes[-1]
(
left_sibling_left,
left_sibling_right,
left_sibling_parent,
left_sibling_tree_id
) = current_lvl_nodes[-1]
instance.parent_id = left_sibling_parent
left_sibling = {'lft': left_sibling_left,
'rgt': left_sibling_right,
'is_parent': False}
left_sibling = {
'lft': left_sibling_left,
'rgt': left_sibling_right,
'is_parent': False
}
# if move_before to top level
elif not right_sibling_parent:
left_sibling_tree_id = right_sibling_tree_id - 1
Expand All @@ -226,7 +302,9 @@ def mptt_before_update(mapper, connection, instance):
table.c.parent_id,
table.c.tree_id
]
).where(table_pk == instance.mptt_move_after)
).where(
table_pk == instance.mptt_move_after
)
).fetchone()
instance.parent_id = left_sibling_parent
left_sibling = {
Expand All @@ -249,7 +327,9 @@ def mptt_before_update(mapper, connection, instance):
table.c.rgt <= instance.right,
table.c.tree_id == instance.tree_id
)
).order_by(table.c.lft)
).order_by(
table.c.lft
)
).fetchall()
subtree = [x[0] for x in subtree]

Expand All @@ -272,7 +352,9 @@ def mptt_before_update(mapper, connection, instance):
table.c.parent_id,
table.c.level
]
).where(table_pk == node_id)
).where(
table_pk == node_id
)
).fetchone()

# if instance just update w/o move
Expand All @@ -299,8 +381,9 @@ def mptt_before_update(mapper, connection, instance):
table.c.tree_id,
table.c.level
]
).where(
table_pk == instance.parent_id
)
.where(table_pk == instance.parent_id)
).fetchone()
if not node_parent_id and node_tree_id == parent_tree_id:
instance.parent_id = None
Expand Down Expand Up @@ -328,7 +411,9 @@ def mptt_before_update(mapper, connection, instance):
table.c.tree_id,
table.c.level
]
).where(table_pk == instance.parent_id)
).where(
table_pk == instance.parent_id
)
).fetchone()
# 'size' of moving node (including all it's sub nodes)
node_size = node_pos_right - node_pos_left + 1
Expand Down Expand Up @@ -363,17 +448,28 @@ def mptt_before_update(mapper, connection, instance):
if left_sibling_tree_id or left_sibling_tree_id == 0:
tree_id = left_sibling_tree_id + 1
connection.execute(
table.update(table.c.tree_id > left_sibling_tree_id)
.values(tree_id=table.c.tree_id + 1)
table.update(
table.c.tree_id > left_sibling_tree_id
).values(
tree_id=table.c.tree_id + 1
)
)
# if just insert
else:
tree_id = connection.scalar(
select([func.max(table.c.tree_id) + 1]))
select(
[
func.max(table.c.tree_id) + 1
]
)
)

connection.execute(
table.update(table_pk.in_(subtree))
.values(
table.update(
table_pk.in_(
subtree
)
).values(
lft=table.c.lft - node_pos_left + 1,
rgt=table.c.rgt - node_pos_left + 1,
level=table.c.level - node_level + default_level,
Expand Down
26 changes: 18 additions & 8 deletions sqlalchemy_mptt/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,14 @@ def parent_id(cls):
if not pk.name:
pk.name = cls.get_pk_name()

return Column("parent_id", pk.type,
ForeignKey('%s.%s' % (cls.__tablename__, pk.name),
ondelete='CASCADE'))
return Column(
"parent_id",
pk.type,
ForeignKey(
'{}.{}'.format(cls.__tablename__, pk.name),
ondelete='CASCADE'
)
)

@declared_attr
def parent(self):
Expand Down Expand Up @@ -188,12 +193,13 @@ def move_before(self, node_id):
def leftsibling_in_level(self):
""" Node to the left of the current node at the same level
For example see :mod:`sqlalchemy_mptt.tests.cases.get_tree.test_leftsibling_in_level`
For example see
:mod:`sqlalchemy_mptt.tests.cases.get_tree.test_leftsibling_in_level`
""" # noqa
table = _get_tree_table(self.__mapper__)
session = Session.object_session(self)
current_lvl_nodes = session.query(table)\
.filter_by(level=self.level).filter_by(tree_id=self.tree_id)\
current_lvl_nodes = session.query(table) \
.filter_by(level=self.level).filter_by(tree_id=self.tree_id) \
.filter(table.c.lft < self.left).order_by(table.c.lft).all()
if current_lvl_nodes:
return current_lvl_nodes[-1]
Expand Down Expand Up @@ -323,8 +329,12 @@ def drilldown_tree(self, session=None, json=False, json_fields=None):
"""
if not session:
session = object_session(self)
return self.get_tree(session, json=json, json_fields=json_fields,
query=self._drilldown_query)
return self.get_tree(
session,
json=json,
json_fields=json_fields,
query=self._drilldown_query
)

def path_to_root(self, session=None):
"""Generate path from a leaf or intermediate node to the root.
Expand Down

0 comments on commit c213176

Please sign in to comment.