Skip to content

Commit

Permalink
Move CreateTypedKernel functions to kernel.h.
Browse files Browse the repository at this point in the history
The functions don't use any of StreamExecutor's member data, and more logically belong with TypedKernel.

PiperOrigin-RevId: 609129818
  • Loading branch information
klucke authored and tensorflower-gardener committed Feb 21, 2024
1 parent 0904e4e commit 664db4f
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 57 deletions.
9 changes: 5 additions & 4 deletions third_party/xla/xla/service/gpu/buffer_comparator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,11 @@ static absl::StatusOr<bool> DeviceCompare(se::Stream* stream,

TF_ASSIGN_OR_RETURN(
ComparisonKernelT<ElementT> comparison_kernel,
(executor->CreateTypedKernel<se::DeviceMemory<ElementT>,
se::DeviceMemory<ElementT>, float, uint64_t,
se::DeviceMemory<uint64_t>>(kernel_name,
kernel_symbol)));
(se::TypedKernel<se::DeviceMemory<ElementT>, se::DeviceMemory<ElementT>,
float, uint64_t,
se::DeviceMemory<uint64_t>>::Create(executor,
kernel_name,
kernel_symbol)));

const se::DeviceDescription& gpu_device_info =
executor->GetDeviceDescription();
Expand Down
7 changes: 3 additions & 4 deletions third_party/xla/xla/service/gpu/kernels/topk_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,9 @@ absl::Status TypedTopK(se::Stream* stream, se::DeviceMemoryBase data,
TF_ASSIGN_OR_RETURN(void* kernel_symbol, GetKernel<T>(num_elements, k));
TF_ASSIGN_OR_RETURN(
auto kernel,
(executor
->CreateTypedKernel<se::DeviceMemory<T>, size_t, se::DeviceMemory<T>,
se::DeviceMemory<uint32_t>, size_t>(
"topk", kernel_symbol)));
(se::TypedKernel<se::DeviceMemory<T>, size_t, se::DeviceMemory<T>,
se::DeviceMemory<uint32_t>,
size_t>::Create(executor, "topk", kernel_symbol)));

TF_RETURN_IF_ERROR(stream->ThenLaunch(
se::ThreadDim(num_threads, 1, 1), se::BlockDim(batch_size, 1, 1),
Expand Down
8 changes: 5 additions & 3 deletions third_party/xla/xla/service/gpu/make_batch_pointers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,11 @@ absl::Status MakeBatchPointers(se::Stream* stream,
#else

TF_ASSIGN_OR_RETURN(
auto kernel, (executor->CreateTypedKernel<se::DeviceMemoryBase, size_t,
size_t, se::DeviceMemoryBase>(
"make_batch_pointers", make_batch_pointers::kernel())));
auto kernel,
(se::TypedKernel<
se::DeviceMemoryBase, size_t, size_t,
se::DeviceMemoryBase>::Create(executor, "make_batch_pointers",
make_batch_pointers::kernel())));

TF_RETURN_IF_ERROR(
stream->ThenLaunch(se::ThreadDim(kThreads, 1, 1),
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/stream_executor/gpu/asm_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ absl::StatusOr<TypedKernel<Args...>*> LoadKernelOrGetPtr(
if (it == kernel_ptr_cache.end()) {
TF_ASSIGN_OR_RETURN(
TypedKernel<Args...> loaded,
executor->CreateTypedKernel<Args...>(kernel_name, ptx, cubin_data));
(TypedKernel<Args...>::Create(executor, kernel_name, ptx, cubin_data)));
it =
kernel_ptr_cache.emplace(kernel_ptr_cache_key, std::move(loaded)).first;
}
Expand Down
6 changes: 3 additions & 3 deletions third_party/xla/xla/stream_executor/gpu/redzone_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,9 @@ absl::StatusOr<RedzoneCheckStatus> RedzoneAllocator::CheckRedzones() const {
#elif TENSORFLOW_USE_ROCM
TF_ASSIGN_OR_RETURN(
ComparisonKernelT loaded_kernel,
(executor->CreateTypedKernel<DeviceMemory<uint8>, uint8, uint64_t,
DeviceMemory<uint64_t>>("redzone_checker",
kernel_symbol())));
(TypedKernel<DeviceMemory<uint8>, uint8, uint64_t,
DeviceMemory<uint64_t>>::Create(executor, "redzone_checker",
kernel_symbol())));
// CUDA side returns a pointer => hence get a pointer to the loaded kernel
auto* kernel_ptr = &loaded_kernel;
#endif // GOOGLE_CUDA
Expand Down
39 changes: 39 additions & 0 deletions third_party/xla/xla/stream_executor/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,21 @@ class TypedKernel {
return TypedKernel(std::move(kernel));
}

// Creates a kernel which can be launched with `stream.ThenLaunch(...)` from a
// PTX (and optional CUBIN), such that the types of the arguments provided for
// launch would have to match types of the arguments provided at creation
// time. The canonical storage for both ptx and cubin_data should outlive the
// lifetime of the kernel.
static absl::StatusOr<TypedKernel> Create(
StreamExecutor *executor, absl::string_view kernel_name,
absl::string_view ptx, absl::Span<const uint8_t> cubin_data);

// Creates a kernel which can be launched with `stream.ThenLaunch(...)` from
// an in-process symbol pointer.
static absl::StatusOr<TypedKernel> Create(StreamExecutor *executor,
absl::string_view kernel_name,
void *symbol);

TypedKernel() = default;

Kernel &operator*() { return *kernel_; }
Expand Down Expand Up @@ -723,6 +738,30 @@ std::unique_ptr<KernelArgsPackedArrayBase> PackKernelArgs(
return std::make_unique<PackedArgs>(std::forward<Args>(args)..., shmem_bytes);
}

template <typename... Args>
inline absl::StatusOr<TypedKernel<Args...>> TypedKernel<Args...>::Create(
StreamExecutor *executor, absl::string_view kernel_name,
absl::string_view ptx, absl::Span<const uint8_t> cubin_data) {
MultiKernelLoaderSpec loader_spec(TypedKernel<Args...>::kNumberOfParameters);
loader_spec.AddCudaPtxInMemory(ptx, kernel_name);

if (!cubin_data.empty()) {
loader_spec.AddCudaCubinInMemory(
reinterpret_cast<const char *>(cubin_data.data()), kernel_name);
}

return TypedKernel<Args...>::Create(executor, loader_spec);
}

template <typename... Args>
inline absl::StatusOr<TypedKernel<Args...>> TypedKernel<Args...>::Create(
StreamExecutor *executor, absl::string_view kernel_name, void *symbol) {
MultiKernelLoaderSpec loader_spec(TypedKernel<Args...>::kNumberOfParameters);
loader_spec.AddInProcessSymbol(symbol, kernel_name);

return TypedKernel<Args...>::Create(executor, loader_spec);
}

} // namespace stream_executor

#endif // XLA_STREAM_EXECUTOR_KERNEL_H_
42 changes: 0 additions & 42 deletions third_party/xla/xla/stream_executor/stream_executor_pimpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,22 +278,6 @@ class StreamExecutor {
// Returns a borrowed pointer to the underlying StreamExecutor implementation.
internal::StreamExecutorInterface* implementation();

// Creates a kernel which can be launched with `stream.ThenLaunch(...)` from a
// PTX (and optional CUBIN), such that the types of the arguments provided for
// launch would have to match types of the arguments provided at creation
// time. The canonical storage for both ptx and cubin_data should outlive the
// lifetime of the kernel.
template <typename... Args>
absl::StatusOr<TypedKernel<Args...>> CreateTypedKernel(
absl::string_view kernel_name, absl::string_view ptx,
absl::Span<const uint8_t> cubin_data);

// Creates a kernel which can be launched with `stream.ThenLaunch(...)` from
// an in-process symbol pointer.
template <typename... Args>
absl::StatusOr<TypedKernel<Args...>> CreateTypedKernel(
absl::string_view kernel_name, void* symbol);

// Warning: use Stream::ThenLaunch instead, this method is not for general
// consumption. However, this is the only way to launch a kernel for which
// the type signature is only known at runtime; say, if an application
Expand Down Expand Up @@ -372,8 +356,6 @@ class StreamExecutor {
private:
friend class Event;
friend class Stream;
template <typename... Params>
friend class TypedKernel;
friend class HostMemoryAllocation;

// Deallocates a region of host memory allocated by HostMemoryAllocate().
Expand Down Expand Up @@ -551,30 +533,6 @@ class ScopedModuleHandle {
////////////
// Inlines

template <typename... Args>
inline absl::StatusOr<TypedKernel<Args...>> StreamExecutor::CreateTypedKernel(
absl::string_view kernel_name, absl::string_view ptx,
absl::Span<const uint8_t> cubin_data) {
MultiKernelLoaderSpec loader_spec(TypedKernel<Args...>::kNumberOfParameters);
loader_spec.AddCudaPtxInMemory(ptx, kernel_name);

if (!cubin_data.empty()) {
loader_spec.AddCudaCubinInMemory(
reinterpret_cast<const char*>(cubin_data.data()), kernel_name);
}

return TypedKernel<Args...>::Create(this, loader_spec);
}

template <typename... Args>
inline absl::StatusOr<TypedKernel<Args...>> StreamExecutor::CreateTypedKernel(
absl::string_view kernel_name, void* symbol) {
MultiKernelLoaderSpec loader_spec(TypedKernel<Args...>::kNumberOfParameters);
loader_spec.AddInProcessSymbol(symbol, kernel_name);

return TypedKernel<Args...>::Create(this, loader_spec);
}

template <typename T>
inline DeviceMemory<T> StreamExecutor::AllocateArray(uint64_t element_count,
int64_t memory_space) {
Expand Down

0 comments on commit 664db4f

Please sign in to comment.