Skip to content
Permalink
Browse files

Merge pull request #519 from jerith/jerith/per-method-recursion-guard

Per-method recursion guard
  • Loading branch information
alex committed Mar 17, 2013
2 parents 2405d6d + 942485a commit 67e3e98cdf02774e952b3dd18f3dd21f8ff8a4fa
@@ -31,7 +31,7 @@ def self.[](*args)

def inspect
result = "["
recursion = Thread.current.recursion_guard(self) do
recursion = Thread.current.recursion_guard(:array_inspect, self) do
self.each_with_index do |obj, i|
if i > 0
result << ", "
@@ -143,7 +143,7 @@ def first

def flatten(level = -1)
list = []
recursion = Thread.current.recursion_guard(self) do
Thread.current.recursion_guard(:array_flatten, self) do
self.each do |item|
if level == 0
list << item
@@ -155,9 +155,7 @@ def flatten(level = -1)
end
return list
end
if recursion
raise ArgumentError.new("tried to flatten recursive array")
end
raise ArgumentError.new("tried to flatten recursive array")
end

def flatten!(level = -1)
@@ -183,7 +181,7 @@ def ==(other)
if self.size != other.size
return false
end
Thread.current.recursion_guard(self) do
Thread.current.recursion_guard(:array_equals, self) do
self.each_with_index do |x, i|
if x != other[i]
return false
@@ -203,7 +201,7 @@ def eql?(other)
if self.length != other.length
return false
end
Thread.current.recursion_guard(self) do
Thread.current.recursion_guard(:array_eqlp, self) do
self.each_with_index do |x, i|
if !x.eql?(other[i])
return false
@@ -1,6 +1,3 @@
fails:Array#inspect returns '[]' for an empty Array
fails:Array#inspect calls inspect on its elements and joins the results with commas
fails:Array#inspect represents a recursive element with '[...]'
fails:Array#inspect taints the result if the Array is non-empty and tainted
fails:Array#inspect taints the result if an element is tainted
fails:Array#inspect untrusts the result if the Array is untrusted
@@ -1,7 +1,2 @@
fails:Array#reject returns a new array without elements for which block is true
fails:Array#reject returns self when called on an Array emptied with #shift
fails:Array#reject properly handles recursive arrays
fails:Array#reject does not return subclass instance on Array subclasses
fails:Array#reject does not retain instance variables
fails:Array#reject returns an Enumerator if no block given
fails:Array#reject! returns an Enumerator if no block given
@@ -1,4 +1,3 @@
fails:Hash#== computes equality for complex recursive hashes
fails:Hash#== computes equality for recursive hashes & arrays
fails:Hash#== compares the values in self to values in other hash
fails:Hash#== compares values with == semantics
@@ -15,3 +15,44 @@ def test_thread_local_storage(self, space):
return Thread.current[:a]
""")
assert space.int_w(w_res) == 1

def test_recursion_guard(self, space):
w_res = space.execute("""
def foo(objs, depth = 0)
obj = objs.shift
recursion = Thread.current.recursion_guard(:foo, obj) do
return foo(objs, depth + 1)
end
if recursion
return [depth, obj]
end
end
return foo([:a, :b, :c, :a, :d])
""")
w_depth, w_symbol = space.listview(w_res)
assert space.int_w(w_depth) == 3
assert space.symbol_w(w_symbol) == "a"

def test_recursion_guard_nested(self, space):
w_res = space.execute("""
def foo(objs, depth = 0)
obj = objs.shift
Thread.current.recursion_guard(:foo, obj) do
return bar(objs, depth + 1)
end
return [depth, obj]
end
def bar(objs, depth)
obj = objs.shift
Thread.current.recursion_guard(:bar, obj) do
return foo(objs, depth + 1)
end
return [depth, obj]
end
return foo([:a, :a, :b, :b, :c, :a, :d, :d])
""")
w_depth, w_symbol = space.listview(w_res)
assert space.int_w(w_depth) == 5
assert space.symbol_w(w_symbol) == "a"
@@ -1,12 +1,13 @@
class TestExecutionContext(object):
def test_recursion_guard(self, space):
f = "my_func"
x = object()
y = object()
with space.getexecutioncontext().recursion_guard(x) as in_recursion:
with space.getexecutioncontext().recursion_guard(f, x) as in_recursion:
assert not in_recursion
with space.getexecutioncontext().recursion_guard(y) as ir2:
with space.getexecutioncontext().recursion_guard(f, y) as ir2:
assert not ir2
with space.getexecutioncontext().recursion_guard(x) as ir3:
with space.getexecutioncontext().recursion_guard(f, x) as ir3:
assert ir3
with space.getexecutioncontext().recursion_guard(x) as ir3:
with space.getexecutioncontext().recursion_guard(f, x) as ir3:
assert ir3
@@ -13,7 +13,7 @@ def __init__(self):
self.regexp_match_cell = None
self.w_trace_proc = None
self.in_trace_proc = False
self.recursive_objects = {}
self.recursive_calls = {}
self.catch_names = {}

def settraceproc(self, w_proc):
@@ -75,8 +75,13 @@ def gettoprubyframe(self):
frame = frame.backref()
return frame

def recursion_guard(self, w_obj):
return _RecursionGuardContextManager(self, w_obj)
def recursion_guard(self, func_id, w_obj):
# We need independent recursion detection for different blocks of
# potentially recursive code so that they don't interfere with each
# other and cause false positives. This is only likely to be a problem
# if one recursion-guarded function calls another, but we can't
# guarantee that won't happen.
return _RecursionGuardContextManager(self, func_id, w_obj)

def catch_block(self, name):
return _CatchContextManager(self, name)
@@ -103,21 +108,27 @@ def __exit__(self, exc_type, exc_value, tb):


class _RecursionGuardContextManager(object):
def __init__(self, ec, w_obj):
def __init__(self, ec, func_id, w_obj):
self.ec = ec
if func_id not in self.ec.recursive_calls:
self.ec.recursive_calls[func_id] = {}
self.recursive_objects = self.ec.recursive_calls[func_id]
self.func_id = func_id
self.w_obj = w_obj
self.added = False

def __enter__(self):
if self.w_obj in self.ec.recursive_objects:
if self.w_obj in self.recursive_objects:
return True
self.ec.recursive_objects[self.w_obj] = None
self.recursive_objects[self.w_obj] = None
self.added = True
return False

def __exit__(self, exc_type, exc_value, tb):
if self.added:
del self.ec.recursive_objects[self.w_obj]
del self.recursive_objects[self.w_obj]
if not self.recursive_objects:
del self.ec.recursive_calls[self.func_id]


class _CatchContextManager(object):
@@ -144,11 +144,12 @@ def singleton_method_join(self, space, args_w):
result = []
for w_arg in args_w:
if isinstance(w_arg, W_ArrayObject):
with space.getexecutioncontext().recursion_guard(w_arg) as in_recursion:
ec = space.getexecutioncontext()
with ec.recursion_guard("file_singleton_method_join", w_arg) as in_recursion:
if in_recursion:
raise space.error(space.w_ArgumentError, "recursive array")
string = space.str_w(
W_FileObject.singleton_method_join(self, space, space.listview(w_arg))
space.send(space.getclassfor(W_FileObject), space.newsymbol("join"), space.listview(w_arg))
)
else:
w_string = space.convert_type(w_arg, space.w_string, "to_path", raise_error=False)
@@ -31,13 +31,14 @@ def method_subscript_assign(self, space, key, w_value):
return w_value

@classdef.method("recursion_guard")
def method_recursion_guard(self, space, w_obj, block):
def method_recursion_guard(self, space, w_identifier, w_obj, block):
"""
Detects recursion. If there is none, yield and return false. Else
return true
Calls the block with true if recursion is detected, false otherwise.
It is up to the block to decide what to do in either case.
"""
with space.getexecutioncontext().recursion_guard(w_obj) as in_recursion:
if in_recursion:
return space.w_true
space.invoke_block(block, [])
return space.w_false
ec = space.getexecutioncontext()
identifier = space.str_w(w_identifier)
with ec.recursion_guard(identifier, w_obj) as in_recursion:
if not in_recursion:
space.invoke_block(block, [])
return space.newbool(in_recursion)

0 comments on commit 67e3e98

Please sign in to comment.