diff --git a/tensorpipe/channel/cuda_ipc/context.cc b/tensorpipe/channel/cuda_ipc/context.cc index e2f7e0714..34c4cd4e6 100644 --- a/tensorpipe/channel/cuda_ipc/context.cc +++ b/tensorpipe/channel/cuda_ipc/context.cc @@ -45,8 +45,10 @@ std::shared_ptr makeCudaIpcChannel() { return std::make_shared(); } -// TODO: Make separate CUDA channel registry. -// TP_REGISTER_CREATOR(TensorpipeChannelRegistry, cuda_ipc, makeCudaIpcChannel); +TP_REGISTER_CREATOR( + TensorpipeCudaChannelRegistry, + cuda_ipc, + makeCudaIpcChannel); } // namespace diff --git a/tensorpipe/channel/registry.cc b/tensorpipe/channel/registry.cc index 429232dc7..a4bf9c2de 100644 --- a/tensorpipe/channel/registry.cc +++ b/tensorpipe/channel/registry.cc @@ -11,3 +11,9 @@ TP_DEFINE_SHARED_REGISTRY( TensorpipeChannelRegistry, tensorpipe::channel::CpuContext); + +#if TENSORPIPE_HAS_CUDA +TP_DEFINE_SHARED_REGISTRY( + TensorpipeCudaChannelRegistry, + tensorpipe::channel::CudaContext); +#endif // TENSORPIPE_HAS_CUDA diff --git a/tensorpipe/channel/registry.h b/tensorpipe/channel/registry.h index 07291bd4d..1e6f4f753 100644 --- a/tensorpipe/channel/registry.h +++ b/tensorpipe/channel/registry.h @@ -8,6 +8,8 @@ #pragma once +#include + #include #include #include @@ -15,3 +17,9 @@ TP_DECLARE_SHARED_REGISTRY( TensorpipeChannelRegistry, tensorpipe::channel::CpuContext); + +#if TENSORPIPE_HAS_CUDA +TP_DECLARE_SHARED_REGISTRY( + TensorpipeCudaChannelRegistry, + tensorpipe::channel::CudaContext); +#endif // TENSORPIPE_HAS_CUDA