diff --git a/cuda_core/cuda/core/experimental/_module.py b/cuda_core/cuda/core/experimental/_module.py index 021165c5c..afb229538 100644 --- a/cuda_core/cuda/core/experimental/_module.py +++ b/cuda_core/cuda/core/experimental/_module.py @@ -302,6 +302,14 @@ def _init(cls, module, code_type, *, symbol_mapping: Optional[dict] = None): return self + @classmethod + def _reduce_helper(self, module, code_type, symbol_mapping): + # just for forwarding kwargs + return ObjectCode._init(module, code_type, symbol_mapping=symbol_mapping) + + def __reduce__(self): + return ObjectCode._reduce_helper, (self._module, self._code_type, self._sym_map) + @staticmethod def from_cubin(module: Union[bytes, str], *, symbol_mapping: Optional[dict] = None) -> "ObjectCode": """Create an :class:`ObjectCode` instance from an existing cubin. diff --git a/cuda_core/tests/test_module.py b/cuda_core/tests/test_module.py index cc8620906..d85a4745e 100644 --- a/cuda_core/tests/test_module.py +++ b/cuda_core/tests/test_module.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import ctypes +import pickle # nosec B403, B301 import warnings import pytest @@ -245,3 +246,13 @@ def test_num_args_error_handling(deinit_all_contexts_function, cuda12_prerequisi with pytest.raises(CUDAError): # assignment resolves linter error "B018: useless expression" _ = krn.num_arguments + + +def test_module_serialization_roundtrip(get_saxpy_kernel): + _, objcode = get_saxpy_kernel + result = pickle.loads(pickle.dumps(objcode)) # nosec B403, B301 + + assert isinstance(result, ObjectCode) + assert objcode.code == result.code + assert objcode._sym_map == result._sym_map + assert objcode._code_type == result._code_type