Skip to content

Commit

Permalink
[rhi] Update CommandList dispatch API (taichi-dev#7052)
Browse files Browse the repository at this point in the history
Issue: taichi-dev#6832

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and quadpixels committed May 13, 2023
1 parent af8e17f commit 6020a5a
Show file tree
Hide file tree
Showing 12 changed files with 91 additions and 25 deletions.
5 changes: 3 additions & 2 deletions taichi/rhi/amdgpu/amdgpu_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ class AmdgpuCommandList : public CommandList {
TI_NOT_IMPLEMENTED};
void buffer_fill(DevicePtr ptr, size_t size, uint32_t data) noexcept final{
TI_NOT_IMPLEMENTED};
void dispatch(uint32_t x, uint32_t y = 1, uint32_t z = 1) override{
TI_NOT_IMPLEMENTED};
RhiResult dispatch(uint32_t x,
uint32_t y = 1,
uint32_t z = 1) noexcept override{TI_NOT_IMPLEMENTED};
};

class AmdgpuStream : public Stream {
Expand Down
5 changes: 3 additions & 2 deletions taichi/rhi/cpu/cpu_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ class CpuCommandList : public CommandList {
TI_NOT_IMPLEMENTED};
void buffer_fill(DevicePtr ptr, size_t size, uint32_t data) noexcept override{
TI_NOT_IMPLEMENTED};
void dispatch(uint32_t x, uint32_t y = 1, uint32_t z = 1) override{
TI_NOT_IMPLEMENTED};
RhiResult dispatch(uint32_t x,
uint32_t y = 1,
uint32_t z = 1) noexcept override{TI_NOT_IMPLEMENTED};
};

class CpuStream : public Stream {
Expand Down
5 changes: 3 additions & 2 deletions taichi/rhi/cuda/cuda_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ class CudaCommandList : public CommandList {
TI_NOT_IMPLEMENTED};
void buffer_fill(DevicePtr ptr, size_t size, uint32_t data) noexcept override{
TI_NOT_IMPLEMENTED};
void dispatch(uint32_t x, uint32_t y = 1, uint32_t z = 1) override{
TI_NOT_IMPLEMENTED};
RhiResult dispatch(uint32_t x,
uint32_t y = 1,
uint32_t z = 1) noexcept override{TI_NOT_IMPLEMENTED};
};

class CudaStream : public Stream {
Expand Down
49 changes: 42 additions & 7 deletions taichi/rhi/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,25 +360,60 @@ class TI_DLL_EXPORT CommandList {
* - (Encouraged behavior) If the `size` is -1 (max of size_t) the underlying
* API might provide a faster code path.
* @params[in] ptr The start of the memory region.
* ptr.offset will be aligned down to a multiple of 4 bytes.
* - ptr.offset will be aligned down to a multiple of 4 bytes.
* @params[in] size The size of the region.
* The size will be clamped to the underlying buffer's size.
* - The size will be clamped to the underlying buffer's size.
*/
virtual void buffer_fill(DevicePtr ptr,
size_t size,
uint32_t data) noexcept = 0;

virtual void dispatch(uint32_t x, uint32_t y = 1, uint32_t z = 1) = 0;
/**
* Enqueues a compute operation with {X, Y, Z} amount of workgroups.
* The block size / workgroup size is pre-determined within the pipeline.
* - This is only valid if the pipeline has a predetermined block size
* - This API has a device-dependent variable max values for X, Y, Z
* - The currently bound pipeline will be dispatched
* - The enqueued operation starts in CommandList API ordering.
* - The enqueued operation may end out-of-order, but it respects barriers
* @params[in] x The number of workgroups in X dimension
* @params[in] y The number of workgroups in Y dimension
* @params[in] z The number of workgroups in Y dimension
* @return The status of this operation
* - `success` if the operation is successful
* - `invalid_operation` if the current pipeline has variable block size
* - `not_supported` if the requested X, Y, or Z is not supported
*/
virtual RhiResult dispatch(uint32_t x,
uint32_t y = 1,
uint32_t z = 1) noexcept = 0;

struct ComputeSize {
uint32_t x{0};
uint32_t y{0};
uint32_t z{0};
};
// Some GPU APIs can set the block (workgroup, threadsgroup) size at
// dispatch time.
virtual void dispatch(ComputeSize grid_size, ComputeSize block_size) {
dispatch(grid_size.x, grid_size.y, grid_size.z);

/**
* Enqueues a compute operation with `grid_size` amount of threads.
* The workgroup size is dynamic and specified through `block_size`
* - This is only valid if the pipeline has a predetermined block size
* - This API has a device-dependent variable max values for `grid_size`
* - This API has a device-dependent supported values for `block_size`
* - The currently bound pipeline will be dispatched
* - The enqueued operation starts in CommandList API ordering.
* - The enqueued operation may end out-of-order, but it respects barriers
* @params[in] grid_size The number of threads dispatch
* @params[in] block_size The shape of each block / workgroup / threadsgroup
* @return The status of this operation
* - `success` if the operation is successful
* - `invalid_operation` if the current pipeline has variable block size
* - `not_supported` if the requested sizes are not supported
* - `error` if the operation failed due to other reasons
*/
virtual RhiResult dispatch(ComputeSize grid_size,
ComputeSize block_size) noexcept {
return RhiResult::not_supported;
}

// These are not implemented in compute only device
Expand Down
6 changes: 5 additions & 1 deletion taichi/rhi/dx/dx_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ void Dx11CommandList::buffer_fill(DevicePtr ptr,
// FIXME: what if the default is not a raw buffer?
}

void Dx11CommandList::dispatch(uint32_t x, uint32_t y, uint32_t z) {
RhiResult Dx11CommandList::dispatch(uint32_t x,
uint32_t y,
uint32_t z) noexcept {
// Set SPIRV_Cross_NumWorkgroups's CB slot based on the watermark
auto cb_slot = cb_slot_watermark_ + 1;
auto spirv_cross_numworkgroups_cb =
Expand All @@ -226,6 +228,8 @@ void Dx11CommandList::dispatch(uint32_t x, uint32_t y, uint32_t z) {
cb_slot_watermark_ = -1;

d3d11_deferred_context_->Dispatch(x, y, z);

return RhiResult::success;
}

void Dx11CommandList::begin_renderpass(int x0,
Expand Down
2 changes: 1 addition & 1 deletion taichi/rhi/dx/dx_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class Dx11CommandList : public CommandList {
void memory_barrier() noexcept final;
void buffer_copy(DevicePtr dst, DevicePtr src, size_t size) noexcept final;
void buffer_fill(DevicePtr ptr, size_t size, uint32_t data) noexcept final;
void dispatch(uint32_t x, uint32_t y = 1, uint32_t z = 1) override;
RhiResult dispatch(uint32_t x, uint32_t y = 1, uint32_t z = 1) noexcept final;

// These are not implemented in compute only device
void begin_renderpass(int x0,
Expand Down
15 changes: 10 additions & 5 deletions taichi/rhi/metal/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,14 +167,16 @@ class CommandListImpl : public CommandList {
finish_encoder(encoder.get());
}

void dispatch(uint32_t x, uint32_t y, uint32_t z) override {
RhiResult dispatch(uint32_t x, uint32_t y, uint32_t z) noexcept final {
TI_ERROR("Please call dispatch(grid_size, block_size) instead");
}

void dispatch(CommandList::ComputeSize grid_size,
CommandList::ComputeSize block_size) override {
RhiResult dispatch(CommandList::ComputeSize grid_size,
CommandList::ComputeSize block_size) noexcept override {
auto encoder = new_compute_command_encoder(command_buffer_.get());
TI_ASSERT(encoder != nullptr);
if (encoder == nullptr) {
return RhiResult::error;
}
metal::set_label(encoder.get(), inflight_label_);
const auto &builder = inflight_compute_builder_.value();
set_compute_pipeline_state(encoder.get(), builder.pipeline);
Expand All @@ -183,7 +185,9 @@ class CommandListImpl : public CommandList {
};
for (const auto &[idx, b] : builder.binding_map) {
auto *buf = alloc_buf_mapper_->find(b.alloc_id).buffer;
TI_ASSERT(buf != nullptr);
if (buf == nullptr) {
return RhiResult::error;
}
set_mtl_buffer(encoder.get(), buf, b.offset, idx);
}
const auto num_blocks_x = ceil_div(grid_size.x, block_size.x);
Expand All @@ -193,6 +197,7 @@ class CommandListImpl : public CommandList {
num_blocks_z, block_size.x, block_size.y,
block_size.z);
finish_encoder(encoder.get());
return RhiResult::success;
}

// Graphics commands are not implemented on Metal
Expand Down
3 changes: 2 additions & 1 deletion taichi/rhi/opengl/opengl_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,12 +382,13 @@ void GLCommandList::buffer_fill(DevicePtr ptr,
recorded_commands_.push_back(std::move(cmd));
}

void GLCommandList::dispatch(uint32_t x, uint32_t y, uint32_t z) {
RhiResult GLCommandList::dispatch(uint32_t x, uint32_t y, uint32_t z) noexcept {
auto cmd = std::make_unique<CmdDispatch>();
cmd->x = x;
cmd->y = y;
cmd->z = z;
recorded_commands_.push_back(std::move(cmd));
return RhiResult::success;
}

void GLCommandList::begin_renderpass(int x0,
Expand Down
2 changes: 1 addition & 1 deletion taichi/rhi/opengl/opengl_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class GLCommandList : public CommandList {
void memory_barrier() noexcept final;
void buffer_copy(DevicePtr dst, DevicePtr src, size_t size) noexcept final;
void buffer_fill(DevicePtr ptr, size_t size, uint32_t data) noexcept final;
void dispatch(uint32_t x, uint32_t y = 1, uint32_t z = 1) override;
RhiResult dispatch(uint32_t x, uint32_t y = 1, uint32_t z = 1) noexcept final;

// These are not implemented in compute only device
void begin_renderpass(int x0,
Expand Down
13 changes: 12 additions & 1 deletion taichi/rhi/vulkan/vulkan_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1080,8 +1080,17 @@ void VulkanCommandList::buffer_fill(DevicePtr ptr,
buffer_->refs.push_back(buffer);
}

void VulkanCommandList::dispatch(uint32_t x, uint32_t y, uint32_t z) {
RhiResult VulkanCommandList::dispatch(uint32_t x,
uint32_t y,
uint32_t z) noexcept {
auto &dev_props = ti_device_->get_vk_physical_device_props();
if (x > dev_props.limits.maxComputeWorkGroupCount[0] ||
y > dev_props.limits.maxComputeWorkGroupCount[1] ||
z > dev_props.limits.maxComputeWorkGroupCount[2]) {
return RhiResult::not_supported;
}
vkCmdDispatch(buffer_->buffer, x, y, z);
return RhiResult::success;
}

vkapi::IVkCommandBuffer VulkanCommandList::vk_command_buffer() {
Expand Down Expand Up @@ -1553,6 +1562,8 @@ void VulkanDevice::init_vulkan_structs(Params &params) {
create_vma_allocator();
RHI_ASSERT(new_descriptor_pool() == RhiResult::success &&
"Failed to allocate initial descriptor pool");

vkGetPhysicalDeviceProperties(physical_device_, &vk_device_properties_);
}

VulkanDevice::~VulkanDevice() {
Expand Down
7 changes: 6 additions & 1 deletion taichi/rhi/vulkan/vulkan_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ class VulkanCommandList : public CommandList {
void memory_barrier() noexcept final;
void buffer_copy(DevicePtr dst, DevicePtr src, size_t size) noexcept final;
void buffer_fill(DevicePtr ptr, size_t size, uint32_t data) noexcept final;
void dispatch(uint32_t x, uint32_t y = 1, uint32_t z = 1) override;
RhiResult dispatch(uint32_t x, uint32_t y = 1, uint32_t z = 1) noexcept final;
void begin_renderpass(int x0,
int y0,
int x1,
Expand Down Expand Up @@ -715,13 +715,18 @@ class TI_DLL_EXPORT VulkanDevice : public GraphicsDevice {
return vk_caps_;
}

const VkPhysicalDeviceProperties &get_vk_physical_device_props() const {
return vk_device_properties_;
}

private:
friend VulkanSurface;

void create_vma_allocator();
[[nodiscard]] RhiResult new_descriptor_pool();

VulkanCapabilities vk_caps_;
VkPhysicalDeviceProperties vk_device_properties_;

VkInstance instance_{VK_NULL_HANDLE};
VkDevice device_{VK_NULL_HANDLE};
Expand Down
4 changes: 3 additions & 1 deletion taichi/runtime/gfx/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,9 @@ void GfxRuntime::launch_kernel(KernelHandle handle, RuntimeContext *host_ctx) {
RhiResult status = current_cmdlist_->bind_shader_resources(bindings.get());
TI_ERROR_IF(status != RhiResult::success,
"Resource binding error : RhiResult({})", status);
current_cmdlist_->dispatch(group_x);
status = current_cmdlist_->dispatch(group_x);
TI_ERROR_IF(status != RhiResult::success, "Dispatch error : RhiResult({})",
status);
current_cmdlist_->memory_barrier();
}

Expand Down

0 comments on commit 6020a5a

Please sign in to comment.