|
32 | 32 | 'Yury Selivanov <yselivanov@sprymix.com>') |
33 | 33 |
|
34 | 34 | import abc |
| 35 | +import ast |
35 | 36 | import dis |
36 | 37 | import collections.abc |
37 | 38 | import enum |
@@ -770,6 +771,42 @@ def getmodule(object, _filename=None): |
770 | 771 | if builtinobject is object: |
771 | 772 | return builtin |
772 | 773 |
|
| 774 | + |
| 775 | +class ClassFoundException(Exception): |
| 776 | + pass |
| 777 | + |
| 778 | + |
| 779 | +class _ClassFinder(ast.NodeVisitor): |
| 780 | + |
| 781 | + def __init__(self, qualname): |
| 782 | + self.stack = [] |
| 783 | + self.qualname = qualname |
| 784 | + |
| 785 | + def visit_FunctionDef(self, node): |
| 786 | + self.stack.append(node.name) |
| 787 | + self.stack.append('<locals>') |
| 788 | + self.generic_visit(node) |
| 789 | + self.stack.pop() |
| 790 | + self.stack.pop() |
| 791 | + |
| 792 | + visit_AsyncFunctionDef = visit_FunctionDef |
| 793 | + |
| 794 | + def visit_ClassDef(self, node): |
| 795 | + self.stack.append(node.name) |
| 796 | + if self.qualname == '.'.join(self.stack): |
| 797 | + # Return the decorator for the class if present |
| 798 | + if node.decorator_list: |
| 799 | + line_number = node.decorator_list[0].lineno |
| 800 | + else: |
| 801 | + line_number = node.lineno |
| 802 | + |
| 803 | + # decrement by one since lines starts with indexing by zero |
| 804 | + line_number -= 1 |
| 805 | + raise ClassFoundException(line_number) |
| 806 | + self.generic_visit(node) |
| 807 | + self.stack.pop() |
| 808 | + |
| 809 | + |
773 | 810 | def findsource(object): |
774 | 811 | """Return the entire source file and starting line number for an object. |
775 | 812 |
|
@@ -802,47 +839,14 @@ def findsource(object): |
802 | 839 | return lines, 0 |
803 | 840 |
|
804 | 841 | if isclass(object): |
805 | | - # Lazy import ast because it's relatively heavy and |
806 | | - # it's not used for other than this part. |
807 | | - import ast |
808 | | - |
809 | | - class ClassFinder(ast.NodeVisitor): |
810 | | - |
811 | | - def visit_FunctionDef(self, node): |
812 | | - stack.append(node.name) |
813 | | - stack.append('<locals>') |
814 | | - self.generic_visit(node) |
815 | | - stack.pop() |
816 | | - stack.pop() |
817 | | - |
818 | | - visit_AsyncFunctionDef = visit_FunctionDef |
819 | | - |
820 | | - def visit_ClassDef(self, node): |
821 | | - nonlocal line_number |
822 | | - stack.append(node.name) |
823 | | - if qualname == '.'.join(stack): |
824 | | - # Return the decorator for the class if present |
825 | | - if node.decorator_list: |
826 | | - line_number = node.decorator_list[0].lineno |
827 | | - else: |
828 | | - line_number = node.lineno |
829 | | - |
830 | | - # decrement by one since lines starts with indexing by zero |
831 | | - line_number -= 1 |
832 | | - raise StopIteration(line_number) |
833 | | - self.generic_visit(node) |
834 | | - stack.pop() |
835 | | - |
836 | | - stack = [] |
837 | | - line_number = None |
838 | 842 | qualname = object.__qualname__ |
839 | 843 | source = ''.join(lines) |
840 | 844 | tree = ast.parse(source) |
841 | | - class_finder = ClassFinder() |
| 845 | + class_finder = _ClassFinder(qualname) |
842 | 846 | try: |
843 | 847 | class_finder.visit(tree) |
844 | | - except StopIteration as e: |
845 | | - line_number = e.value |
| 848 | + except ClassFoundException as e: |
| 849 | + line_number = e.args[0] |
846 | 850 | return lines, line_number |
847 | 851 | else: |
848 | 852 | raise OSError('could not find class definition') |
|
0 commit comments