Skip to content

Commit

Permalink
Fixes #276: Memory leak in get_lambda_args()
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlovsky committed Jul 17, 2017
1 parent 2d3afb2 commit 6495daa
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
7 changes: 2 additions & 5 deletions pony/orm/decompiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,13 @@

from pony.thirdparty.compiler import ast, parse

from pony.utils import throw
from pony.utils import throw, get_codeobject_id

##ast.And.__repr__ = lambda self: "And(%s: %s)" % (getattr(self, 'endpos', '?'), repr(self.nodes),)
##ast.Or.__repr__ = lambda self: "Or(%s: %s)" % (getattr(self, 'endpos', '?'), repr(self.nodes),)

ast_cache = {}

codeobjects = {}

def decompile(x):
cells = {}
t = type(x)
Expand All @@ -28,10 +26,9 @@ def decompile(x):
else:
if x.__closure__: cells = dict(izip(codeobject.co_freevars, x.__closure__))
else: throw(TypeError)
key = id(codeobject)
key = get_codeobject_id(codeobject)
result = ast_cache.get(key)
if result is None:
codeobjects[key] = codeobject
decompiler = Decompiler(codeobject)
result = decompiler.ast, decompiler.external_names
ast_cache[key] = result
Expand Down
23 changes: 20 additions & 3 deletions pony/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,27 @@ def truncate_repr(s, max_len=100):
s = repr(s)
return s if len(s) <= max_len else s[:max_len-3] + '...'

codeobjects = {}

def get_codeobject_id(codeobject):
codeobject_id = id(codeobject)
if codeobject_id not in codeobjects:
codeobjects[codeobject_id] = codeobject
return codeobject_id

lambda_args_cache = {}

def get_lambda_args(func):
names = lambda_args_cache.get(func)
if type(func) is types.FunctionType:
codeobject = func.func_code if PY2 else func.__code__
cache_key = get_codeobject_id(codeobject)
elif isinstance(func, ast.Lambda):
cache_key = func
else: assert False # pragma: no cover

names = lambda_args_cache.get(cache_key)
if names is not None: return names

if type(func) is types.FunctionType:
if hasattr(inspect, 'signature'):
names, argsname, kwname, defaults = [], None, None, None
Expand Down Expand Up @@ -162,7 +178,8 @@ def get_lambda_args(func):
if argsname: throw(TypeError, '*%s is not supported' % argsname)
if kwname: throw(TypeError, '**%s is not supported' % kwname)
if defaults: throw(TypeError, 'Defaults are not supported')
lambda_args_cache[func] = names

lambda_args_cache[cache_key] = names
return names

_cache = {}
Expand Down Expand Up @@ -493,4 +510,4 @@ def concat(*args):
return ''.join(tostring(arg) for arg in args)

def is_utf8(encoding):
return encoding.upper().replace('_', '').replace('-', '') in ('UTF8', 'UTF', 'U8')
return encoding.upper().replace('_', '').replace('-', '') in ('UTF8', 'UTF', 'U8')

0 comments on commit 6495daa

Please sign in to comment.