Skip to content

Commit

Permalink
[build] Improve TI_WITH_CUDA guards for CUDA related test cases (taic…
Browse files Browse the repository at this point in the history
…hi-dev#6698)

Issue: #

### Brief Summary
  • Loading branch information
jim19930609 authored and quadpixels committed May 13, 2023
1 parent 534bfe0 commit 19e200c
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 11 deletions.
4 changes: 0 additions & 4 deletions c_api/include/taichi/taichi_cuda.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
#pragma once

#ifndef TI_WITH_CUDA
#define TI_WITH_CUDA 1
#endif // TI_WITH_CUDA

#include <taichi/taichi.h>

#ifdef __cplusplus
Expand Down
3 changes: 3 additions & 0 deletions c_api/src/c_api_test_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#include "c_api_test_utils.h"
#include "taichi_llvm_impl.h"
#include "taichi/platform/cuda/detect_cuda.h"

#ifdef TI_WITH_CUDA
#include "taichi/rhi/cuda/cuda_driver.h"
#endif

#ifdef TI_WITH_VULKAN
#include "taichi/rhi/vulkan/vulkan_loader.h"
Expand Down
9 changes: 8 additions & 1 deletion tests/cpp/aot/llvm/bitmasked_aot_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
#include "taichi/runtime/llvm/llvm_runtime_executor.h"
#include "taichi/system/memory_pool.h"
#include "taichi/runtime/cpu/aot_module_loader_impl.h"
#include "taichi/runtime/cuda/aot_module_loader_impl.h"
#include "taichi/runtime/llvm/llvm_aot_module_loader.h"

#ifdef TI_WITH_CUDA

#include "taichi/rhi/cuda/cuda_driver.h"
#include "taichi/platform/cuda/detect_cuda.h"
#include "taichi/runtime/cuda/aot_module_loader_impl.h"

#endif

#define TI_RUNTIME_HOST
#include "taichi/program/context.h"
Expand Down Expand Up @@ -86,6 +91,7 @@ TEST(LlvmAotTest, CpuBitmasked) {
}

TEST(LlvmAotTest, CudaBitmasked) {
#ifdef TI_WITH_CUDA
if (is_cuda_api_available()) {
CompileConfig cfg;
cfg.arch = Arch::cuda;
Expand All @@ -108,6 +114,7 @@ TEST(LlvmAotTest, CudaBitmasked) {

run_bitmasked_tests(mod.get(), &exec, result_buffer);
}
#endif
}

} // namespace taichi::lang
9 changes: 8 additions & 1 deletion tests/cpp/aot/llvm/dynamic_aot_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
#include "taichi/runtime/llvm/llvm_runtime_executor.h"
#include "taichi/system/memory_pool.h"
#include "taichi/runtime/cpu/aot_module_loader_impl.h"
#include "taichi/runtime/cuda/aot_module_loader_impl.h"
#include "taichi/runtime/llvm/llvm_aot_module_loader.h"

#ifdef TI_WITH_CUDA

#include "taichi/rhi/cuda/cuda_driver.h"
#include "taichi/platform/cuda/detect_cuda.h"
#include "taichi/runtime/cuda/aot_module_loader_impl.h"

#endif

#define TI_RUNTIME_HOST
#include "taichi/program/context.h"
Expand Down Expand Up @@ -86,6 +91,7 @@ TEST(LlvmAotTest, CpuDynamic) {
}

TEST(LlvmAotTest, CudaDynamic) {
#ifdef TI_WITH_CUDA
if (is_cuda_api_available()) {
CompileConfig cfg;
cfg.arch = Arch::cuda;
Expand All @@ -108,6 +114,7 @@ TEST(LlvmAotTest, CudaDynamic) {

run_dynamic_tests(mod.get(), &exec, result_buffer);
}
#endif
}

} // namespace taichi::lang
9 changes: 8 additions & 1 deletion tests/cpp/aot/llvm/field_aot_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
#include "taichi/runtime/llvm/llvm_runtime_executor.h"
#include "taichi/system/memory_pool.h"
#include "taichi/runtime/cpu/aot_module_loader_impl.h"
#include "taichi/runtime/cuda/aot_module_loader_impl.h"
#include "taichi/runtime/llvm/llvm_aot_module_loader.h"

#ifdef TI_WITH_CUDA

#include "taichi/rhi/cuda/cuda_driver.h"
#include "taichi/platform/cuda/detect_cuda.h"
#include "taichi/runtime/cuda/aot_module_loader_impl.h"

#endif

#define TI_RUNTIME_HOST
#include "taichi/program/context.h"
Expand Down Expand Up @@ -118,6 +123,7 @@ TEST(LlvmAotTest, CpuField) {
}

TEST(LlvmAotTest, CudaField) {
#ifdef TI_WITH_CUDA
if (is_cuda_api_available()) {
CompileConfig cfg;
cfg.arch = Arch::cuda;
Expand All @@ -140,6 +146,7 @@ TEST(LlvmAotTest, CudaField) {

run_field_tests(mod.get(), &exec, result_buffer);
}
#endif
}

} // namespace taichi::lang
9 changes: 8 additions & 1 deletion tests/cpp/aot/llvm/field_cgraph_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
#include "taichi/runtime/llvm/llvm_runtime_executor.h"
#include "taichi/system/memory_pool.h"
#include "taichi/runtime/cpu/aot_module_loader_impl.h"
#include "taichi/runtime/cuda/aot_module_loader_impl.h"
#include "taichi/runtime/llvm/llvm_aot_module_loader.h"

#ifdef TI_WITH_CUDA

#include "taichi/rhi/cuda/cuda_driver.h"
#include "taichi/platform/cuda/detect_cuda.h"
#include "taichi/runtime/cuda/aot_module_loader_impl.h"

#endif

#define TI_RUNTIME_HOST
#include "taichi/program/context.h"
Expand Down Expand Up @@ -64,6 +69,7 @@ TEST(LlvmCGraph, CpuField) {
}

TEST(LlvmCGraph, CudaField) {
#ifdef TI_WITH_CUDA
if (is_cuda_api_available()) {
CompileConfig cfg;
cfg.arch = Arch::cuda;
Expand All @@ -86,6 +92,7 @@ TEST(LlvmCGraph, CudaField) {

run_graph_tests(mod.get(), &exec, result_buffer);
}
#endif
}

} // namespace taichi::lang
9 changes: 8 additions & 1 deletion tests/cpp/aot/llvm/graph_aot_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
#include "taichi/runtime/program_impls/llvm/llvm_program.h"
#include "taichi/system/memory_pool.h"
#include "taichi/runtime/cpu/aot_module_loader_impl.h"
#include "taichi/runtime/cuda/aot_module_loader_impl.h"
#include "taichi/runtime/llvm/llvm_aot_module_loader.h"

#ifdef TI_WITH_CUDA

#include "taichi/rhi/cuda/cuda_driver.h"
#include "taichi/platform/cuda/detect_cuda.h"
#include "taichi/runtime/cuda/aot_module_loader_impl.h"

#endif

#define TI_RUNTIME_HOST
#include "taichi/program/context.h"
Expand Down Expand Up @@ -70,6 +75,7 @@ TEST(LlvmCGraph, RunGraphCpu) {
}

TEST(LlvmCGraph, RunGraphCuda) {
#ifdef TI_WITH_CUDA
if (is_cuda_api_available()) {
CompileConfig cfg;
cfg.arch = Arch::cuda;
Expand Down Expand Up @@ -126,4 +132,5 @@ TEST(LlvmCGraph, RunGraphCuda) {
EXPECT_EQ(cpu_data[i], 3 * i + base0 + base1 + base2);
}
}
#endif
}
9 changes: 8 additions & 1 deletion tests/cpp/aot/llvm/kernel_aot_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
#include "taichi/runtime/llvm/llvm_runtime_executor.h"
#include "taichi/system/memory_pool.h"
#include "taichi/runtime/cpu/aot_module_loader_impl.h"
#include "taichi/runtime/cuda/aot_module_loader_impl.h"
#include "taichi/runtime/dx12/aot_module_loader_impl.h"

#ifdef TI_WITH_CUDA

#include "taichi/rhi/cuda/cuda_driver.h"
#include "taichi/platform/cuda/detect_cuda.h"
#include "taichi/runtime/cuda/aot_module_loader_impl.h"

#endif

#define TI_RUNTIME_HOST
#include "taichi/program/context.h"
Expand Down Expand Up @@ -61,6 +66,7 @@ TEST(LlvmAotTest, CpuKernel) {
}

TEST(LlvmAotTest, CudaKernel) {
#ifdef TI_WITH_CUDA
if (is_cuda_api_available()) {
CompileConfig cfg;
cfg.arch = Arch::cuda;
Expand Down Expand Up @@ -108,6 +114,7 @@ TEST(LlvmAotTest, CudaKernel) {
EXPECT_EQ(cpu_data[i], i + vec[0]);
}
}
#endif
}

#ifdef TI_WITH_DX12
Expand Down
9 changes: 8 additions & 1 deletion tests/cpp/aot/llvm/mpm88_graph_aot_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
#include "taichi/runtime/llvm/llvm_runtime_executor.h"
#include "taichi/system/memory_pool.h"
#include "taichi/runtime/cpu/aot_module_loader_impl.h"
#include "taichi/runtime/cuda/aot_module_loader_impl.h"
#include "taichi/runtime/llvm/llvm_aot_module_loader.h"

#ifdef TI_WITH_CUDA

#include "taichi/rhi/cuda/cuda_driver.h"
#include "taichi/platform/cuda/detect_cuda.h"
#include "taichi/runtime/cuda/aot_module_loader_impl.h"

#endif

#define TI_RUNTIME_HOST
#include "taichi/program/context.h"
Expand Down Expand Up @@ -104,6 +109,7 @@ TEST(LlvmCGraph, Mpm88Cpu) {
}

TEST(LlvmCGraph, Mpm88Cuda) {
#ifdef TI_WITH_CUDA
if (is_cuda_api_available()) {
CompileConfig cfg;
cfg.arch = Arch::cuda;
Expand Down Expand Up @@ -185,4 +191,5 @@ TEST(LlvmCGraph, Mpm88Cuda) {
g_update->run(args);
exec.synchronize();
}
#endif
}

0 comments on commit 19e200c

Please sign in to comment.