Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion backends/vulkan/runtime/api/containers/ParamsBuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,29 @@ class ParamsBuffer final {
}
// Fill the uniform buffer with data in block
{
vkapi::MemoryMap mapping(vulkan_buffer_, vkapi::MemoryAccessType::WRITE);
vkapi::MemoryMap mapping(vulkan_buffer_, vkapi::kWrite);
Block* data_ptr = mapping.template data<Block>();

*data_ptr = block;
}
}

template <typename T>
T read() const {
T val;
if (sizeof(val) != nbytes_) {
VK_THROW(
"Attempted to store value from ParamsBuffer to type of different size");
}
// Read value from uniform buffer and store in val
{
vkapi::MemoryMap mapping(vulkan_buffer_, vkapi::kRead);
T* data_ptr = mapping.template data<T>();

val = *data_ptr;
}
return val;
}
};

} // namespace api
Expand Down
4 changes: 4 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,10 @@ void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) {
get_symint(idx)->set(val);
}

int32_t ComputeGraph::read_symint(const ValueRef idx) {
return get_symint(idx)->get();
}

SharedObject& ComputeGraph::get_shared_object(const int64_t idx) {
if (idx >= shared_objects_.size()) {
shared_objects_.resize(static_cast<size_t>(idx + 1));
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,8 @@ class ComputeGraph final {

void set_symint(const ValueRef idx, const int32_t val);

int32_t read_symint(const ValueRef idx);

/*
* Convenience function to add an input tensor along with its staging buffer
*/
Expand Down
4 changes: 4 additions & 0 deletions backends/vulkan/runtime/graph/containers/SymInt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ void SymInt::set(const int32_t val) {
gpu_buffer.update(val);
}

int32_t SymInt::get() {
return gpu_buffer.read<int32_t>();
}

void SymInt::operator=(const int32_t val) {
gpu_buffer.update(val);
}
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/graph/containers/SymInt.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ struct SymInt final {

void set(const int32_t val);

int32_t get();

void operator=(const int32_t val);
};

Expand Down
3 changes: 3 additions & 0 deletions backends/vulkan/test/vulkan_compute_api_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1433,6 +1433,9 @@ TEST(VulkanComputeGraphTest, test_simple_graph_with_symint) {
int scalar_val = i - 3.0f;
graph.set_symint(scalar, scalar_val);

int32_t scalar_val_read = graph.read_symint(scalar);
EXPECT_TRUE(scalar_val_read == scalar_val);

float val_a = i + 2.0f;
float val_out = val_a + scalar_val;

Expand Down
Loading