Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions mypyc/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,7 @@ def get_mypy_config(
mypyc_sources = all_sources

if compiler_options.separate:
mypyc_sources = [
src for src in mypyc_sources if src.path and not src.path.endswith("__init__.py")
]
mypyc_sources = [src for src in mypyc_sources if src.path]

if not mypyc_sources:
return mypyc_sources, all_sources, options
Expand All @@ -243,6 +241,10 @@ def get_mypy_config(
return mypyc_sources, all_sources, options


def is_package_source(source: BuildSource) -> bool:
return source.path is not None and os.path.split(source.path)[1] == "__init__.py"


def generate_c_extension_shim(
full_module_name: str, module_name: str, dir_name: str, group_name: str
) -> str:
Expand Down Expand Up @@ -388,7 +390,7 @@ def build_using_shared_lib(
# since this seems to be needed for it to end up in the right place.
full_module_name = source.module
assert source.path
if os.path.split(source.path)[1] == "__init__.py":
if is_package_source(source):
full_module_name += ".__init__"
extensions.append(
get_extension()(
Expand Down Expand Up @@ -534,6 +536,7 @@ def mypyc_build(
use_shared_lib = (
len(mypyc_sources) > 1
or any("." in x.module for x in mypyc_sources)
or any(is_package_source(x) for x in mypyc_sources)
or always_use_shared_lib
)

Expand Down
36 changes: 34 additions & 2 deletions mypyc/codegen/emitmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from mypyc.codegen.literals import Literals
from mypyc.common import (
EXT_SUFFIX,
IS_FREE_THREADED,
MODULE_PREFIX,
PREFIX,
Expand Down Expand Up @@ -1286,11 +1287,42 @@ def emit_module_init_func(
f"if (unlikely({module_static} == NULL))",
" goto fail;",
)

emitter.emit_line(f'modname = PyUnicode_FromString("{module_name}");')
emitter.emit_line("if (modname == NULL) CPyError_OutOfMemory();")
emitter.emit_line("int rv = 0;")
if self.group_name:
shared_lib_mod_name = shared_lib_name(self.group_name)
emitter.emit_line("PyObject *mod_dict = PyImport_GetModuleDict();")
emitter.emit_line("PyObject *shared_lib = NULL;")
emitter.emit_line(
f'rv = PyDict_GetItemStringRef(mod_dict, "{shared_lib_mod_name}", &shared_lib);'
)
emitter.emit_line("if (rv < 0) goto fail;")
emitter.emit_line(
'PyObject *shared_lib_file = PyObject_GetAttrString(shared_lib, "__file__");'
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use PyDict_GetItemStringRef instead, as this returns a borrowed reference, and we decref it below.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you might be thinking of PyDict_GetItemString? this one returns a new reference according to docs https://docs.python.org/3/c-api/object.html#c.PyObject_GetAttrString

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was confusing the two.

)
emitter.emit_line("if (shared_lib_file == NULL) goto fail;")
else:
emitter.emit_line(
f'PyObject *shared_lib_file = PyUnicode_FromString("{module_name + EXT_SUFFIX}");'
Comment thread
p-sawicki marked this conversation as resolved.
)
emitter.emit_line("if (shared_lib_file == NULL) CPyError_OutOfMemory();")
emitter.emit_line(f'PyObject *ext_suffix = PyUnicode_FromString("{EXT_SUFFIX}");')
emitter.emit_line("if (ext_suffix == NULL) CPyError_OutOfMemory();")
is_pkg = int(self.source_paths[module_name].endswith("__init__.py"))
emitter.emit_line(f"Py_ssize_t is_pkg = {is_pkg};")

emitter.emit_line(
f"rv = CPyImport_SetDunderAttrs({module_static}, modname, shared_lib_file, ext_suffix, is_pkg);"
)
emitter.emit_line("Py_DECREF(ext_suffix);")
emitter.emit_line("Py_DECREF(shared_lib_file);")
emitter.emit_line("if (rv < 0) goto fail;")

# Register in sys.modules early so that circular imports via
# CPyImport_ImportNative can detect that this module is already
# being initialized and avoid re-executing the module body.
emitter.emit_line(f'modname = PyUnicode_FromString("{module_name}");')
emitter.emit_line("if (modname == NULL) CPyError_OutOfMemory();")
emitter.emit_line(
f"if (PyObject_SetItem(PyImport_GetModuleDict(), modname, {module_static}) < 0)"
)
Expand Down
2 changes: 2 additions & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,8 @@ PyObject *CPyImport_ImportNative(PyObject *module_name,
CPyModule **module_static,
PyObject *shared_lib_file, PyObject *ext_suffix,
Py_ssize_t is_package);
int CPyImport_SetDunderAttrs(PyObject *module, PyObject *module_name, PyObject *shared_lib_file,
PyObject *ext_suffix, Py_ssize_t is_package);

PyObject *CPySingledispatch_RegisterFunction(PyObject *singledispatch_func, PyObject *cls,
PyObject *func);
Expand Down
109 changes: 67 additions & 42 deletions mypyc/lib-rt/misc_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -1225,6 +1225,47 @@ static int CPyImport_InitSpecClasses(void) {
return 0;
}

// Set __package__ before executing the module body so it is available
// during module initialization. For a package, __package__ is the module
// name itself. For a non-package submodule "a.b.c", it is "a.b". For a
// top-level non-package module, it is "".
static int CPyImport_SetModulePackage(PyObject *modobj, PyObject *module_name,
Py_ssize_t is_package) {
PyObject *pkg = NULL;
int rc = PyObject_GetOptionalAttrString(modobj, "__package__", &pkg);
if (rc < 0) {
return -1;
}
if (pkg != NULL && pkg != Py_None) {
Py_DECREF(pkg);
return 0;
}
Py_XDECREF(pkg);

PyObject *package_name = NULL;
if (is_package) {
package_name = module_name;
Py_INCREF(package_name);
} else {
Py_ssize_t name_len = PyUnicode_GetLength(module_name);
if (name_len < 0) {
return -1;
}
Py_ssize_t dot = PyUnicode_FindChar(module_name, '.', 0, name_len, -1);
if (dot >= 0) {
package_name = PyUnicode_Substring(module_name, 0, dot);
} else {
package_name = PyUnicode_FromString("");
}
}
if (package_name == NULL) {
return -1;
}
rc = PyObject_SetAttrString(modobj, "__package__", package_name);
Py_DECREF(package_name);
return rc;
}

// Derive and set __file__ on modobj from the shared library path, module name,
// and extension suffix. Returns 0 on success, -1 on error.
static int CPyImport_SetModuleFile(PyObject *modobj, PyObject *module_name,
Expand Down Expand Up @@ -1509,47 +1550,7 @@ PyObject *CPyImport_ImportNative(PyObject *module_name,
goto fail;
}

// Set __package__ before executing the module body so it is available
// during module initialization. For a package, __package__ is the module
// name itself. For a non-package submodule "a.b.c", it is "a.b". For a
// top-level non-package module, it is "".
{
PyObject *pkg = NULL;
if (PyObject_GetOptionalAttrString(modobj, "__package__", &pkg) < 0) {
goto fail;
}
if (pkg == NULL || pkg == Py_None) {
Py_XDECREF(pkg);
PyObject *package_name;
if (is_package) {
package_name = module_name;
Py_INCREF(package_name);
} else if (dot >= 0) {
package_name = PyUnicode_Substring(module_name, 0, dot);
} else {
package_name = PyUnicode_FromString("");
if (package_name == NULL) {
CPyError_OutOfMemory();
}
}
if (PyObject_SetAttrString(modobj, "__package__", package_name) < 0) {
Py_DECREF(package_name);
goto fail;
}
Py_DECREF(package_name);
} else {
Py_DECREF(pkg);
}
}

if (CPyImport_SetModuleFile(modobj, module_name, shared_lib_file, ext_suffix,
is_package) < 0) {
goto fail;
}
if (is_package && CPyImport_SetModulePath(modobj) < 0) {
goto fail;
}
if (CPyImport_SetModuleSpec(modobj, module_name, is_package) < 0) {
if (CPyImport_SetDunderAttrs(modobj, module_name, shared_lib_file, ext_suffix, is_package) < 0) {
goto fail;
}

Expand Down Expand Up @@ -1577,10 +1578,34 @@ PyObject *CPyImport_ImportNative(PyObject *module_name,
PyErr_Restore(exc_type, exc_val, exc_tb);
Py_XDECREF(parent_module);
Py_XDECREF(child_name);
Py_DECREF(modobj);
Py_CLEAR(*module_static);
return NULL;
}

int CPyImport_SetDunderAttrs(PyObject *module, PyObject *module_name, PyObject *shared_lib_file,
PyObject *ext_suffix, Py_ssize_t is_package)
{
int res = CPyImport_SetModulePackage(module, module_name, is_package);
if (res < 0) {
return res;
}

res = CPyImport_SetModuleFile(module, module_name, shared_lib_file, ext_suffix,
is_package);
if (res < 0) {
return res;
}

if (is_package) {
res = CPyImport_SetModulePath(module);
if (res < 0) {
return res;
}
}

return CPyImport_SetModuleSpec(module, module_name, is_package);
}

#if CPY_3_14_FEATURES

#include "internal/pycore_object.h"
Expand Down
14 changes: 14 additions & 0 deletions mypyc/test-data/commandline.test
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,20 @@ print(type(Eggs(obj1=pkg1.A.B())["obj1"]).__module__)
B
pkg2.mod2

[case testCompilePackageOnlyInitPy]
# cmd: pkg/__init__.py
import os.path
import pkg

print(pkg.x)
assert os.path.splitext(pkg.__file__)[1] != ".py"

[file pkg/__init__.py]
x: int = 1

[out]
1

[case testStrictBytesRequired]
# cmd: --no-strict-bytes a.py

Expand Down
32 changes: 32 additions & 0 deletions mypyc/test-data/run-multimodule.test
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,38 @@ globals()['A'] = None
[file driver.py]
import other_main

[case testNonNativeImportInPackageFile]
# The import is really non-native only in separate compilation mode where __init__.py and
# other_cache.py are in different libraries and the import uses the standard Python procedure.
# Python imports are resolved using __path__ and __spec__ from the package file so this checks
# that they are set up correctly.
[file other/__init__.py]
from other.other_cache import Cache

x = 1
[file other/other_cache.py]
class Cache:
pass

[file driver.py]
import other

[case testRelativeImportInPackageFile]
# Relative imports from a compiled package __init__ depend on package metadata being
# available while the package module body is executing.
[file other/__init__.py]
assert __package__ == "other"
from .other_cache import Cache

x = 1
[file other/other_cache.py]
class Cache:
pass

[file driver.py]
import other
assert other.Cache.__name__ == "Cache"

[case testMultiModuleSameNames]
# Use same names in both modules
import other
Expand Down
Loading