Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions backends/vulkan/runtime/api/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,13 @@ void Context::report_shader_dispatch_end() {
vkapi::DescriptorSet Context::get_descriptor_set(
const vkapi::ShaderInfo& shader_descriptor,
const utils::uvec3& local_workgroup_size,
const vkapi::SpecVarList& additional_constants) {
const vkapi::SpecVarList& additional_constants,
const uint32_t push_constants_size) {
VkDescriptorSetLayout shader_layout =
shader_layout_cache().retrieve(shader_descriptor.kernel_layout);

VkPipelineLayout pipeline_layout =
pipeline_layout_cache().retrieve(shader_layout);
pipeline_layout_cache().retrieve(shader_layout, push_constants_size);

vkapi::SpecVarList spec_constants = {
SV(local_workgroup_size[0u]),
Expand All @@ -105,7 +106,7 @@ vkapi::DescriptorSet Context::get_descriptor_set(
spec_constants.append(additional_constants);

VkPipeline pipeline = pipeline_cache().retrieve(
{pipeline_layout_cache().retrieve(shader_layout),
{pipeline_layout_cache().retrieve(shader_layout, push_constants_size),
shader_cache().retrieve(shader_descriptor),
spec_constants});

Expand Down Expand Up @@ -151,7 +152,7 @@ void Context::register_shader_dispatch(
const VkDescriptorSetLayout shader_layout =
shader_layout_cache().retrieve(shader_descriptor.kernel_layout);
const VkPipelineLayout pipeline_layout =
pipeline_layout_cache().retrieve(shader_layout);
pipeline_layout_cache().retrieve(shader_layout, push_constants_size);
cmd_.set_push_constants(
pipeline_layout, push_constants_data, push_constants_size);
}
Expand Down
9 changes: 6 additions & 3 deletions backends/vulkan/runtime/api/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,13 @@ class Context final {
vkapi::DescriptorSet get_descriptor_set(
const vkapi::ShaderInfo&,
const utils::uvec3&,
const vkapi::SpecVarList&);
const vkapi::SpecVarList&,
const uint32_t push_constants_size);

inline vkapi::DescriptorSet get_descriptor_set(
const vkapi::ShaderInfo& shader_descriptor,
const utils::uvec3& local_work_group_size) {
return get_descriptor_set(shader_descriptor, local_work_group_size, {});
return get_descriptor_set(shader_descriptor, local_work_group_size, {}, 0u);
}

void register_shader_dispatch(
Expand Down Expand Up @@ -333,8 +334,10 @@ inline bool Context::submit_compute_job(
dispatch_id);

// Factor out template parameter independent code to minimize code bloat.
// Note that push constants are not exposed yet via this API, therefore the
// push constants size is assumed to be 0.
vkapi::DescriptorSet descriptor_set = get_descriptor_set(
shader, local_work_group_size, specialization_constants);
shader, local_work_group_size, specialization_constants, 0u);

detail::bind(
descriptor_set,
Expand Down
5 changes: 3 additions & 2 deletions backends/vulkan/runtime/api/containers/ParamsBuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ class ParamsBuffer final {
vulkan_buffer_(
context_p_->adapter_ptr()->vma().create_params_buffer(block)) {}

template <typename Block>
ParamsBuffer(Context* context_p, const VkDeviceSize nbytes)
// The last bool argument, though unused, is required to disambiguate this
// constructor from the one above.
ParamsBuffer(Context* context_p, const VkDeviceSize nbytes, const bool unused)
: context_p_(context_p),
vulkan_buffer_(
context_p_->adapter_ptr()->vma().create_uniform_buffer(nbytes)) {}
Expand Down
8 changes: 4 additions & 4 deletions backends/vulkan/runtime/api/containers/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ utils::GPUMemoryLayout vTensor::estimate_memory_layout() const {

const vkapi::BufferBindInfo vTensor::sizes_ubo() {
if (!uniforms_.buffer()) {
uniforms_ = ParamsBuffer(storage_.context_, kMaxUniformBufferSize);
uniforms_ = ParamsBuffer(storage_.context_, kMaxUniformBufferSize, true);
}
if (sizes_uniform_offset_ == kUniformOffsetUnset) {
VK_CHECK_COND(
Expand All @@ -674,7 +674,7 @@ const vkapi::BufferBindInfo vTensor::sizes_ubo() {

const vkapi::BufferBindInfo vTensor::strides_ubo() {
if (!uniforms_.buffer()) {
uniforms_ = ParamsBuffer(storage_.context_, kMaxUniformBufferSize);
uniforms_ = ParamsBuffer(storage_.context_, kMaxUniformBufferSize, true);
}
if (unsqueezed_strides_offset_ == kUniformOffsetUnset) {
VK_CHECK_COND(
Expand All @@ -691,7 +691,7 @@ const vkapi::BufferBindInfo vTensor::strides_ubo() {

const vkapi::BufferBindInfo vTensor::logical_limits_ubo() {
if (!uniforms_.buffer()) {
uniforms_ = ParamsBuffer(storage_.context_, kMaxUniformBufferSize);
uniforms_ = ParamsBuffer(storage_.context_, kMaxUniformBufferSize, true);
}
if (logical_limits_uniform_offset_ == kUniformOffsetUnset) {
VK_CHECK_COND(
Expand All @@ -707,7 +707,7 @@ const vkapi::BufferBindInfo vTensor::logical_limits_ubo() {

const vkapi::BufferBindInfo vTensor::numel_ubo() {
if (!uniforms_.buffer()) {
uniforms_ = ParamsBuffer(storage_.context_, kMaxUniformBufferSize);
uniforms_ = ParamsBuffer(storage_.context_, kMaxUniformBufferSize, true);
}
if (numel_uniform_offset_ == kUniformOffsetUnset) {
VK_CHECK_COND(
Expand Down
23 changes: 12 additions & 11 deletions backends/vulkan/runtime/graph/ops/DispatchNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,30 +60,31 @@ void DispatchNode::encode(ComputeGraph* graph) {

std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();

std::array<uint8_t, kMaxPushConstantSize> push_constants_data;
uint32_t push_constants_offset = 0;

for (const auto& push_constant : push_constants_) {
push_constants_offset += push_constant.write(
push_constants_data.data(),
push_constants_offset,
kMaxPushConstantSize);
}

context->report_shader_dispatch_start(
shader_.kernel_name,
global_workgroup_size_,
local_workgroup_size_,
node_id_);

vkapi::DescriptorSet descriptor_set =
context->get_descriptor_set(shader_, local_workgroup_size_, spec_vars_);
vkapi::DescriptorSet descriptor_set = context->get_descriptor_set(
shader_, local_workgroup_size_, spec_vars_, push_constants_offset);

uint32_t idx = 0;
idx = bind_values_to_descriptor_set(
graph, args_, pipeline_barrier, descriptor_set, idx);

bind_params_to_descriptor_set(params_, descriptor_set, idx);

std::array<uint8_t, kMaxPushConstantSize> push_constants_data;
uint32_t push_constants_offset = 0;

for (const auto& push_constant : push_constants_) {
push_constants_offset += push_constant.write(
push_constants_data.data(),
push_constants_offset,
kMaxPushConstantSize);
}
context->register_shader_dispatch(
descriptor_set,
pipeline_barrier,
Expand Down
4 changes: 2 additions & 2 deletions backends/vulkan/runtime/graph/ops/PrepackNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ void PrepackNode::encode(ComputeGraph* graph) {

{
vkapi::PipelineBarrier pipeline_barrier{};
vkapi::DescriptorSet descriptor_set =
context->get_descriptor_set(shader_, local_workgroup_size_, spec_vars_);
vkapi::DescriptorSet descriptor_set = context->get_descriptor_set(
shader_, local_workgroup_size_, spec_vars_, 0u);

uint32_t idx = 0;
bind_tensor_to_descriptor_set(
Expand Down
31 changes: 25 additions & 6 deletions backends/vulkan/runtime/vk_api/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,17 +205,29 @@ bool operator==(const SpecVarList& lhs, const SpecVarList& rhs) {

PipelineLayout::PipelineLayout(
VkDevice device,
VkDescriptorSetLayout descriptor_layout)
VkDescriptorSetLayout descriptor_layout,
const uint32_t push_constants_size)
: device_(device), handle_{VK_NULL_HANDLE} {
// TODO: Enable push constants
VkPushConstantRange pc_range{
VK_SHADER_STAGE_COMPUTE_BIT, // stageFlags
0u, // offset
push_constants_size, // size
};
uint32_t num_push_constants = 0u;
VkPushConstantRange* pc_ranges_ptr = nullptr;
if (push_constants_size > 0u) {
num_push_constants = 1u;
pc_ranges_ptr = &pc_range;
}

const VkPipelineLayoutCreateInfo pipeline_layout_create_info{
VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO, // sType
nullptr, // pNext
0u, // flags
1u, // setLayoutCount
&descriptor_layout, // pSetLayouts
0u, // pushConstantRangeCount
nullptr, // pPushConstantRanges
num_push_constants, // pushConstantRangeCount
pc_ranges_ptr, // pPushConstantRanges
};

VK_CHECK(vkCreatePipelineLayout(
Expand Down Expand Up @@ -344,12 +356,19 @@ PipelineLayoutCache::~PipelineLayoutCache() {
}

VkPipelineLayout PipelineLayoutCache::retrieve(
const PipelineLayoutCache::Key& key) {
const VkDescriptorSetLayout layout,
const uint32_t push_constants_size) {
PipelineLayoutCache::Key key{layout, push_constants_size};
std::lock_guard<std::mutex> lock(cache_mutex_);

auto it = cache_.find(key);
if (cache_.cend() == it) {
it = cache_.insert({key, PipelineLayoutCache::Value(device_, key)}).first;
it = cache_
.insert(
{key,
PipelineLayoutCache::Value(
device_, layout, push_constants_size)})
.first;
}

return it->second.handle();
Expand Down
16 changes: 10 additions & 6 deletions backends/vulkan/runtime/vk_api/Pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ VkImageLayout vk_layout(const PipelineStageFlags, const MemoryAccessFlags);

class PipelineLayout final {
public:
explicit PipelineLayout(VkDevice, VkDescriptorSetLayout);
explicit PipelineLayout(VkDevice, VkDescriptorSetLayout, const uint32_t);

PipelineLayout(const PipelineLayout&) = delete;
PipelineLayout& operator=(const PipelineLayout&) = delete;
Expand Down Expand Up @@ -193,13 +193,17 @@ class PipelineLayoutCache final {
PipelineLayoutCache& operator=(PipelineLayoutCache&&) = delete;

~PipelineLayoutCache();

using Key = VkDescriptorSetLayout;
using Key = std::pair<VkDescriptorSetLayout, uint32_t>;
using Value = PipelineLayout;

struct Hasher {
inline size_t operator()(VkDescriptorSetLayout descriptor_layout) const {
return std::hash<VkDescriptorSetLayout>()(descriptor_layout);
inline size_t operator()(
std::pair<VkDescriptorSetLayout, uint32_t> key) const {
size_t seed = 0;
seed = utils::hash_combine(
seed, std::hash<VkDescriptorSetLayout>()(key.first));
seed = utils::hash_combine(seed, std::hash<uint32_t>()(key.second));
return seed;
}
};

Expand All @@ -212,7 +216,7 @@ class PipelineLayoutCache final {
std::unordered_map<Key, Value, Hasher> cache_;

public:
VkPipelineLayout retrieve(const Key&);
VkPipelineLayout retrieve(const VkDescriptorSetLayout, const uint32_t);
void purge();
};

Expand Down