Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions backends/vulkan/runtime/VulkanBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,8 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
compute_graph->encode_prepack();
compute_graph->prepack();

// TODO(ssjia): remove this once we can batch compile compute pipelines
// during prepare().
compute_graph->encode_execute();

return Error::Ok;
Expand Down Expand Up @@ -567,9 +569,14 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
}
}

// propagate_resize() will re-encode the command buffer so that push
// constants are updated and DynamicDispatchNode can update the compute
// shader, global workgroup size, and local workgroup size to perform the
// model inference.
if (should_propagate_resize) {
compute_graph->propagate_resize();
}

compute_graph->execute();

for (size_t i = 0; i < compute_graph->outputs().size(); i++) {
Expand Down
4 changes: 3 additions & 1 deletion backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -678,11 +678,12 @@ void ComputeGraph::encode_execute() {
}
}

void ComputeGraph::execute() const {
void ComputeGraph::execute() {
vkapi::VulkanFence fence = context_->fences().get_fence();
context_->submit_cmd_to_gpu(fence.get_submit_handle());
fence.wait();
context_->fences().return_fence(fence);
execute_count_++;
}

void ComputeGraph::resize_input(
Expand All @@ -696,6 +697,7 @@ void ComputeGraph::propagate_resize() {
for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
node->trigger_resize(this);
}
encode_execute();
}

} // namespace vkcompute
7 changes: 6 additions & 1 deletion backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ class ComputeGraph final {

protected:
size_t values_in_use_ = 0;
size_t execute_count_ = 0;

public:
//
Expand Down Expand Up @@ -745,7 +746,7 @@ class ComputeGraph final {
//

void encode_execute();
void execute() const;
void execute();

//
// Dynamic Shape support
Expand All @@ -762,6 +763,10 @@ class ComputeGraph final {
return context_->adapter_ptr()->supports_int16_shader_types();
}

inline size_t execute_count() const {
return execute_count_;
}

/*
* Check whether the GPU supports 8 bit buffers.
*/
Expand Down
26 changes: 14 additions & 12 deletions backends/vulkan/runtime/graph/ops/DispatchNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,7 @@ void DispatchNode::encode(ComputeGraph* graph) {

std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();

std::array<uint8_t, kMaxPushConstantSize> push_constants_data;
uint32_t push_constants_offset = 0;

for (const auto& push_constant : push_constants_) {
push_constants_offset += push_constant.write(
push_constants_data.data(),
push_constants_offset,
kMaxPushConstantSize);
}
write_push_constant_data();

context->report_shader_dispatch_start(
shader_.kernel_name,
Expand All @@ -63,7 +55,7 @@ void DispatchNode::encode(ComputeGraph* graph) {
node_id_);

vkapi::DescriptorSet descriptor_set = context->get_descriptor_set(
shader_, local_workgroup_size_, spec_vars_, push_constants_offset);
shader_, local_workgroup_size_, spec_vars_, push_constants_offset_);

uint32_t idx = 0;
idx = bind_values_to_descriptor_set(
Expand All @@ -76,10 +68,20 @@ void DispatchNode::encode(ComputeGraph* graph) {
pipeline_barrier,
shader_,
global_workgroup_size_,
push_constants_data.data(),
push_constants_offset);
push_constants_data_.data(),
push_constants_offset_);

context->report_shader_dispatch_end();
}

void DispatchNode::write_push_constant_data() {
push_constants_offset_ = 0;
for (const auto& push_constant : push_constants_) {
push_constants_offset_ += push_constant.write(
push_constants_data_.data(),
push_constants_offset_,
kMaxPushConstantSize);
}
}

} // namespace vkcompute
6 changes: 6 additions & 0 deletions backends/vulkan/runtime/graph/ops/DispatchNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ class DispatchNode : public ExecuteNode {
const vkapi::SpecVarList spec_vars_;
const std::vector<PushConstantDataInfo> push_constants_;

// For push constants
std::array<uint8_t, kMaxPushConstantSize> push_constants_data_{};
uint32_t push_constants_offset_ = 0;

void write_push_constant_data();

public:
operator bool() const {
return shader_;
Expand Down
58 changes: 51 additions & 7 deletions backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ DynamicDispatchNode::DynamicDispatchNode(
const ResizeFunction& resize_fn)
: DispatchNode(
graph,
pick_shader_fn(&graph, args, resize_args),
pick_global_wg_fn(&graph, args, resize_args),
pick_local_wg_fn(&graph, args, resize_args),
vkapi::ShaderInfo(),
{1u, 1u, 1u},
{1u, 1u, 1u},
args,
params,
push_constants,
Expand All @@ -36,13 +36,57 @@ DynamicDispatchNode::DynamicDispatchNode(
resize_fn),
pick_shader_fn_(pick_shader_fn),
pick_global_wg_fn_(pick_global_wg_fn),
pick_local_wg_fn_(pick_local_wg_fn) {
shader_ = pick_shader_fn(&graph, args, resize_args);
global_workgroup_size_ =
pick_global_wg_fn(&graph, shader_, args, resize_args);
local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn(
&graph, shader_, global_workgroup_size_, args, resize_args));
}

DynamicDispatchNode::DynamicDispatchNode(
ComputeGraph& graph,
const vkapi::ShaderInfo& shader,
const PickGlobalFn& pick_global_wg_fn,
const PickLocalFn& pick_local_wg_fn,
const std::vector<ArgGroup>& args,
const vkapi::ParamsBindList& params,
const std::vector<PushConstantDataInfo>& push_constants,
const vkapi::SpecVarList& spec_vars,
const std::vector<ValueRef>& resize_args,
const ResizeFunction& resize_fn)
: DispatchNode(
graph,
shader,
pick_global_wg_fn(&graph, shader, args, resize_args),
pick_local_wg_fn(
&graph,
shader,
pick_global_wg_fn(&graph, shader, args, resize_args),
args,
resize_args),
args,
params,
push_constants,
spec_vars,
resize_args,
resize_fn),
pick_shader_fn_{nullptr},
pick_global_wg_fn_(pick_global_wg_fn),
pick_local_wg_fn_(pick_local_wg_fn) {}

void DynamicDispatchNode::encode(ComputeGraph* graph) {
shader_ = pick_shader_fn_(graph, args_, resize_args_);
global_workgroup_size_ = pick_global_wg_fn_(graph, args_, resize_args_);
local_workgroup_size_ =
utils::WorkgroupSize(pick_local_wg_fn_(graph, args_, resize_args_));
if (pick_shader_fn_) {
shader_ = pick_shader_fn_(graph, args_, resize_args_);
}
if (pick_global_wg_fn_) {
global_workgroup_size_ =
pick_global_wg_fn_(graph, shader_, args_, resize_args_);
}
if (pick_local_wg_fn_) {
local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn_(
graph, shader_, global_workgroup_size_, args_, resize_args_));
}
DispatchNode::encode(graph);
}

Expand Down
15 changes: 15 additions & 0 deletions backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@ class DynamicDispatchNode final : public DispatchNode {
const std::vector<ValueRef>&)>;
using PickGlobalFn = const std::function<utils::uvec3(
ComputeGraph*,
const vkapi::ShaderInfo& shader,
const std::vector<ArgGroup>&,
const std::vector<ValueRef>&)>;
using PickLocalFn = const std::function<utils::uvec3(
ComputeGraph*,
const vkapi::ShaderInfo& shader,
const utils::uvec3& global_workgroup_size,
const std::vector<ArgGroup>&,
const std::vector<ValueRef>&)>;

Expand All @@ -51,6 +54,18 @@ class DynamicDispatchNode final : public DispatchNode {
const std::vector<ValueRef>& resize_args,
const ResizeFunction& resize_fn = nullptr);

explicit DynamicDispatchNode(
ComputeGraph& graph,
const vkapi::ShaderInfo& shader,
const PickGlobalFn& pick_global_wg_fn,
const PickLocalFn& pick_local_wg_fn,
const std::vector<ArgGroup>& args,
const vkapi::ParamsBindList& params,
const std::vector<PushConstantDataInfo>& push_constants,
const vkapi::SpecVarList& spec_vars,
const std::vector<ValueRef>& resize_args,
const ResizeFunction& resize_fn = nullptr);

~DynamicDispatchNode() override = default;

void encode(ComputeGraph* graph) override;
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/ExecuteNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class ExecuteNode {
(void)graph;
}

inline void trigger_resize(ComputeGraph* graph) {
virtual inline void trigger_resize(ComputeGraph* graph) {
if (resize_fn_ != nullptr) {
resize_fn_(graph, args_, resize_args_);
}
Expand Down
16 changes: 11 additions & 5 deletions backends/vulkan/test/vulkan_compute_api_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <gtest/gtest.h>

#include <bitset>
#include <iomanip>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -1660,9 +1661,8 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
for (auto& new_sizes : new_sizes_list) {
graph.get_tensor(a.value)->virtual_resize(new_sizes);
graph.get_tensor(b.value)->virtual_resize(new_sizes);
graph.get_tensor(c)->virtual_resize(new_sizes);
graph.get_tensor(d.value)->virtual_resize(new_sizes);
graph.get_tensor(e)->virtual_resize(new_sizes);
graph.propagate_resize();

float val_a = new_sizes[1] + 4.0f;
float val_b = new_sizes[2] + 1.5f;
Expand Down Expand Up @@ -3315,17 +3315,23 @@ vkapi::ShaderInfo pick_dynamic_dispatch_shader(

utils::uvec3 pick_dynamic_dispatch_global_wg_size(
ComputeGraph* graph,
const vkapi::ShaderInfo& shader,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& additional_args) {
const std::vector<ValueRef>& resize_args) {
(void)shader;
const ValueRef out = args[0].refs[0];

return graph->logical_limits_of(out);
}

utils::uvec3 pick_dynamic_dispatch_local_wg_size(
ComputeGraph* graph,
const vkapi::ShaderInfo& shader,
const utils::uvec3& global_workgroup_size,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& additional_args) {
const std::vector<ValueRef>& resize_args) {
(void)graph;
(void)shader;
(void)global_workgroup_size;
return {64, 1, 1};
}

Expand Down
Loading