Skip to content

Commit

Permalink
feat(frontend): add python bindings for functions to check whether th…
Browse files Browse the repository at this point in the history
…e compiler is GPU enabled and whether a GPU is available on the system.
  • Loading branch information
antoniupop committed Jun 24, 2024
1 parent 3736785 commit 711758c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "concretelang/Common/Keysets.h"
#include "concretelang/Dialect/FHE/IR/FHEOpsDialect.h.inc"
#include "concretelang/Runtime/DFRuntime.hpp"
#include "concretelang/Runtime/GPUDFG.hpp"
#include "concretelang/ServerLib/ServerLib.h"
#include "concretelang/Support/logging.h"
#include <llvm/Support/Debug.h>
Expand Down Expand Up @@ -462,6 +463,14 @@ void initDataflowParallelization() {
mlir::concretelang::dfr::_dfr_set_required(true);
}

bool checkGPURuntimeEnabled() {
return mlir::concretelang::gpu_dfg::check_cuda_runtime_enabled();
}

bool checkCudaDeviceAvailable() {
return mlir::concretelang::gpu_dfg::check_cuda_device_available();
}

std::string roundTrip(const char *module) {
std::shared_ptr<mlir::concretelang::CompilationContext> ccx =
mlir::concretelang::CompilationContext::createShared();
Expand Down Expand Up @@ -673,6 +682,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
m.def("terminate_df_parallelization", &terminateDataflowParallelization);

m.def("init_df_parallelization", &initDataflowParallelization);
m.def("check_gpu_runtime_enabled", &checkGPURuntimeEnabled);
m.def("check_cuda_device_available", &checkCudaDeviceAvailable);

pybind11::enum_<mlir::concretelang::Backend>(m, "Backend")
.value("CPU", mlir::concretelang::Backend::CPU)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from mlir._mlir_libs._concretelang._compiler import (
terminate_df_parallelization as _terminate_df_parallelization,
init_df_parallelization as _init_df_parallelization,
check_gpu_runtime_enabled as _check_gpu_runtime_enabled,
check_cuda_device_available as _check_cuda_device_available,
)
from mlir._mlir_libs._concretelang._compiler import round_trip as _round_trip
from mlir._mlir_libs._concretelang._compiler import (
Expand Down Expand Up @@ -49,6 +51,18 @@ def init_dfr():
_init_df_parallelization()


def check_gpu_enabled() -> bool:
"""Check whether the compiler and runtime support GPU offloading.
GPU offloading is not always available, in particular in non-GPU wheels."""
return _check_gpu_runtime_enabled()


def check_gpu_available() -> bool:
"""Check whether a CUDA device is available and online."""
return _check_cuda_device_available()


# Cleanly terminate the dataflow runtime if it has been initialized
# (does nothing otherwise)
atexit.register(_terminate_df_parallelization)
Expand Down

0 comments on commit 711758c

Please sign in to comment.