Skip to content

Commit

Permalink
gh-106727: Make inspect.getsource smarter for class for same name d…
Browse files Browse the repository at this point in the history
…efinitions (#106815)
  • Loading branch information
gaogaotiantian committed Jul 18, 2023
1 parent 505eede commit 663854d
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 12 deletions.
57 changes: 46 additions & 11 deletions Lib/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,9 +1034,13 @@ class ClassFoundException(Exception):

class _ClassFinder(ast.NodeVisitor):

def __init__(self, qualname):
def __init__(self, cls, tree, lines, qualname):
self.stack = []
self.cls = cls
self.tree = tree
self.lines = lines
self.qualname = qualname
self.lineno_found = []

def visit_FunctionDef(self, node):
self.stack.append(node.name)
Expand All @@ -1057,11 +1061,48 @@ def visit_ClassDef(self, node):
line_number = node.lineno

# decrement by one since lines starts with indexing by zero
line_number -= 1
raise ClassFoundException(line_number)
self.lineno_found.append((line_number - 1, node.end_lineno))
self.generic_visit(node)
self.stack.pop()

def get_lineno(self):
self.visit(self.tree)
lineno_found_number = len(self.lineno_found)
if lineno_found_number == 0:
raise OSError('could not find class definition')
elif lineno_found_number == 1:
return self.lineno_found[0][0]
else:
# We have multiple candidates for the class definition.
# Now we have to guess.

# First, let's see if there are any method definitions
for member in self.cls.__dict__.values():
if isinstance(member, types.FunctionType):
for lineno, end_lineno in self.lineno_found:
if lineno <= member.__code__.co_firstlineno <= end_lineno:
return lineno

class_strings = [(''.join(self.lines[lineno: end_lineno]), lineno)
for lineno, end_lineno in self.lineno_found]

# Maybe the class has a docstring and it's unique?
if self.cls.__doc__:
ret = None
for candidate, lineno in class_strings:
if self.cls.__doc__.strip() in candidate:
if ret is None:
ret = lineno
else:
break
else:
if ret is not None:
return ret

# We are out of ideas, just return the last one found, which is
# slightly better than previous ones
return self.lineno_found[-1][0]


def findsource(object):
"""Return the entire source file and starting line number for an object.
Expand Down Expand Up @@ -1098,14 +1139,8 @@ def findsource(object):
qualname = object.__qualname__
source = ''.join(lines)
tree = ast.parse(source)
class_finder = _ClassFinder(qualname)
try:
class_finder.visit(tree)
except ClassFoundException as e:
line_number = e.args[0]
return lines, line_number
else:
raise OSError('could not find class definition')
class_finder = _ClassFinder(object, tree, lines, qualname)
return lines, class_finder.get_lineno()

if ismethod(object):
object = object.__func__
Expand Down
20 changes: 20 additions & 0 deletions Lib/test/inspect_fodder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,23 @@ def complex_decorated(foo=0, bar=lambda: 0):
nested_lambda = (
lambda right: [].map(
lambda length: ()))

# line 294
if True:
class cls296:
def f():
pass
else:
class cls296:
def g():
pass

# line 304
if False:
class cls310:
def f():
pass
else:
class cls310:
def g():
pass
5 changes: 4 additions & 1 deletion Lib/test/test_inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,6 @@ def test_class_decorator(self):
self.assertSourceEqual(mod2.cls196.cls200, 198, 201)

def test_class_inside_conditional(self):
self.assertSourceEqual(mod2.cls238, 238, 240)
self.assertSourceEqual(mod2.cls238.cls239, 239, 240)

def test_multiple_children_classes(self):
Expand All @@ -975,6 +974,10 @@ def test_nested_class_definition_inside_async_function(self):
self.assertSourceEqual(mod2.cls226, 231, 235)
self.assertSourceEqual(asyncio.run(mod2.cls226().func232()), 233, 234)

def test_class_definition_same_name_diff_methods(self):
self.assertSourceEqual(mod2.cls296, 296, 298)
self.assertSourceEqual(mod2.cls310, 310, 312)

class TestNoEOL(GetSourceBase):
def setUp(self):
self.tempdir = TESTFN + '_dir'
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make :func:`inspect.getsource` smarter for class for same name definitions

0 comments on commit 663854d

Please sign in to comment.