Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rhi] Update CommandList dispatch API #7052

Merged
merged 10 commits into from
Jan 5, 2023
5 changes: 3 additions & 2 deletions taichi/rhi/amdgpu/amdgpu_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ class AmdgpuCommandList : public CommandList {
TI_NOT_IMPLEMENTED};
void buffer_fill(DevicePtr ptr, size_t size, uint32_t data) noexcept final{
TI_NOT_IMPLEMENTED};
void dispatch(uint32_t x, uint32_t y = 1, uint32_t z = 1) override{
TI_NOT_IMPLEMENTED};
RhiResult dispatch(uint32_t x,
uint32_t y = 1,
uint32_t z = 1) noexcept override{TI_NOT_IMPLEMENTED};
};

class AmdgpuStream : public Stream {
Expand Down
5 changes: 3 additions & 2 deletions taichi/rhi/cpu/cpu_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ class CpuCommandList : public CommandList {
TI_NOT_IMPLEMENTED};
void buffer_fill(DevicePtr ptr, size_t size, uint32_t data) noexcept override{
TI_NOT_IMPLEMENTED};
void dispatch(uint32_t x, uint32_t y = 1, uint32_t z = 1) override{
TI_NOT_IMPLEMENTED};
RhiResult dispatch(uint32_t x,
uint32_t y = 1,
uint32_t z = 1) noexcept override{TI_NOT_IMPLEMENTED};
};

class CpuStream : public Stream {
Expand Down
5 changes: 3 additions & 2 deletions taichi/rhi/cuda/cuda_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ class CudaCommandList : public CommandList {
TI_NOT_IMPLEMENTED};
void buffer_fill(DevicePtr ptr, size_t size, uint32_t data) noexcept override{
TI_NOT_IMPLEMENTED};
void dispatch(uint32_t x, uint32_t y = 1, uint32_t z = 1) override{
TI_NOT_IMPLEMENTED};
RhiResult dispatch(uint32_t x,
uint32_t y = 1,
uint32_t z = 1) noexcept override{TI_NOT_IMPLEMENTED};
};

class CudaStream : public Stream {
Expand Down
49 changes: 42 additions & 7 deletions taichi/rhi/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,25 +360,60 @@ class TI_DLL_EXPORT CommandList {
* - (Encouraged behavior) If the `size` is -1 (max of size_t) the underlying
* API might provide a faster code path.
* @params[in] ptr The start of the memory region.
* ptr.offset will be aligned down to a multiple of 4 bytes.
* - ptr.offset will be aligned down to a multiple of 4 bytes.
* @params[in] size The size of the region.
* The size will be clamped to the underlying buffer's size.
* - The size will be clamped to the underlying buffer's size.
*/
virtual void buffer_fill(DevicePtr ptr,
size_t size,
uint32_t data) noexcept = 0;

virtual void dispatch(uint32_t x, uint32_t y = 1, uint32_t z = 1) = 0;
/**
* Enqueues a compute operation with {X, Y, Z} amount of workgroups.
* The block size / workgroup size is pre-determined within the pipeline.
* - This is only valid if the pipeline has a predetermined block size
* - This API has a device-dependent variable max values for X, Y, Z
* - The currently bound pipeline will be dispatched
* - The enqueued operation starts in CommandList API ordering.
* - The enqueued operation may end out-of-order, but it respects barriers
* @params[in] x The number of workgroups in X dimension
* @params[in] y The number of workgroups in Y dimension
* @params[in] z The number of workgroups in Y dimension
* @return The status of this operation
* - `success` if the operation is successful
* - `invalid_operation` if the current pipeline has variable block size
* - `not_supported` if the requested X, Y, or Z is not supported
*/
virtual RhiResult dispatch(uint32_t x,
uint32_t y = 1,
uint32_t z = 1) noexcept = 0;

struct ComputeSize {
uint32_t x{0};
uint32_t y{0};
uint32_t z{0};
};
// Some GPU APIs can set the block (workgroup, threadsgroup) size at
// dispatch time.
virtual void dispatch(ComputeSize grid_size, ComputeSize block_size) {
dispatch(grid_size.x, grid_size.y, grid_size.z);

/**
* Enqueues a compute operation with `grid_size` amount of threads.
* The workgroup size is dynamic and specified through `block_size`
* - This is only valid if the pipeline has a predetermined block size
* - This API has a device-dependent variable max values for `grid_size`
* - This API has a device-dependent supported values for `block_size`
* - The currently bound pipeline will be dispatched
* - The enqueued operation starts in CommandList API ordering.
* - The enqueued operation may end out-of-order, but it respects barriers
* @params[in] grid_size The number of threads dispatch
* @params[in] block_size The shape of each block / workgroup / threadsgroup
* @return The status of this operation
* - `success` if the operation is successful
* - `invalid_operation` if the current pipeline has variable block size
* - `not_supported` if the requested sizes are not supported
* - `error` if the operation failed due to other reasons
*/
virtual RhiResult dispatch(ComputeSize grid_size,
ComputeSize block_size) noexcept {
return RhiResult::not_supported;
}

// These are not implemented in compute only device
Expand Down
6 changes: 5 additions & 1 deletion taichi/rhi/dx/dx_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ void Dx11CommandList::buffer_fill(DevicePtr ptr,
// FIXME: what if the default is not a raw buffer?
}

void Dx11CommandList::dispatch(uint32_t x, uint32_t y, uint32_t z) {
RhiResult Dx11CommandList::dispatch(uint32_t x,
uint32_t y,
uint32_t z) noexcept {
// Set SPIRV_Cross_NumWorkgroups's CB slot based on the watermark
auto cb_slot = cb_slot_watermark_ + 1;
auto spirv_cross_numworkgroups_cb =
Expand All @@ -226,6 +228,8 @@ void Dx11CommandList::dispatch(uint32_t x, uint32_t y, uint32_t z) {
cb_slot_watermark_ = -1;

d3d11_deferred_context_->Dispatch(x, y, z);

return RhiResult::success;
}

void Dx11CommandList::begin_renderpass(int x0,
Expand Down
2 changes: 1 addition & 1 deletion taichi/rhi/dx/dx_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class Dx11CommandList : public CommandList {
void memory_barrier() noexcept final;
void buffer_copy(DevicePtr dst, DevicePtr src, size_t size) noexcept final;
void buffer_fill(DevicePtr ptr, size_t size, uint32_t data) noexcept final;
void dispatch(uint32_t x, uint32_t y = 1, uint32_t z = 1) override;
RhiResult dispatch(uint32_t x, uint32_t y = 1, uint32_t z = 1) noexcept final;

// These are not implemented in compute only device
void begin_renderpass(int x0,
Expand Down
15 changes: 10 additions & 5 deletions taichi/rhi/metal/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,14 +167,16 @@ class CommandListImpl : public CommandList {
finish_encoder(encoder.get());
}

void dispatch(uint32_t x, uint32_t y, uint32_t z) override {
RhiResult dispatch(uint32_t x, uint32_t y, uint32_t z) noexcept final {
TI_ERROR("Please call dispatch(grid_size, block_size) instead");
}

void dispatch(CommandList::ComputeSize grid_size,
CommandList::ComputeSize block_size) override {
RhiResult dispatch(CommandList::ComputeSize grid_size,
CommandList::ComputeSize block_size) noexcept override {
auto encoder = new_compute_command_encoder(command_buffer_.get());
TI_ASSERT(encoder != nullptr);
if (encoder == nullptr) {
return RhiResult::error;
}
metal::set_label(encoder.get(), inflight_label_);
const auto &builder = inflight_compute_builder_.value();
set_compute_pipeline_state(encoder.get(), builder.pipeline);
Expand All @@ -183,7 +185,9 @@ class CommandListImpl : public CommandList {
};
for (const auto &[idx, b] : builder.binding_map) {
auto *buf = alloc_buf_mapper_->find(b.alloc_id).buffer;
TI_ASSERT(buf != nullptr);
if (buf == nullptr) {
return RhiResult::error;
}
set_mtl_buffer(encoder.get(), buf, b.offset, idx);
}
const auto num_blocks_x = ceil_div(grid_size.x, block_size.x);
Expand All @@ -193,6 +197,7 @@ class CommandListImpl : public CommandList {
num_blocks_z, block_size.x, block_size.y,
block_size.z);
finish_encoder(encoder.get());
return RhiResult::success;
}

// Graphics commands are not implemented on Metal
Expand Down
3 changes: 2 additions & 1 deletion taichi/rhi/opengl/opengl_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,12 +382,13 @@ void GLCommandList::buffer_fill(DevicePtr ptr,
recorded_commands_.push_back(std::move(cmd));
}

void GLCommandList::dispatch(uint32_t x, uint32_t y, uint32_t z) {
RhiResult GLCommandList::dispatch(uint32_t x, uint32_t y, uint32_t z) noexcept {
auto cmd = std::make_unique<CmdDispatch>();
cmd->x = x;
cmd->y = y;
cmd->z = z;
recorded_commands_.push_back(std::move(cmd));
return RhiResult::success;
}

void GLCommandList::begin_renderpass(int x0,
Expand Down
2 changes: 1 addition & 1 deletion taichi/rhi/opengl/opengl_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class GLCommandList : public CommandList {
void memory_barrier() noexcept final;
void buffer_copy(DevicePtr dst, DevicePtr src, size_t size) noexcept final;
void buffer_fill(DevicePtr ptr, size_t size, uint32_t data) noexcept final;
void dispatch(uint32_t x, uint32_t y = 1, uint32_t z = 1) override;
RhiResult dispatch(uint32_t x, uint32_t y = 1, uint32_t z = 1) noexcept final;

// These are not implemented in compute only device
void begin_renderpass(int x0,
Expand Down
13 changes: 12 additions & 1 deletion taichi/rhi/vulkan/vulkan_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1080,8 +1080,17 @@ void VulkanCommandList::buffer_fill(DevicePtr ptr,
buffer_->refs.push_back(buffer);
}

void VulkanCommandList::dispatch(uint32_t x, uint32_t y, uint32_t z) {
RhiResult VulkanCommandList::dispatch(uint32_t x,
uint32_t y,
uint32_t z) noexcept {
auto &dev_props = ti_device_->get_vk_physical_device_props();
if (x > dev_props.limits.maxComputeWorkGroupCount[0] ||
y > dev_props.limits.maxComputeWorkGroupCount[1] ||
z > dev_props.limits.maxComputeWorkGroupCount[2]) {
return RhiResult::not_supported;
}
vkCmdDispatch(buffer_->buffer, x, y, z);
return RhiResult::success;
}

vkapi::IVkCommandBuffer VulkanCommandList::vk_command_buffer() {
Expand Down Expand Up @@ -1553,6 +1562,8 @@ void VulkanDevice::init_vulkan_structs(Params &params) {
create_vma_allocator();
RHI_ASSERT(new_descriptor_pool() == RhiResult::success &&
"Failed to allocate initial descriptor pool");

vkGetPhysicalDeviceProperties(physical_device_, &vk_device_properties_);
}

VulkanDevice::~VulkanDevice() {
Expand Down
7 changes: 6 additions & 1 deletion taichi/rhi/vulkan/vulkan_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ class VulkanCommandList : public CommandList {
void memory_barrier() noexcept final;
void buffer_copy(DevicePtr dst, DevicePtr src, size_t size) noexcept final;
void buffer_fill(DevicePtr ptr, size_t size, uint32_t data) noexcept final;
void dispatch(uint32_t x, uint32_t y = 1, uint32_t z = 1) override;
RhiResult dispatch(uint32_t x, uint32_t y = 1, uint32_t z = 1) noexcept final;
void begin_renderpass(int x0,
int y0,
int x1,
Expand Down Expand Up @@ -715,13 +715,18 @@ class TI_DLL_EXPORT VulkanDevice : public GraphicsDevice {
return vk_caps_;
}

const VkPhysicalDeviceProperties &get_vk_physical_device_props() const {
return vk_device_properties_;
}

private:
friend VulkanSurface;

void create_vma_allocator();
[[nodiscard]] RhiResult new_descriptor_pool();

VulkanCapabilities vk_caps_;
VkPhysicalDeviceProperties vk_device_properties_;

VkInstance instance_{VK_NULL_HANDLE};
VkDevice device_{VK_NULL_HANDLE};
Expand Down
4 changes: 3 additions & 1 deletion taichi/runtime/gfx/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,9 @@ void GfxRuntime::launch_kernel(KernelHandle handle, RuntimeContext *host_ctx) {
RhiResult status = current_cmdlist_->bind_shader_resources(bindings.get());
TI_ERROR_IF(status != RhiResult::success,
"Resource binding error : RhiResult({})", status);
current_cmdlist_->dispatch(group_x);
status = current_cmdlist_->dispatch(group_x);
TI_ERROR_IF(status != RhiResult::success, "Dispatch error : RhiResult({})",
status);
current_cmdlist_->memory_barrier();
}

Expand Down