Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix memory leak in Rule function builder #1521

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
60 changes: 43 additions & 17 deletions src/werkzeug/routing.py
Expand Up @@ -102,7 +102,6 @@
import sys
import types
import uuid
from functools import partial
from pprint import pformat
from threading import Lock

Expand Down Expand Up @@ -719,6 +718,14 @@ def get_converter(self, variable_name, converter_name, args, kwargs):
raise LookupError("the converter %r does not exist" % converter_name)
return self.map.converters[converter_name](self.map, *args, **kwargs)

def _encode_query_vars(self, query_vars):
return url_encode(
query_vars,
charset=self.map.charset,
sort=self.map.sort_parameters,
key=self.map.sort_key,
)

def compile(self):
"""Compiles the regular expression and stores it."""
assert self.map is not None, "rule not bound"
Expand Down Expand Up @@ -764,8 +771,8 @@ def _build_regex(rule):
if not self.is_leaf:
self._trace.append((False, "/"))

self._build = self._compile_builder(False)
self._build_unknown = self._compile_builder(True)
self._build = self._compile_builder(False).__get__(self, None)
self._build_unknown = self._compile_builder(True).__get__(self, None)

if self.build_only:
return
Expand Down Expand Up @@ -840,9 +847,14 @@ def __init__(self, rule):
self.const_table = {}
self.var = []
self.var_table = {}
self.names = []
self.names_table = {}
self.argdefs = ()
self.defaults = dict(self.rule.defaults or {})

# initialize `self` as the first argument
self.get_var(".self")

def get_const(self, x):
"""Return a constant ID for an object, adding it to the pool
if not already present.
Expand All @@ -866,6 +878,18 @@ def get_var(self, x):
self.var.append(x)
return self.var_table[x]

def get_name(self, x):
"""Return an id for a `name`, adding it to the name pool. This is
used to later populate `co_names` (for looking up attributes).

We use this for `self._XXX` attributes.
"""
x = str(x)
if x not in self.names_table:
self.names_table[x] = len(self.names)
self.names.append(x)
return self.names_table[x]

def add_defaults(self):
"""A rule builder is allowed to receive any of its defaults
as arguments. We don't bother to check that they match
Expand Down Expand Up @@ -946,9 +970,7 @@ def build_string(self, n):
"CALL_FUNCTION", 1
)

def emit_build(
self, ind, opl, append_unknown=False, encode_query_vars=None, kwargs=None
):
def emit_build(self, ind, opl, append_unknown=False, kwargs=None):
ops = b""
n = len(opl)
stack = 0
Expand All @@ -959,7 +981,14 @@ def emit_build(
ops += self.build_op("LOAD_CONST", self.get_const(elem))
stack_overhead = 0
continue
ops += self.build_op("LOAD_CONST", self.get_const(op))

# self._converters["argname"].to_url
ops += self.build_op("LOAD_FAST", self.get_var(".self"))
ops += self.build_op("LOAD_ATTR", self.get_name("_converters"))
ops += self.build_op("LOAD_CONST", self.get_const(elem))
ops += self.build_op("BINARY_SUBSCR")
ops += self.build_op("LOAD_ATTR", self.get_name("to_url"))

ops += self.build_op("LOAD_FAST", self.get_var(elem))
ops += self.build_op("CALL_FUNCTION", 1)
stack_overhead = 2
Expand All @@ -985,7 +1014,11 @@ def emit_build(
# assemble this in its own buffers because we need to
# jump over it
uops = bytearray() # run if kwargs. TOS=kwargs
uops += self.build_op("LOAD_CONST", self.get_const(encode_query_vars))

# self._encode_query_vars
uops += self.build_op("LOAD_FAST", self.get_var(".self"))
uops += self.build_op("LOAD_ATTR", self.get_name("_encode_query_vars"))

uops += self.build_op("ROT_TWO")
uops += self.build_op("CALL_FUNCTION", 1)
uops += self.build_op("LOAD_CONST", self.get_const("?"))
Expand Down Expand Up @@ -1036,13 +1069,6 @@ def compile(self, append_unknown=True):
dom_ops = []
url_ops = []
opl = dom_ops
if append_unknown:
encode_query_vars = partial(
url_encode,
charset=self.rule.map.charset,
sort=self.rule.map.sort_parameters,
key=self.rule.map.sort_key,
)
for is_dynamic, data in self.rule._trace:
if data == "|" and opl is dom_ops:
opl = url_ops
Expand Down Expand Up @@ -1092,7 +1118,7 @@ def compile(self, append_unknown=True):
stack += 1
if append_unknown:
ps, rv = self.emit_build(
len(ops), url_ops, append_unknown, encode_query_vars, argcount
len(ops), url_ops, append_unknown, argcount
)
else:
ps, rv = self.emit_build(len(ops), url_ops)
Expand All @@ -1107,7 +1133,7 @@ def compile(self, append_unknown=True):
flags, # flags
bytes(ops), # codestring
tuple(self.consts), # constants
(), # names
tuple(self.names), # names
tuple(self.var), # varnames
"<werkzeug routing>", # filename, coverage ignores "<"
"<builder:%r>" % self.rule.rule, # name
Expand Down
26 changes: 26 additions & 0 deletions tests/test_routing.py
Expand Up @@ -8,6 +8,7 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
import gc
import uuid

import pytest
Expand Down Expand Up @@ -1057,3 +1058,28 @@ def test_error_message_suggestion():
with pytest.raises(r.BuildError) as excinfo:
adapter.build("world", {"id": 2}, method="POST")
assert "Did you mean to use methods ['GET', 'HEAD']?" in str(excinfo.value)


def test_no_memory_leak_from_Rule_builder():
"""See #1520"""

# generate a bunch of objects that *should* get collected
for _ in range(100):
r.Map([r.Rule("/a/<string:b>")])

# ensure that the garbage collection has had a chance to collect cyclic
# objects
for _ in range(5):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gc.collect() returns the number of objects collected so i think this can be replaced by

while gc.collect():
    pass

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would cause the test to never stop instead of failing though. 5 seems like a safe bet.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gc.collect() returns the number of objects collected so i think this can be replaced by

while gc.collect():
    pass

Not necessarily, the first generation could collect zero objects (fast gc) and then later generations could collect cyclic gc

gc.collect()

# assert they got collected!
count = sum(1 for obj in gc.get_objects() if isinstance(obj, r.Rule))
assert count == 0


def test_build_url_with_arg_self():
map = r.Map([r.Rule("/foo/<string:self>", endpoint="foo")])
adapter = map.bind("example.org", "/", subdomain="blah")

ret = adapter.build("foo", {"self": "bar"})
assert ret == "http://example.org/foo/bar"