Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions backends/vulkan/runtime/graph/ops/ExecuteNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,22 @@ ExecuteNode::ExecuteNode(
graph.update_descriptor_counts(shader, /*execute = */ true);
}

ExecuteNode::ExecuteNode(
const ResizeFunction& resize_fn,
const std::vector<ValueRef>& resize_args)
: shader_(),
global_workgroup_size_({0u, 0u, 0u}),
local_workgroup_size_({0u, 0u, 0u}),
args_(),
params_(),
spec_vars_(),
resize_fn_(resize_fn),
resize_args_(resize_args) {}

void ExecuteNode::encode(ComputeGraph* graph) {
if (!shader_) {
return;
}
api::Context* const context = graph->context();
vkapi::PipelineBarrier pipeline_barrier{};

Expand Down
16 changes: 15 additions & 1 deletion backends/vulkan/runtime/graph/ops/ExecuteNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class ExecuteNode final {
const std::vector<ArgGroup>&,
const std::vector<ValueRef>&)>;

ExecuteNode(
explicit ExecuteNode(
ComputeGraph& graph,
const vkapi::ShaderInfo& shader,
const utils::uvec3& global_workgroup_size,
Expand All @@ -59,6 +59,15 @@ class ExecuteNode final {
const ResizeFunction& resize_fn = nullptr,
const std::vector<ValueRef>& resize_args = {});

/*
* This overload of the ExecuteNode constructor is used to register ops which
* update a tensor view. No shader is dispatched, but the node still needs to
* update the view's sizes and strides after a resize.
*/
explicit ExecuteNode(
const ResizeFunction& resize_fn = nullptr,
const std::vector<ValueRef>& resize_args = {});

~ExecuteNode() = default;

void encode(ComputeGraph* graph);
Expand All @@ -83,6 +92,11 @@ class ExecuteNode final {
const vkapi::SpecVarList spec_vars_;
const ResizeFunction resize_fn_;
const std::vector<ValueRef> resize_args_;

public:
operator bool() const {
return shader_;
}
};

} // namespace vkcompute
8 changes: 6 additions & 2 deletions backends/vulkan/runtime/vk_api/Shader.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ class ShaderLayout final {

struct ShaderInfo final {
struct {
const uint32_t* bin;
uint32_t size;
const uint32_t* bin = nullptr;
uint32_t size = 0u;
} src_code;

std::string kernel_name{""};
Expand All @@ -71,6 +71,10 @@ struct ShaderInfo final {
const uint32_t,
std::vector<VkDescriptorType>,
const utils::uvec3 tile_size);

operator bool() const {
return src_code.bin != nullptr;
};
};

bool operator==(const ShaderInfo& _1, const ShaderInfo& _2);
Expand Down
6 changes: 6 additions & 0 deletions backends/vulkan/test/utils/test_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -482,3 +482,9 @@ void execute_graph_and_check_output(
}
}
}

bool check_close(float a, float b, float atol, float rtol) {
float max = std::max(std::abs(a), std::abs(b));
float diff = std::abs(a - b);
return diff <= (atol + rtol * max);
}
6 changes: 6 additions & 0 deletions backends/vulkan/test/utils/test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,9 @@ void print_vector(
}
std::cout << std::endl;
}

//
// Misc. Utilities
//

bool check_close(float a, float b, float atol = 1e-4, float rtol = 1e-5);
22 changes: 21 additions & 1 deletion backends/vulkan/test/vulkan_compute_api_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,13 @@ std::vector<int64_t> get_reference_strides(
return {};
}

TEST_F(VulkanComputeAPITest, empty_init_shader_info_test) {
vkapi::ShaderInfo empty_shader_info;
EXPECT_FALSE(empty_shader_info);
EXPECT_TRUE(empty_shader_info.src_code.bin == nullptr);
EXPECT_TRUE(empty_shader_info.src_code.size == 0u);
}

TEST_F(VulkanComputeAPITest, calculate_tensor_strides_test) {
for (const auto& sizes : standard_sizes_to_test) {
if (sizes.size() < 3) {
Expand Down Expand Up @@ -601,7 +608,7 @@ TEST_F(VulkanComputeAPITest, tensor_no_copy_transpose_test) {
EXPECT_TRUE(data_out.size() == ref_out.size());

for (size_t i = 0; i < data_out.size(); ++i) {
EXPECT_TRUE(data_out[i] == ref_out[i]);
EXPECT_TRUE(check_close(data_out[i], ref_out[i]));
}
}

Expand Down Expand Up @@ -975,6 +982,19 @@ TEST(VulkanComputeGraphTest, test_values_string) {
EXPECT_TRUE(stored == "hello, world");
}

TEST(VulkanComputeGraphTest, empty_init_executenode_test) {
ExecuteNode node(nullptr, {});
EXPECT_FALSE(node);

GraphConfig config;
ComputeGraph graph(config);

// Encode an empty ExecuteNode and check that command buffer encoding does not
// crash.
graph.execute_nodes().emplace_back(new ExecuteNode(nullptr, {}));
EXPECT_NO_FATAL_FAILURE(graph.encode_execute());
}

TEST(VulkanComputeGraphTest, test_zero_dim_tensor) {
GraphConfig config;
ComputeGraph graph(config);
Expand Down