diff --git a/mypyc/build.py b/mypyc/build.py index f87a238299b9..439734e39b9e 100644 --- a/mypyc/build.py +++ b/mypyc/build.py @@ -215,9 +215,7 @@ def get_mypy_config( mypyc_sources = all_sources if compiler_options.separate: - mypyc_sources = [ - src for src in mypyc_sources if src.path and not src.path.endswith("__init__.py") - ] + mypyc_sources = [src for src in mypyc_sources if src.path] if not mypyc_sources: return mypyc_sources, all_sources, options @@ -243,6 +241,10 @@ def get_mypy_config( return mypyc_sources, all_sources, options +def is_package_source(source: BuildSource) -> bool: + return source.path is not None and os.path.split(source.path)[1] == "__init__.py" + + def generate_c_extension_shim( full_module_name: str, module_name: str, dir_name: str, group_name: str ) -> str: @@ -388,7 +390,7 @@ def build_using_shared_lib( # since this seems to be needed for it to end up in the right place. full_module_name = source.module assert source.path - if os.path.split(source.path)[1] == "__init__.py": + if is_package_source(source): full_module_name += ".__init__" extensions.append( get_extension()( @@ -534,6 +536,7 @@ def mypyc_build( use_shared_lib = ( len(mypyc_sources) > 1 or any("." in x.module for x in mypyc_sources) + or any(is_package_source(x) for x in mypyc_sources) or always_use_shared_lib ) diff --git a/mypyc/codegen/emitmodule.py b/mypyc/codegen/emitmodule.py index 2111f1208609..563662476fbf 100644 --- a/mypyc/codegen/emitmodule.py +++ b/mypyc/codegen/emitmodule.py @@ -46,6 +46,7 @@ ) from mypyc.codegen.literals import Literals from mypyc.common import ( + EXT_SUFFIX, IS_FREE_THREADED, MODULE_PREFIX, PREFIX, @@ -1286,11 +1287,42 @@ def emit_module_init_func( f"if (unlikely({module_static} == NULL))", " goto fail;", ) + + emitter.emit_line(f'modname = PyUnicode_FromString("{module_name}");') + emitter.emit_line("if (modname == NULL) CPyError_OutOfMemory();") + emitter.emit_line("int rv = 0;") + if self.group_name: + shared_lib_mod_name = shared_lib_name(self.group_name) + emitter.emit_line("PyObject *mod_dict = PyImport_GetModuleDict();") + emitter.emit_line("PyObject *shared_lib = NULL;") + emitter.emit_line( + f'rv = PyDict_GetItemStringRef(mod_dict, "{shared_lib_mod_name}", &shared_lib);' + ) + emitter.emit_line("if (rv < 0) goto fail;") + emitter.emit_line( + 'PyObject *shared_lib_file = PyObject_GetAttrString(shared_lib, "__file__");' + ) + emitter.emit_line("if (shared_lib_file == NULL) goto fail;") + else: + emitter.emit_line( + f'PyObject *shared_lib_file = PyUnicode_FromString("{module_name + EXT_SUFFIX}");' + ) + emitter.emit_line("if (shared_lib_file == NULL) CPyError_OutOfMemory();") + emitter.emit_line(f'PyObject *ext_suffix = PyUnicode_FromString("{EXT_SUFFIX}");') + emitter.emit_line("if (ext_suffix == NULL) CPyError_OutOfMemory();") + is_pkg = int(self.source_paths[module_name].endswith("__init__.py")) + emitter.emit_line(f"Py_ssize_t is_pkg = {is_pkg};") + + emitter.emit_line( + f"rv = CPyImport_SetDunderAttrs({module_static}, modname, shared_lib_file, ext_suffix, is_pkg);" + ) + emitter.emit_line("Py_DECREF(ext_suffix);") + emitter.emit_line("Py_DECREF(shared_lib_file);") + emitter.emit_line("if (rv < 0) goto fail;") + # Register in sys.modules early so that circular imports via # CPyImport_ImportNative can detect that this module is already # being initialized and avoid re-executing the module body. - emitter.emit_line(f'modname = PyUnicode_FromString("{module_name}");') - emitter.emit_line("if (modname == NULL) CPyError_OutOfMemory();") emitter.emit_line( f"if (PyObject_SetItem(PyImport_GetModuleDict(), modname, {module_static}) < 0)" ) diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index 3a7a08a5dc6a..89ef4d0749a4 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -967,6 +967,8 @@ PyObject *CPyImport_ImportNative(PyObject *module_name, CPyModule **module_static, PyObject *shared_lib_file, PyObject *ext_suffix, Py_ssize_t is_package); +int CPyImport_SetDunderAttrs(PyObject *module, PyObject *module_name, PyObject *shared_lib_file, + PyObject *ext_suffix, Py_ssize_t is_package); PyObject *CPySingledispatch_RegisterFunction(PyObject *singledispatch_func, PyObject *cls, PyObject *func); diff --git a/mypyc/lib-rt/misc_ops.c b/mypyc/lib-rt/misc_ops.c index a13243fc40d6..6f4843132537 100644 --- a/mypyc/lib-rt/misc_ops.c +++ b/mypyc/lib-rt/misc_ops.c @@ -1225,6 +1225,47 @@ static int CPyImport_InitSpecClasses(void) { return 0; } +// Set __package__ before executing the module body so it is available +// during module initialization. For a package, __package__ is the module +// name itself. For a non-package submodule "a.b.c", it is "a.b". For a +// top-level non-package module, it is "". +static int CPyImport_SetModulePackage(PyObject *modobj, PyObject *module_name, + Py_ssize_t is_package) { + PyObject *pkg = NULL; + int rc = PyObject_GetOptionalAttrString(modobj, "__package__", &pkg); + if (rc < 0) { + return -1; + } + if (pkg != NULL && pkg != Py_None) { + Py_DECREF(pkg); + return 0; + } + Py_XDECREF(pkg); + + PyObject *package_name = NULL; + if (is_package) { + package_name = module_name; + Py_INCREF(package_name); + } else { + Py_ssize_t name_len = PyUnicode_GetLength(module_name); + if (name_len < 0) { + return -1; + } + Py_ssize_t dot = PyUnicode_FindChar(module_name, '.', 0, name_len, -1); + if (dot >= 0) { + package_name = PyUnicode_Substring(module_name, 0, dot); + } else { + package_name = PyUnicode_FromString(""); + } + } + if (package_name == NULL) { + return -1; + } + rc = PyObject_SetAttrString(modobj, "__package__", package_name); + Py_DECREF(package_name); + return rc; +} + // Derive and set __file__ on modobj from the shared library path, module name, // and extension suffix. Returns 0 on success, -1 on error. static int CPyImport_SetModuleFile(PyObject *modobj, PyObject *module_name, @@ -1509,47 +1550,7 @@ PyObject *CPyImport_ImportNative(PyObject *module_name, goto fail; } - // Set __package__ before executing the module body so it is available - // during module initialization. For a package, __package__ is the module - // name itself. For a non-package submodule "a.b.c", it is "a.b". For a - // top-level non-package module, it is "". - { - PyObject *pkg = NULL; - if (PyObject_GetOptionalAttrString(modobj, "__package__", &pkg) < 0) { - goto fail; - } - if (pkg == NULL || pkg == Py_None) { - Py_XDECREF(pkg); - PyObject *package_name; - if (is_package) { - package_name = module_name; - Py_INCREF(package_name); - } else if (dot >= 0) { - package_name = PyUnicode_Substring(module_name, 0, dot); - } else { - package_name = PyUnicode_FromString(""); - if (package_name == NULL) { - CPyError_OutOfMemory(); - } - } - if (PyObject_SetAttrString(modobj, "__package__", package_name) < 0) { - Py_DECREF(package_name); - goto fail; - } - Py_DECREF(package_name); - } else { - Py_DECREF(pkg); - } - } - - if (CPyImport_SetModuleFile(modobj, module_name, shared_lib_file, ext_suffix, - is_package) < 0) { - goto fail; - } - if (is_package && CPyImport_SetModulePath(modobj) < 0) { - goto fail; - } - if (CPyImport_SetModuleSpec(modobj, module_name, is_package) < 0) { + if (CPyImport_SetDunderAttrs(modobj, module_name, shared_lib_file, ext_suffix, is_package) < 0) { goto fail; } @@ -1577,10 +1578,34 @@ PyObject *CPyImport_ImportNative(PyObject *module_name, PyErr_Restore(exc_type, exc_val, exc_tb); Py_XDECREF(parent_module); Py_XDECREF(child_name); - Py_DECREF(modobj); + Py_CLEAR(*module_static); return NULL; } +int CPyImport_SetDunderAttrs(PyObject *module, PyObject *module_name, PyObject *shared_lib_file, + PyObject *ext_suffix, Py_ssize_t is_package) +{ + int res = CPyImport_SetModulePackage(module, module_name, is_package); + if (res < 0) { + return res; + } + + res = CPyImport_SetModuleFile(module, module_name, shared_lib_file, ext_suffix, + is_package); + if (res < 0) { + return res; + } + + if (is_package) { + res = CPyImport_SetModulePath(module); + if (res < 0) { + return res; + } + } + + return CPyImport_SetModuleSpec(module, module_name, is_package); +} + #if CPY_3_14_FEATURES #include "internal/pycore_object.h" diff --git a/mypyc/test-data/commandline.test b/mypyc/test-data/commandline.test index 892666910073..f89accd5b816 100644 --- a/mypyc/test-data/commandline.test +++ b/mypyc/test-data/commandline.test @@ -313,6 +313,20 @@ print(type(Eggs(obj1=pkg1.A.B())["obj1"]).__module__) B pkg2.mod2 +[case testCompilePackageOnlyInitPy] +# cmd: pkg/__init__.py +import os.path +import pkg + +print(pkg.x) +assert os.path.splitext(pkg.__file__)[1] != ".py" + +[file pkg/__init__.py] +x: int = 1 + +[out] +1 + [case testStrictBytesRequired] # cmd: --no-strict-bytes a.py diff --git a/mypyc/test-data/run-multimodule.test b/mypyc/test-data/run-multimodule.test index 8d9505a67c52..2a29d7257009 100644 --- a/mypyc/test-data/run-multimodule.test +++ b/mypyc/test-data/run-multimodule.test @@ -473,6 +473,38 @@ globals()['A'] = None [file driver.py] import other_main +[case testNonNativeImportInPackageFile] +# The import is really non-native only in separate compilation mode where __init__.py and +# other_cache.py are in different libraries and the import uses the standard Python procedure. +# Python imports are resolved using __path__ and __spec__ from the package file so this checks +# that they are set up correctly. +[file other/__init__.py] +from other.other_cache import Cache + +x = 1 +[file other/other_cache.py] +class Cache: + pass + +[file driver.py] +import other + +[case testRelativeImportInPackageFile] +# Relative imports from a compiled package __init__ depend on package metadata being +# available while the package module body is executing. +[file other/__init__.py] +assert __package__ == "other" +from .other_cache import Cache + +x = 1 +[file other/other_cache.py] +class Cache: + pass + +[file driver.py] +import other +assert other.Cache.__name__ == "Cache" + [case testMultiModuleSameNames] # Use same names in both modules import other