Skip to content

Commit

Permalink
[mypyc] Fix using package imported inside a function (#9782)
Browse files Browse the repository at this point in the history
This fixes an issue where this code resulted in an unbound local `p`
error:

```
def f() -> None:
    import p.submodule
    print(p.x)  # Runtime error here
```

We now look up `p` from the global modules dictionary instead
of trying to use an undefined local variable.
  • Loading branch information
JukkaL committed Dec 8, 2020
1 parent 98eee40 commit eac1897
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 11 deletions.
21 changes: 17 additions & 4 deletions mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
from mypyc.ir.func_ir import FUNC_CLASSMETHOD, FUNC_STATICMETHOD
from mypyc.primitives.registry import CFunctionDescription, builtin_names
from mypyc.primitives.generic_ops import iter_op
from mypyc.primitives.misc_ops import new_slice_op, ellipsis_op, type_op
from mypyc.primitives.misc_ops import new_slice_op, ellipsis_op, type_op, get_module_dict_op
from mypyc.primitives.list_ops import list_append_op, list_extend_op, list_slice_op
from mypyc.primitives.tuple_ops import list_tuple_op, tuple_slice_op
from mypyc.primitives.dict_ops import dict_new_op, dict_set_item_op
from mypyc.primitives.dict_ops import dict_new_op, dict_set_item_op, dict_get_item_op
from mypyc.primitives.set_ops import new_set_op, set_add_op, set_update_op
from mypyc.primitives.str_ops import str_slice_op
from mypyc.primitives.int_ops import int_comparison_op_mapping
Expand Down Expand Up @@ -85,8 +85,21 @@ def transform_name_expr(builder: IRBuilder, expr: NameExpr) -> Value:
expr.node.name),
expr.node.line)

# TODO: Behavior currently only defined for Var and FuncDef node types.
return builder.read(builder.get_assignment_target(expr), expr.line)
# TODO: Behavior currently only defined for Var, FuncDef and MypyFile node types.
if isinstance(expr.node, MypyFile):
# Load reference to a module imported inside function from
# the modules dictionary. It would be closer to Python
# semantics to access modules imported inside functions
# via local variables, but this is tricky since the mypy
# AST doesn't include a Var node for the module. We
# instead load the module separately on each access.
mod_dict = builder.call_c(get_module_dict_op, [], expr.line)
obj = builder.call_c(dict_get_item_op,
[mod_dict, builder.load_static_unicode(expr.node.fullname)],
expr.line)
return obj
else:
return builder.read(builder.get_assignment_target(expr), expr.line)

return builder.load_global(expr)

Expand Down
4 changes: 4 additions & 0 deletions mypyc/irbuild/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ def transform_import(builder: IRBuilder, node: Import) -> None:
# that mypy couldn't find, since it doesn't analyze module references
# from those properly.

# TODO: Don't add local imports to the global namespace

# Miscompiling imports inside of functions, like below in import from.
if as_name:
name = as_name
Expand All @@ -140,8 +142,10 @@ def transform_import(builder: IRBuilder, node: Import) -> None:

# Python 3.7 has a nice 'PyImport_GetModule' function that we can't use :(
mod_dict = builder.call_c(get_module_dict_op, [], node.line)
# Get top-level module/package object.
obj = builder.call_c(dict_get_item_op,
[mod_dict, builder.load_static_unicode(base)], node.line)

builder.gen_method_call(
globals, '__setitem__', [builder.load_static_unicode(name), obj],
result_type=None, line=node.line)
Expand Down
51 changes: 51 additions & 0 deletions mypyc/test-data/irbuild-basic.test
Original file line number Diff line number Diff line change
Expand Up @@ -3637,3 +3637,54 @@ L0:
c = r2
r3 = (c, b, a)
return r3

[case testLocalImportSubmodule]
def f() -> int:
import p.m
return p.x
[file p/__init__.py]
x = 1
[file p/m.py]
[out]
def f():
r0 :: dict
r1, r2 :: object
r3 :: bit
r4 :: str
r5 :: object
r6 :: dict
r7 :: str
r8 :: object
r9 :: str
r10 :: int32
r11 :: bit
r12 :: dict
r13 :: str
r14 :: object
r15 :: str
r16 :: object
r17 :: int
L0:
r0 = __main__.globals :: static
r1 = p.m :: module
r2 = load_address _Py_NoneStruct
r3 = r1 != r2
if r3 goto L2 else goto L1 :: bool
L1:
r4 = load_global CPyStatic_unicode_1 :: static ('p.m')
r5 = PyImport_Import(r4)
p.m = r5 :: module
L2:
r6 = PyImport_GetModuleDict()
r7 = load_global CPyStatic_unicode_2 :: static ('p')
r8 = CPyDict_GetItem(r6, r7)
r9 = load_global CPyStatic_unicode_2 :: static ('p')
r10 = CPyDict_SetItem(r0, r9, r8)
r11 = r10 >= 0 :: signed
r12 = PyImport_GetModuleDict()
r13 = load_global CPyStatic_unicode_2 :: static ('p')
r14 = CPyDict_GetItem(r12, r13)
r15 = load_global CPyStatic_unicode_3 :: static ('x')
r16 = CPyObject_GetAttr(r14, r15)
r17 = unbox(int, r16)
return r17
49 changes: 42 additions & 7 deletions mypyc/test-data/run-imports.test
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,45 @@ import testmodule

def f(x: int) -> int:
return testmodule.factorial(5)

def g(x: int) -> int:
from welp import foo
return foo(x)

def test_import_basics() -> None:
assert f(5) == 120
assert g(5) == 5

def test_import_submodule_within_function() -> None:
import pkg.mod
assert pkg.x == 1
assert pkg.mod.y == 2

def test_import_as_submodule_within_function() -> None:
import pkg.mod as mm
assert mm.y == 2

# TODO: Don't add local imports to globals()
#
# def test_local_import_not_in_globals() -> None:
# import nob
# assert 'nob' not in globals()

def test_import_module_without_stub_in_function() -> None:
# 'virtualenv' must not have a stub in typeshed for this test case
import virtualenv # type: ignore
# TODO: We shouldn't add local imports to globals()
# assert 'virtualenv' not in globals()
assert isinstance(virtualenv.__name__, str)

def test_import_as_module_without_stub_in_function() -> None:
# 'virtualenv' must not have a stub in typeshed for this test case
import virtualenv as vv # type: ignore
assert 'virtualenv' not in globals()
# TODO: We shouldn't add local imports to globals()
# assert 'vv' not in globals()
assert isinstance(vv.__name__, str)

[file testmodule.py]
def factorial(x: int) -> int:
if x == 0:
Expand All @@ -17,13 +53,12 @@ def factorial(x: int) -> int:
[file welp.py]
def foo(x: int) -> int:
return x
[file driver.py]
from native import f, g
print(f(5))
print(g(5))
[out]
120
5
[file pkg/__init__.py]
x = 1
[file pkg/mod.py]
y = 2
[file nob.py]
z = 3

[case testImportMissing]
# The unchecked module is configured by the test harness to not be
Expand Down

0 comments on commit eac1897

Please sign in to comment.