Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 36 additions & 9 deletions backends/vulkan/runtime/gen_vulkan_spv.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
# Basic configuration settings for shaders
DEFAULT_ENV: Dict[str, Any] = {
"PRECISION": "highp",
# B is shorthand for "binding". This is used to automatically increment the
# layout binding index when declaring layout bindings. Note that a container
# type is used because integers are immutable in Python.
"B": [0],
}

# Establishes relationships between different tensor types and different GLSL types
Expand Down Expand Up @@ -179,8 +183,14 @@ def get_access_qualifier(access_type: Optional[str]) -> str:
raise AssertionError(f"Invalid access type: {access_type}")


def get_slot_val(slot: Union[int, List[int]]) -> int:
if isinstance(slot, list):
return slot[0]
return slot


def layout_declare_buffer(
slot: int,
slot: Union[int, List[int]],
access_type: str,
var_name: str,
dtype: str,
Expand All @@ -192,15 +202,18 @@ def layout_declare_buffer(
array_type = buffer_scalar_type(dtype)

out_str = f"""
layout(set = 0, binding = {slot}) buffer {precision} restrict {get_access_qualifier(access_type)} {var_name}Buffer {{
layout(set = 0, binding = {get_slot_val(slot)}) buffer {precision} restrict {get_access_qualifier(access_type)} {var_name}Buffer {{
{array_type} {var_name}[];
}};
"""

if isinstance(slot, list):
slot[0] = slot[0] + 1
return out_str


def layout_declare_image(
slot: int,
slot: Union[int, List[int]],
access_type: str,
var_name: str,
dtype: str,
Expand All @@ -209,11 +222,16 @@ def layout_declare_image(
) -> str:
image_format = TYPE_MAPPINGS["IMAGE_FORMAT"][dtype]
image_type = TYPE_MAPPINGS["IMAGE_T"][image_ndim][dtype]
return f"layout(set = 0, binding = {slot}, {image_format}) uniform {precision} restrict {get_access_qualifier(access_type)} {image_type} {var_name};"

ret_str = f"layout(set = 0, binding = {get_slot_val(slot)}, {image_format}) uniform {precision} restrict {get_access_qualifier(access_type)} {image_type} {var_name};"

if isinstance(slot, list):
slot[0] = slot[0] + 1
return ret_str


def layout_declare_sampler(
slot: int,
slot: Union[int, List[int]],
access_type: str,
var_name: str,
dtype: str,
Expand All @@ -222,11 +240,16 @@ def layout_declare_sampler(
image_ndim: int = 3,
) -> str:
sampler_type = TYPE_MAPPINGS["SAMPLER_T"][image_ndim][dtype]
return f"layout(set = 0, binding = {slot}) uniform {precision} {sampler_type} {var_name};"

ret_str = f"layout(set = 0, binding = {get_slot_val(slot)}) uniform {precision} {sampler_type} {var_name};"

if isinstance(slot, list):
slot[0] = slot[0] + 1
return ret_str


def layout_declare_tensor(
slot: int,
slot: Union[int, List[int]],
access_type: str,
var_name: str,
dtype: str,
Expand Down Expand Up @@ -262,7 +285,9 @@ def layout_declare_tensor(
)


def layout_declare_ubo(slot: int, *args, precision: str = "PRECISION") -> str:
def layout_declare_ubo(
slot: Union[int, List[int]], *args, precision: str = "PRECISION"
) -> str:
assert len(args) % 2 == 0

var_list = list(zip(args[::2], args[1::2]))
Expand All @@ -272,12 +297,14 @@ def layout_declare_ubo(slot: int, *args, precision: str = "PRECISION") -> str:
ubo_name += var_name + "_"

out_str = f"""
layout(set = 0, binding = {slot}) uniform {precision} restrict readonly {ubo_name}UBO {{
layout(set = 0, binding = {get_slot_val(slot)}) uniform {precision} restrict readonly {ubo_name}UBO {{
"""
for type_name, var_name in var_list:
out_str += f"{type_name} {var_name};\n"
out_str += "};"

if isinstance(slot, list):
slot[0] = slot[0] + 1
return out_str


Expand Down
66 changes: 66 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,72 @@ VALUE_PTR_CLASS_IMPL(SymIntPtr, SymInt, SymInt)

#undef VALUE_PTR_CLASS_IMPL

//
// TmpTensor
//

TmpTensor::TmpTensor(
ComputeGraph* const graph_ptr,
const std::vector<int64_t>& sizes,
const vkapi::ScalarType dtype,
const utils::StorageType storage_type,
const utils::GPUMemoryLayout memory_layout)
: graph_p(graph_ptr),
sobj_idx(get_sobj_idx()),
vref(graph_p->add_tensor(
sizes,
dtype,
storage_type,
memory_layout,
sobj_idx)) {}

TmpTensor::TmpTensor(
ComputeGraph* const graph_ptr,
const std::vector<int64_t>& sizes,
const vkapi::ScalarType dtype,
const utils::StorageType storage_type)
: graph_p(graph_ptr),
sobj_idx(get_sobj_idx()),
vref(graph_p->add_tensor(sizes, dtype, storage_type, sobj_idx)) {}

TmpTensor::TmpTensor(
ComputeGraph* const graph_ptr,
const std::vector<int64_t>& sizes,
const vkapi::ScalarType dtype,
const utils::GPUMemoryLayout memory_layout)
: graph_p(graph_ptr),
sobj_idx(get_sobj_idx()),
vref(graph_p->add_tensor(sizes, dtype, memory_layout, sobj_idx)) {}

TmpTensor::TmpTensor(
ComputeGraph* const graph_ptr,
const std::vector<int64_t>& sizes,
const vkapi::ScalarType dtype)
: graph_p(graph_ptr),
sobj_idx(get_sobj_idx()),
vref(graph_p->add_tensor(sizes, dtype, sobj_idx)) {}

TmpTensor::~TmpTensor() {
// Lifetime of this temporary tensor is expired; return the shared object to
// the pool, as long as the sobj index is valid
if (sobj_idx >= 0) {
graph_p->tmp_shared_object_idxs_.emplace(sobj_idx);
}
}

int64_t TmpTensor::get_sobj_idx() {
int64_t sobj_idx;
// If no available temporary shared objects, request a new one to be created
if (graph_p->tmp_shared_object_idxs_.empty()) {
sobj_idx = graph_p->shared_objects_.size();
} else {
// Get the first available shared object idx
sobj_idx = graph_p->tmp_shared_object_idxs_.top();
graph_p->tmp_shared_object_idxs_.pop();
}
return sobj_idx;
}

//
// ComputeGraph
//
Expand Down
81 changes: 81 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
// @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName

#include <optional>
#include <stack>

#include <executorch/backends/vulkan/runtime/api/api.h>

Expand Down Expand Up @@ -67,6 +68,79 @@ DECL_VALUE_PTR_CLASS(SymIntPtr, SymInt);

#undef DECL_VALUE_PTR_CLASS

//
// TmpTensor
//

/*
* This struct is used to recycle the memory of temporary tensors that are
* created during the execution of a node. Upon construction, this struct will
* check the `tmp_shared_object_idxs_` of the provided `ComputeGraph` instance
* if any shared objects are available; if not, then a new one is created. A
* tensor value is then added to the `ComputeGraph` instance with the requested
* specifications. Upon destruction, the shared object index of the temporary
* tensor is returned to `tmp_shared_object_idxs_`.
*
* Note that instances of this struct can be used as if they were `ValueRef` due
* to implementation of a custom casting operator.
*
* This class should only be used to create tensors whose lifetimes exist only
* in a well defined scope (i.e. within a function).
*/
struct TmpTensor {
ComputeGraph* graph_p;
int64_t sobj_idx;
ValueRef vref;

//
// Match all available overloads of `add_tensor`
//

TmpTensor(
ComputeGraph* const graph_ptr,
const std::vector<int64_t>& sizes,
const vkapi::ScalarType dtype,
const utils::StorageType storage_type,
const utils::GPUMemoryLayout memory_layout);

TmpTensor(
ComputeGraph* const graph_ptr,
const std::vector<int64_t>& sizes,
const vkapi::ScalarType dtype,
const utils::StorageType storage_type);

TmpTensor(
ComputeGraph* const graph_ptr,
const std::vector<int64_t>& sizes,
const vkapi::ScalarType dtype,
const utils::GPUMemoryLayout memory_layout);

TmpTensor(
ComputeGraph* const graph_ptr,
const std::vector<int64_t>& sizes,
const vkapi::ScalarType dtype);

// No copy construction or assignment
TmpTensor(TmpTensor& other) = delete;
TmpTensor& operator=(TmpTensor& other) = delete;

// No move construction or assignment
TmpTensor(TmpTensor&& other) = delete;
TmpTensor& operator=(TmpTensor&& other) = delete;

// Custom cast to ValueRef
operator ValueRef() const {
return vref;
};

~TmpTensor();

private:
// Helper function to get first available shared object index or request a new
// one to be created.
int64_t get_sobj_idx();
};

//
// ComputeGraph
//
Expand Down Expand Up @@ -94,7 +168,12 @@ class ComputeGraph final {
vkapi::DescriptorPoolConfig execute_descriptor_counts_;

std::unique_ptr<api::Context> context_;

std::vector<SharedObject> shared_objects_;
// This stack is used by `TmpTensor` instances to recycle shared objects
// for temporary tensors. See the comments of `TmpTensor` for more details
std::stack<int64_t> tmp_shared_object_idxs_;

std::vector<Value> values_;
std::vector<api::ParamsBuffer> param_ubos_;

Expand Down Expand Up @@ -593,6 +672,8 @@ class ComputeGraph final {
friend class BoolListPtr;
friend class ValueListPtr;
friend class SymIntPtr;

friend struct TmpTensor;
};

template <typename T>
Expand Down
5 changes: 5 additions & 0 deletions backends/vulkan/runtime/graph/containers/Value.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ inline bool is_valid(ValueRef value_ref) {
struct IOValueRef {
ValueRef value;
ValueRef staging;

// Custom cast to ValueRef
operator ValueRef() const {
return value;
};
};

/*
Expand Down
Loading
Loading