Skip to content
Merged
66 changes: 66 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,72 @@ VALUE_PTR_CLASS_IMPL(SymIntPtr, SymInt, SymInt)

#undef VALUE_PTR_CLASS_IMPL

//
// TmpTensor
//

TmpTensor::TmpTensor(
ComputeGraph* const graph_ptr,
const std::vector<int64_t>& sizes,
const vkapi::ScalarType dtype,
const utils::StorageType storage_type,
const utils::GPUMemoryLayout memory_layout)
: graph_p(graph_ptr),
sobj_idx(get_sobj_idx()),
vref(graph_p->add_tensor(
sizes,
dtype,
storage_type,
memory_layout,
sobj_idx)) {}

TmpTensor::TmpTensor(
ComputeGraph* const graph_ptr,
const std::vector<int64_t>& sizes,
const vkapi::ScalarType dtype,
const utils::StorageType storage_type)
: graph_p(graph_ptr),
sobj_idx(get_sobj_idx()),
vref(graph_p->add_tensor(sizes, dtype, storage_type, sobj_idx)) {}

TmpTensor::TmpTensor(
ComputeGraph* const graph_ptr,
const std::vector<int64_t>& sizes,
const vkapi::ScalarType dtype,
const utils::GPUMemoryLayout memory_layout)
: graph_p(graph_ptr),
sobj_idx(get_sobj_idx()),
vref(graph_p->add_tensor(sizes, dtype, memory_layout, sobj_idx)) {}

TmpTensor::TmpTensor(
ComputeGraph* const graph_ptr,
const std::vector<int64_t>& sizes,
const vkapi::ScalarType dtype)
: graph_p(graph_ptr),
sobj_idx(get_sobj_idx()),
vref(graph_p->add_tensor(sizes, dtype, sobj_idx)) {}

TmpTensor::~TmpTensor() {
// Lifetime of this temporary tensor is expired; return the shared object to
// the pool, as long as the sobj index is valid
if (sobj_idx >= 0) {
graph_p->tmp_shared_object_idxs_.emplace(sobj_idx);
}
}

int64_t TmpTensor::get_sobj_idx() {
int64_t sobj_idx;
// If no available temporary shared objects, request a new one to be created
if (graph_p->tmp_shared_object_idxs_.empty()) {
sobj_idx = graph_p->shared_objects_.size();
} else {
// Get the first available shared object idx
sobj_idx = graph_p->tmp_shared_object_idxs_.top();
graph_p->tmp_shared_object_idxs_.pop();
}
return sobj_idx;
}

//
// ComputeGraph
//
Expand Down
81 changes: 81 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
// @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName

#include <optional>
#include <stack>

#include <executorch/backends/vulkan/runtime/api/api.h>

Expand Down Expand Up @@ -67,6 +68,79 @@ DECL_VALUE_PTR_CLASS(SymIntPtr, SymInt);

#undef DECL_VALUE_PTR_CLASS

//
// TmpTensor
//

/*
* This struct is used to recycle the memory of temporary tensors that are
* created during the execution of a node. Upon construction, this struct will
* check the `tmp_shared_object_idxs_` of the provided `ComputeGraph` instance
* if any shared objects are available; if not, then a new one is created. A
* tensor value is then added to the `ComputeGraph` instance with the requested
* specifications. Upon destruction, the shared object index of the temporary
* tensor is returned to `tmp_shared_object_idxs_`.
*
* Note that instances of this struct can be used as if they were `ValueRef` due
* to implementation of a custom casting operator.
*
* This class should only be used to create tensors whose lifetimes exist only
* in a well defined scope (i.e. within a function).
*/
struct TmpTensor {
ComputeGraph* graph_p;
int64_t sobj_idx;
ValueRef vref;

//
// Match all available overloads of `add_tensor`
//

TmpTensor(
ComputeGraph* const graph_ptr,
const std::vector<int64_t>& sizes,
const vkapi::ScalarType dtype,
const utils::StorageType storage_type,
const utils::GPUMemoryLayout memory_layout);

TmpTensor(
ComputeGraph* const graph_ptr,
const std::vector<int64_t>& sizes,
const vkapi::ScalarType dtype,
const utils::StorageType storage_type);

TmpTensor(
ComputeGraph* const graph_ptr,
const std::vector<int64_t>& sizes,
const vkapi::ScalarType dtype,
const utils::GPUMemoryLayout memory_layout);

TmpTensor(
ComputeGraph* const graph_ptr,
const std::vector<int64_t>& sizes,
const vkapi::ScalarType dtype);

// No copy construction or assignment
TmpTensor(TmpTensor& other) = delete;
TmpTensor& operator=(TmpTensor& other) = delete;

// No move construction or assignment
TmpTensor(TmpTensor&& other) = delete;
TmpTensor& operator=(TmpTensor&& other) = delete;

// Custom cast to ValueRef
operator ValueRef() const {
return vref;
};

~TmpTensor();

private:
// Helper function to get first available shared object index or request a new
// one to be created.
int64_t get_sobj_idx();
};

//
// ComputeGraph
//
Expand Down Expand Up @@ -94,7 +168,12 @@ class ComputeGraph final {
vkapi::DescriptorPoolConfig execute_descriptor_counts_;

std::unique_ptr<api::Context> context_;

std::vector<SharedObject> shared_objects_;
// This stack is used by `TmpTensor` instances to recycle shared objects
// for temporary tensors. See the comments of `TmpTensor` for more details
std::stack<int64_t> tmp_shared_object_idxs_;

std::vector<Value> values_;
std::vector<api::ParamsBuffer> param_ubos_;

Expand Down Expand Up @@ -593,6 +672,8 @@ class ComputeGraph final {
friend class BoolListPtr;
friend class ValueListPtr;
friend class SymIntPtr;

friend struct TmpTensor;
};

template <typename T>
Expand Down
5 changes: 5 additions & 0 deletions backends/vulkan/runtime/graph/containers/Value.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ inline bool is_valid(ValueRef value_ref) {
struct IOValueRef {
ValueRef value;
ValueRef staging;

// Custom cast to ValueRef
operator ValueRef() const {
return value;
};
};

/*
Expand Down
99 changes: 99 additions & 0 deletions backends/vulkan/test/vulkan_compute_api_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1518,6 +1518,105 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
}
}

TEST(VulkanComputeGraphTest, test_simple_graph_with_tmp_tensors) {
GraphConfig config;
ComputeGraph graph(config);

std::vector<int64_t> size_big = {8, 64, 124};
std::vector<int64_t> size_small = {8, 1, 124};

// Build graph

IOValueRef a = graph.add_input_tensor(
size_big, vkapi::kFloat, /*shared_object_idx = */ 0);
IOValueRef b = graph.add_input_tensor(
size_small, vkapi::kFloat, /*shared_object_idx = */ 1);

IOValueRef out = {};

out.value =
graph.add_tensor(size_big, vkapi::kFloat, /*shared_object_idx = */ 2);

// Perform the following compute
//
// a, b, out;
// {
// inter;
// {
// tmp = a + b
// tmp2 = tmp + a
// inter = tmp2 + b
// }
// {
// tmp = inter + b;
// tmp2 = tmp + a
// out = tmp2 + b;
// }
// }
{
TmpTensor inter(&graph, size_big, vkapi::kFloat);
EXPECT_TRUE(inter.sobj_idx == 3);
{
TmpTensor tmp(&graph, size_big, vkapi::kFloat);
EXPECT_TRUE(tmp.sobj_idx == 4);
VK_GET_OP_FN("aten.add.Tensor")
(graph, {a, b, kDummyValueRef, tmp});

TmpTensor tmp2(&graph, size_big, vkapi::kFloat);
EXPECT_TRUE(tmp2.sobj_idx == 5);
VK_GET_OP_FN("aten.add.Tensor")
(graph, {tmp, a, kDummyValueRef, tmp2});

VK_GET_OP_FN("aten.add.Tensor")
(graph, {tmp2, b, kDummyValueRef, inter});
}
{
TmpTensor tmp(&graph, size_big, vkapi::kFloat);
EXPECT_TRUE(tmp.sobj_idx == 4);
VK_GET_OP_FN("aten.add.Tensor")
(graph, {inter, b, kDummyValueRef, tmp});

TmpTensor tmp2(&graph, size_big, vkapi::kFloat);
EXPECT_TRUE(tmp2.sobj_idx == 5);
VK_GET_OP_FN("aten.add.Tensor")
(graph, {tmp, a, kDummyValueRef, tmp2});

VK_GET_OP_FN("aten.add.Tensor")
(graph, {tmp2, b, kDummyValueRef, out});
}
}

out.staging = graph.set_output_tensor(out.value);

graph.prepare();
graph.encode_execute();

// Run graph

for (float i = 5.0f; i < 30.0f; i += 10.0f) {
float val_a = i + 2.0f;
float val_b = i + 1.5f;
float val_tmp = val_a + val_b;
float val_tmp2 = val_tmp + val_a;
float val_inter = val_tmp2 + val_b;
float val_tmp_2 = val_inter + val_b;
float val_tmp2_2 = val_tmp_2 + val_a;
float val_out = val_tmp2_2 + val_b;

fill_vtensor(graph, a, val_a);
fill_vtensor(graph, b, val_b);

graph.execute();

EXTRACT_TENSOR(out);

// Sanity check that the values are correct
for (size_t i = 0; i < graph.get_tensor(out.value)->numel(); ++i) {
CHECK_VALUE(data_out, i, val_out);
}
}
}

TEST(VulkanComputeGraphTest, test_large_graph) {
auto build_start_time = std::chrono::system_clock::now();
GraphConfig config;
Expand Down
Loading