Skip to content

Commit

Permalink
feat(compile): add python bindings for functions to check whether the…
Browse files Browse the repository at this point in the history
… compiler is GPU enabled and whether a GPU is available on the system.
  • Loading branch information
antoniupop committed Jun 21, 2024
1 parent c6c1e99 commit 5b2dd07
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,14 @@ void initDataflowParallelization() {
mlir::concretelang::dfr::_dfr_set_required(true);
}

bool checkGPURuntimeEnabled() {
return mlir::concretelang::dfr::check_cuda_runtime_enabled();
}

bool checkCudaDeviceAvailable() {
return mlir::concretelang::dfr::check_cuda_device_available();
}

std::string roundTrip(const char *module) {
std::shared_ptr<mlir::concretelang::CompilationContext> ccx =
mlir::concretelang::CompilationContext::createShared();
Expand Down Expand Up @@ -673,6 +681,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
m.def("terminate_df_parallelization", &terminateDataflowParallelization);

m.def("init_df_parallelization", &initDataflowParallelization);
m.def("check_gpu_runtime_enabled", &checkGPURuntimeEnabled);
m.def("check_cuda_device_available", &checkCudaDeviceAvailable);

pybind11::enum_<mlir::concretelang::Backend>(m, "Backend")
.value("CPU", mlir::concretelang::Backend::CPU)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from mlir._mlir_libs._concretelang._compiler import (
terminate_df_parallelization as _terminate_df_parallelization,
init_df_parallelization as _init_df_parallelization,
check_gpu_runtime_enabled as _check_gpu_runtime_enabled,
check_cuda_device_available as _check_cuda_device_available,
)
from mlir._mlir_libs._concretelang._compiler import round_trip as _round_trip
from mlir._mlir_libs._concretelang._compiler import (
Expand Down Expand Up @@ -48,6 +50,16 @@ def init_dfr():
and the runtime is needed"""
_init_df_parallelization()

def check_gpu_enabled() -> bool:
"""Check whether the compiler and runtime support GPU offloading.
GPU offloading is not always available, in particular in non-GPU wheels."""
return _check_gpu_runtime_enabled()

def check_gpu_available() -> bool:
"""Check whether a CUDA device is available and online."""
return _check_cuda_device_available()


# Cleanly terminate the dataflow runtime if it has been initialized
# (does nothing otherwise)
Expand Down

0 comments on commit 5b2dd07

Please sign in to comment.