Skip to content

Clean up cffi resources in file #679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions cuda_core/examples/strided_memory_view_cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
# SPDX-License-Identifier: Apache-2.0

# ################################################################################
#
# This demo aims to illustrate two takeaways:
#
# 1. The similarity between CPU and GPU JIT-compilation with C++ sources
# 2. How to use StridedMemoryView to interface with foreign C/C++ functions
#
# To facilitate this demo, we use cffi (https://cffi.readthedocs.io/) for the CPU
# path, which can be easily installed from pip or conda following their instructions.
# We also use NumPy/CuPy as the CPU/GPU array container.
#
# ################################################################################

import importlib
import shutil
import string
import sys
import tempfile

try:
from cffi import FFI
except ImportError:
print("cffi is not installed, the CPU example will be skipped", file=sys.stderr)
FFI = None
import numpy as np

from cuda.core.experimental.utils import StridedMemoryView, args_viewable_as_strided_memory

# ################################################################################
#
# Usually this entire code block is in a separate file, built as a Python extension
# module that can be imported by users at run time. For illustrative purposes we
# use JIT compilation to make this demo self-contained.
#
# Here we assume an in-place operation, equivalent to the following NumPy code:
#
# >>> arr = ...
# >>> assert arr.dtype == np.int32
# >>> assert arr.ndim == 1
# >>> arr += np.arange(arr.size, dtype=arr.dtype)
#
# is implemented for both CPU and GPU at low-level, with the following C function
# signature:
func_name = "inplace_plus_arange_N"
func_sig = f"void {func_name}(int* data, size_t N)"


# Now we are prepared to run the code from the user's perspective!
#
# ################################################################################


# Below, as a user we want to perform the said in-place operation on a CPU
# or GPU, by calling the corresponding function implemented "elsewhere"
# (in the body of run function).


# We assume the 0-th argument supports either DLPack or CUDA Array Interface (both
# of which are supported by StridedMemoryView).
@args_viewable_as_strided_memory((0,))
def my_func(arr):
global cpu_func
global cpu_prog
# Create a memory view over arr (assumed to be a 1D array of int32). The stream
# ordering is taken care of, so that arr can be safely accessed on our work
# stream (ordered after a data stream on which arr is potentially prepared).
view = arr.view(-1)
assert isinstance(view, StridedMemoryView)
assert len(view.shape) == 1
assert view.dtype == np.int32
assert not view.is_device_accessible

size = view.shape[0]
# DLPack also supports host arrays. We want to know if the array data is
# accessible from the GPU, and dispatch to the right routine accordingly.
cpu_func(cpu_prog.cast("int*", view.ptr), size)


def run():
global my_func
if not FFI:
return
# Here is a concrete (very naive!) implementation on CPU:
cpu_code = string.Template(r"""
extern "C"
$func_sig {
for (size_t i = 0; i < N; i++) {
data[i] += i;
}
}
""").substitute(func_sig=func_sig)
# This is cffi's way of JIT compiling & loading a CPU function. cffi builds an
# extension module that has the Python binding to the underlying C function.
# For more details, please refer to cffi's documentation.
cpu_prog = FFI()
cpu_prog.cdef(f"{func_sig};")
cpu_prog.set_source(
"_cpu_obj",
cpu_code,
source_extension=".cpp",
extra_compile_args=["-std=c++11"],
)
temp_dir = tempfile.mkdtemp()
saved_sys_path = sys.path.copy()
try:
cpu_prog.compile(tmpdir=temp_dir)

sys.path.append(temp_dir)
cpu_func = getattr(importlib.import_module("_cpu_obj.lib"), func_name)

# Create input array on CPU
arr_cpu = np.zeros(1024, dtype=np.int32)
print(f"before: {arr_cpu[:10]=}")

# Run the workload
my_func(arr_cpu)

# Check the result
print(f"after: {arr_cpu[:10]=}")
assert np.allclose(arr_cpu, np.arange(1024, dtype=np.int32))
finally:
sys.path = saved_sys_path
# to allow FFI module to unload, we delete references to
# to cpu_func
del cpu_func, my_func
# clean up temp directory
shutil.rmtree(temp_dir)


if __name__ == "__main__":
run()
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,9 @@
#
# ################################################################################

import importlib
import string
import sys

try:
from cffi import FFI
except ImportError:
print("cffi is not installed, the CPU example will be skipped", file=sys.stderr)
FFI = None
try:
import cupy as cp
except ImportError:
Expand Down Expand Up @@ -52,51 +46,6 @@
func_name = "inplace_plus_arange_N"
func_sig = f"void {func_name}(int* data, size_t N)"

# Here is a concrete (very naive!) implementation on CPU:
if FFI:
cpu_code = string.Template(r"""
extern "C"
$func_sig {
for (size_t i = 0; i < N; i++) {
data[i] += i;
}
}
""").substitute(func_sig=func_sig)
# This is cffi's way of JIT compiling & loading a CPU function. cffi builds an
# extension module that has the Python binding to the underlying C function.
# For more details, please refer to cffi's documentation.
cpu_prog = FFI()
cpu_prog.cdef(f"{func_sig};")
cpu_prog.set_source(
"_cpu_obj",
cpu_code,
source_extension=".cpp",
extra_compile_args=["-std=c++11"],
)
cpu_prog.compile()
cpu_func = getattr(importlib.import_module("_cpu_obj.lib"), func_name)

# Here is a concrete (again, very naive!) implementation on GPU:
if cp:
gpu_code = string.Template(r"""
extern "C"
__global__ $func_sig {
const size_t tid = threadIdx.x + blockIdx.x * blockDim.x;
const size_t stride_size = gridDim.x * blockDim.x;
for (size_t i = tid; i < N; i += stride_size) {
data[i] += i;
}
}
""").substitute(func_sig=func_sig)

# To know the GPU's compute capability, we need to identify which GPU to use.
dev = Device(0)
dev.set_current()
arch = "".join(f"{i}" for i in dev.compute_capability)
gpu_prog = Program(gpu_code, code_type="c++", options=ProgramOptions(arch=f"sm_{arch}", std="c++11"))
mod = gpu_prog.compile(target_type="cubin")
gpu_ker = mod.get_kernel(func_name)

# Now we are prepared to run the code from the user's perspective!
#
# ################################################################################
Expand All @@ -109,60 +58,72 @@
# We assume the 0-th argument supports either DLPack or CUDA Array Interface (both
# of which are supported by StridedMemoryView).
@args_viewable_as_strided_memory((0,))
def my_func(arr, work_stream):
def my_func(arr, work_stream, gpu_ker):
# Create a memory view over arr (assumed to be a 1D array of int32). The stream
# ordering is taken care of, so that arr can be safely accessed on our work
# stream (ordered after a data stream on which arr is potentially prepared).
view = arr.view(work_stream.handle if work_stream else -1)
assert isinstance(view, StridedMemoryView)
assert len(view.shape) == 1
assert view.dtype == np.int32
assert view.is_device_accessible

size = view.shape[0]
# DLPack also supports host arrays. We want to know if the array data is
# accessible from the GPU, and dispatch to the right routine accordingly.
if view.is_device_accessible:
block = 256
grid = (size + block - 1) // block
config = LaunchConfig(grid=grid, block=block)
launch(work_stream, config, gpu_ker, view.ptr, np.uint64(size))
# Here we're being conservative and synchronize over our work stream,
# assuming we do not know the data stream; if we know then we could
# just order the data stream after the work stream here, e.g.
#
# data_stream.wait(work_stream)
#
# without an expensive synchronization (with respect to the host).
work_stream.sync()
else:
cpu_func(cpu_prog.cast("int*", view.ptr), size)


# This takes the CPU path
if FFI:
# Create input array on CPU
arr_cpu = np.zeros(1024, dtype=np.int32)
print(f"before: {arr_cpu[:10]=}")

# Run the workload
my_func(arr_cpu, None)

# Check the result
print(f"after: {arr_cpu[:10]=}")
assert np.allclose(arr_cpu, np.arange(1024, dtype=np.int32))


# This takes the GPU path
if cp:
block = 256
grid = (size + block - 1) // block
config = LaunchConfig(grid=grid, block=block)
launch(work_stream, config, gpu_ker, view.ptr, np.uint64(size))
# Here we're being conservative and synchronize over our work stream,
# assuming we do not know the data stream; if we know then we could
# just order the data stream after the work stream here, e.g.
#
# data_stream.wait(work_stream)
#
# without an expensive synchronization (with respect to the host).
work_stream.sync()


def run():
global my_func
if not cp:
return None
# Here is a concrete (very naive!) implementation on GPU:
gpu_code = string.Template(r"""
extern "C"
__global__ $func_sig {
const size_t tid = threadIdx.x + blockIdx.x * blockDim.x;
const size_t stride_size = gridDim.x * blockDim.x;
for (size_t i = tid; i < N; i += stride_size) {
data[i] += i;
}
}
""").substitute(func_sig=func_sig)

# To know the GPU's compute capability, we need to identify which GPU to use.
dev = Device(0)
dev.set_current()
arch = "".join(f"{i}" for i in dev.compute_capability)
gpu_prog = Program(gpu_code, code_type="c++", options=ProgramOptions(arch=f"sm_{arch}", std="c++11"))
mod = gpu_prog.compile(target_type="cubin")
gpu_ker = mod.get_kernel(func_name)

s = dev.create_stream()
# Create input array on GPU
arr_gpu = cp.ones(1024, dtype=cp.int32)
print(f"before: {arr_gpu[:10]=}")
try:
# Create input array on GPU
arr_gpu = cp.ones(1024, dtype=cp.int32)
print(f"before: {arr_gpu[:10]=}")

# Run the workload
my_func(arr_gpu, s, gpu_ker)

# Check the result
print(f"after: {arr_gpu[:10]=}")
assert cp.allclose(arr_gpu, 1 + cp.arange(1024, dtype=cp.int32))
finally:
s.close()

# Run the workload
my_func(arr_gpu, s)

# Check the result
print(f"after: {arr_gpu[:10]=}")
assert cp.allclose(arr_gpu, 1 + cp.arange(1024, dtype=cp.int32))
s.close()
if __name__ == "__main__":
run()
17 changes: 0 additions & 17 deletions cuda_core/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# Copyright 2024 NVIDIA Corporation. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import glob
import os
import sys

try:
from cuda.bindings import driver
Expand Down Expand Up @@ -67,21 +65,6 @@ def pop_all_contexts():
return pop_all_contexts


# samples relying on cffi could fail as the modules cannot be imported
sys.path.append(os.getcwd())


@pytest.fixture(scope="session", autouse=True)
def clean_up_cffi_files():
yield
files = glob.glob(os.path.join(os.getcwd(), "_cpu_obj*"))
for f in files:
try: # noqa: SIM105
os.remove(f)
except FileNotFoundError:
pass # noqa: SIM105


skipif_testing_with_compute_sanitizer = pytest.mark.skipif(
os.environ.get("CUDA_PYTHON_TESTING_WITH_COMPUTE_SANITIZER", "0") == "1",
reason="The compute-sanitizer is running, and this test causes an API error.",
Expand Down
Loading