diff --git a/ChangeLog b/ChangeLog index e5bc5dc9f..c10ae3152 100644 --- a/ChangeLog +++ b/ChangeLog @@ -16,6 +16,10 @@ Release date: TBA Closes PyCQA/pylint#7488. +* Create ``ContextManagerModel`` and let ``GeneratorModel`` inherit from it. + + Refs PyCQA/pylint#2567 + What's New in astroid 2.12.11? ============================== diff --git a/astroid/interpreter/objectmodel.py b/astroid/interpreter/objectmodel.py index 879ee7f25..1f41a1112 100644 --- a/astroid/interpreter/objectmodel.py +++ b/astroid/interpreter/objectmodel.py @@ -588,6 +588,48 @@ def attr___self__(self): attr_im_self = attr___self__ +class ContextManagerModel(ObjectModel): + """Model for context managers. + + Based on 3.3.9 of the Data Model documentation: + https://docs.python.org/3/reference/datamodel.html#with-statement-context-managers + """ + + @property + def attr___enter__(self) -> bases.BoundMethod: + """Representation of the base implementation of __enter__. + + As per Python documentation: + Enter the runtime context related to this object. The with statement + will bind this method's return value to the target(s) specified in the + as clause of the statement, if any. + """ + node: nodes.FunctionDef = builder.extract_node("""def __enter__(self): ...""") + # We set the parent as being the ClassDef of 'object' as that + # is where this method originally comes from + node.parent = AstroidManager().builtins_module["object"] + + return bases.BoundMethod(proxy=node, bound=_get_bound_node(self)) + + @property + def attr___exit__(self) -> bases.BoundMethod: + """Representation of the base implementation of __exit__. + + As per Python documentation: + Exit the runtime context related to this object. The parameters describe the + exception that caused the context to be exited. If the context was exited + without an exception, all three arguments will be None. + """ + node: nodes.FunctionDef = builder.extract_node( + """def __exit__(self, exc_type, exc_value, traceback): ...""" + ) + # We set the parent as being the ClassDef of 'object' as that + # is where this method originally comes from + node.parent = AstroidManager().builtins_module["object"] + + return bases.BoundMethod(proxy=node, bound=_get_bound_node(self)) + + class BoundMethodModel(FunctionModel): @property def attr___func__(self): @@ -598,7 +640,7 @@ def attr___self__(self): return self._instance.bound -class GeneratorModel(FunctionModel): +class GeneratorModel(FunctionModel, ContextManagerModel): def __new__(cls, *args, **kwargs): # Append the values from the GeneratorType unto this object. ret = super().__new__(cls, *args, **kwargs) diff --git a/tests/unittest_object_model.py b/tests/unittest_object_model.py index 3dbe5026b..9d412b786 100644 --- a/tests/unittest_object_model.py +++ b/tests/unittest_object_model.py @@ -571,6 +571,45 @@ def test(a: 1, b: 2, /, c: 3): pass self.assertEqual(annotations.getitem(astroid.Const("c")).value, 3) +class TestContextManagerModel: + def test_model(self) -> None: + """We use a generator to test this model.""" + ast_nodes = builder.extract_node( + """ + def test(): + "a" + yield + + gen = test() + gen.__enter__ #@ + gen.__exit__ #@ + """ + ) + assert isinstance(ast_nodes, list) + + enter = next(ast_nodes[0].infer()) + assert isinstance(enter, astroid.BoundMethod) + # Test that the method is correctly bound + assert isinstance(enter.bound, bases.Generator) + assert enter.bound._proxied.qname() == "builtins.generator" + # Test that thet FunctionDef accepts no arguments except self + # NOTE: This probably shouldn't be double proxied, but this is a + # quirck of the current model implementations. + assert isinstance(enter._proxied._proxied, nodes.FunctionDef) + assert len(enter._proxied._proxied.args.args) == 1 + assert enter._proxied._proxied.args.args[0].name == "self" + + exit_node = next(ast_nodes[1].infer()) + assert isinstance(exit_node, astroid.BoundMethod) + # Test that the FunctionDef accepts the arguments as defiend in the ObjectModel + assert isinstance(exit_node._proxied._proxied, nodes.FunctionDef) + assert len(exit_node._proxied._proxied.args.args) == 4 + assert exit_node._proxied._proxied.args.args[0].name == "self" + assert exit_node._proxied._proxied.args.args[1].name == "exc_type" + assert exit_node._proxied._proxied.args.args[2].name == "exc_value" + assert exit_node._proxied._proxied.args.args[3].name == "traceback" + + class GeneratorModelTest(unittest.TestCase): def test_model(self) -> None: ast_nodes = builder.extract_node( @@ -585,6 +624,8 @@ def test(): gen.gi_code #@ gen.gi_frame #@ gen.send #@ + gen.__enter__ #@ + gen.__exit__ #@ """ ) assert isinstance(ast_nodes, list) @@ -605,6 +646,12 @@ def test(): send = next(ast_nodes[4].infer()) self.assertIsInstance(send, astroid.BoundMethod) + enter = next(ast_nodes[5].infer()) + assert isinstance(enter, astroid.BoundMethod) + + exit_node = next(ast_nodes[6].infer()) + assert isinstance(exit_node, astroid.BoundMethod) + class ExceptionModelTest(unittest.TestCase): @staticmethod