Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions backends/vulkan/runtime/api/containers/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -707,8 +707,7 @@ void vTensor::virtual_transpose(const int64_t dim0, const int64_t dim1) {
const int dim1_whcn = sizes_.size() - 1 - dim1;
if (packed_dim_ == dim0_whcn) {
packed_dim_ = dim1_whcn;
}
if (packed_dim_ == dim1_whcn) {
} else if (packed_dim_ == dim1_whcn) {
packed_dim_ = dim0_whcn;
}

Expand All @@ -719,6 +718,12 @@ void vTensor::virtual_transpose(const int64_t dim0, const int64_t dim1) {
VK_CHECK_COND(dim0_whcn < 3 && dim1_whcn < 3);
std::iter_swap(
axis_map_.begin() + dim0_whcn, axis_map_.begin() + dim1_whcn);
// Update the "identity" of the concatted dimension
if (axis_map_.at(3) == dim0_whcn) {
axis_map_.at(3) = dim1_whcn;
} else if (axis_map_.at(3) == dim1_whcn) {
axis_map_.at(3) = dim0_whcn;
}
}
update_metadata();
}
Expand Down
26 changes: 26 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,32 @@ std::vector<int64_t> ComputeGraph::sizes_of(const ValueRef idx) const {
VK_THROW("Could not get sizes of value with type ", val.type());
}

int64_t ComputeGraph::dim_of(const ValueRef idx) const {
const Value& val = values_.at(idx);
if (val.isTensor()) {
return val.toConstTensor().dim();
} else if (val.isTensorRef()) {
return val.toConstTensorRef().sizes.size();
}
VK_THROW("Could not get dim of value with type ", val.type());
}

std::vector<int64_t> ComputeGraph::dim_order_of(const ValueRef idx) const {
const Value& val = values_.at(idx);
if (val.isTensor()) {
return val.toConstTensor().dim_order();
}
VK_THROW("Could not get dim order of value with type ", val.type());
}

std::vector<int64_t> ComputeGraph::strides_of(const ValueRef idx) const {
const Value& val = values_.at(idx);
if (val.isTensor()) {
return val.toConstTensor().strides();
}
VK_THROW("Could not get strides of value with type ", val.type());
}

vkapi::ScalarType ComputeGraph::dtype_of(const ValueRef idx) const {
const Value& val = values_.at(idx);
if (val.isTensor()) {
Expand Down
6 changes: 6 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,12 @@ class ComputeGraph final {
VK_THROW("Could not get sizes of value with type ", val.type());
}

int64_t dim_of(const ValueRef idx) const;

std::vector<int64_t> dim_order_of(const ValueRef idx) const;

std::vector<int64_t> strides_of(const ValueRef idx) const;

vkapi::ScalarType dtype_of(const ValueRef idx) const;

inline const utils::ivec3& logical_limits_of(const ValueRef idx) const {
Expand Down
85 changes: 85 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Transpose.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/Logging.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Transpose.h>

#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>

#include <algorithm>

namespace vkcompute {

void resize_transpose_view_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& extra_args) {
(void)args;
vTensorPtr out = graph->get_tensor(extra_args[0]);
vTensorPtr in = graph->get_tensor(extra_args[1]);

const int64_t dim0 = graph->extract_scalar<int64_t>(extra_args[2]);
const int64_t dim1 = graph->extract_scalar<int64_t>(extra_args[3]);

std::vector<int64_t> new_sizes = in->sizes();
// Transpose the resized input sizes
std::iter_swap(new_sizes.begin() + dim0, new_sizes.begin() + dim1);
out->virtual_resize(new_sizes);
}

void check_transpose_view_args(
ComputeGraph& graph,
ValueRef in_ref,
const int64_t dim0,
const int64_t dim1,
ValueRef out_ref) {
VK_CHECK_COND(
graph.val_is_view_of(out_ref, in_ref),
"output tensor must be a view of the input tensor");

const int64_t in_ndim = graph.dim_of(in_ref);
VK_CHECK_COND(
dim0 >= 0 && dim0 < in_ndim, "dim0 is not in the range of [0, in_ndim)");
VK_CHECK_COND(
dim1 >= 0 && dim1 < in_ndim, "dim1 is not in the range of [0, in_ndim)");
}

void add_transpose_view_node(
ComputeGraph& graph,
ValueRef input_ref,
ValueRef dim0_ref,
ValueRef dim1_ref,
ValueRef out_ref) {
const int64_t dim0 = graph.extract_scalar<int64_t>(dim0_ref);
const int64_t dim1 = graph.extract_scalar<int64_t>(dim1_ref);

check_transpose_view_args(graph, input_ref, dim0, dim1, out_ref);
graph.get_tensor(out_ref)->virtual_transpose(dim0, dim1);

graph.execute_nodes().emplace_back(new ExecuteNode(
resize_transpose_view_node, {out_ref, input_ref, dim0_ref, dim1_ref}));
}

void transpose(ComputeGraph& graph, const std::vector<ValueRef>& args) {
const ValueRef out = args[3];
return add_transpose_view_node(
graph,
args[0], // input
args[1], // dim0
args[2], // dim1
out);
}

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.transpose.int, transpose);
}

} // namespace vkcompute
26 changes: 26 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Transpose.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/backends/vulkan/runtime/api/api.h>

#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>

#include <vector>

namespace vkcompute {

void add_transpose_view_node(
ComputeGraph& graph,
ValueRef input_ref,
ValueRef dim0_ref,
ValueRef dim1_ref,
ValueRef out_ref);

} // namespace vkcompute
25 changes: 25 additions & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,31 @@ def get_slice_inputs():
return test_suite


@register_test_suite(["aten.transpose.int"])
def get_transpose_inputs():
Test = namedtuple("VkTransposeViewTest", ["self", "dim0", "dim1"])
Test.__new__.__defaults__ = (None, 0, 1)

test_cases = [
Test(self=[M1, M2], dim0=0, dim1=1),
Test(self=[M1, S2, M], dim0=0, dim1=1),
Test(self=[M1, S2, M], dim0=0, dim1=2),
Test(self=[M1, S2, M], dim0=2, dim1=1),
Test(self=[S, M, S2, M2], dim0=3, dim1=2),
Test(self=[S, M, S2, M2], dim0=1, dim1=2),
Test(self=[S, M, S2, M2], dim0=3, dim1=1),
]

test_suite = VkTestSuite([tuple(tc) for tc in test_cases])

test_suite.dtypes = ["at::kFloat"]
test_suite.storage_types = ["utils::kBuffer", "utils::kTexture3D"]
test_suite.layouts = ["utils::kWidthPacked", "utils::kChannelsPacked"]
test_suite.data_gen = "make_seq_tensor"
test_suite.is_view_op = True
return test_suite


@register_test_suite("aten.index_select.default")
def get_index_select_inputs():
Test = namedtuple("VkIndexSelectTest", ["self", "dim", "index"])
Expand Down
11 changes: 10 additions & 1 deletion backends/vulkan/test/op_tests/utils/gen_computegraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def create_value_for( # noqa: C901
return ret_str

prepack = self.prepack_ref(ref)
ref_is_view = self.suite_def.is_view_op and ref.is_out

cpp_type = "IOValueRef" if (ref.is_in and not prepack) else "ValueRef"
if not include_declarations:
Expand Down Expand Up @@ -362,7 +363,15 @@ def create_value_for( # noqa: C901
ret_str = f"IOValueRef {ref.name};\n"
ret_str += f"{ref.name}.value = {self.graph}{self.dot}"

if ref.src_cpp_type == AT_TENSOR and not prepack:
if ref.src_cpp_type == AT_TENSOR and ref_is_view:
input_name = None
for _name, ref in self.refs.items():
if ref.is_in and ref.src_cpp_type == AT_TENSOR:
input_name = ref.name

assert input_name is not None
ret_str += f"add_tensor_view({input_name}.value);"
elif ref.src_cpp_type == AT_TENSOR and not prepack:
ret_str += "add_input_tensor(" if ref.is_in else "add_tensor("
ret_str += f"{ref.src_cpp_name}.sizes().vec(), "
ret_str += f"from_at_scalartype({ref.src_cpp_name}.scalar_type())); \n"
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/test/op_tests/utils/test_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def __init__(self, input_cases: List[Any]):
self.atol: str = "1e-5"
self.rtol: str = "1e-5"

self.is_view_op: bool = False

def supports_prepack(self):
return len(self.prepacked_args) > 0

Expand Down
95 changes: 93 additions & 2 deletions backends/vulkan/test/vulkan_compute_api_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ TEST_F(VulkanComputeAPITest, virtual_transpose_test) {
// (dim0, dim1), new_sizes, new_dim_order, new_axis_map, new_packed_dim_idx
std::vector<std::vector<std::vector<int64_t>>> test_cases = {
{{2, 3}, {7, 9, 13, 11}, {0, 1, 3, 2}, {1, 0, 2, 2}, {1}},
{{2, 1}, {7, 11, 9, 13}, {0, 2, 1, 3}, {0, 2, 1, 2}, {0}},
{{1, 3}, {7, 13, 11, 9}, {0, 3, 2, 1}, {2, 1, 0, 2}, {2}},
{{2, 1}, {7, 11, 9, 13}, {0, 2, 1, 3}, {0, 2, 1, 1}, {0}},
{{1, 3}, {7, 13, 11, 9}, {0, 3, 2, 1}, {2, 1, 0, 0}, {2}},
};

for (const auto& test_case : test_cases) {
Expand Down Expand Up @@ -3039,3 +3039,94 @@ TEST(VulkanComputeGraphOpsTest, int4pack_mm_test) {
test_int4pack_mm({37, 256, 19}, 64, storage_type);
}
}

void test_transpose_view_mm(
const int B,
const int M,
const int K,
const int N,
utils::StorageType storage_type) {
GraphConfig config;
config.set_storage_type_override(storage_type);
ComputeGraph graph(config);

std::vector<int64_t> mat1_size = {M, K};
std::vector<int64_t> mat2_t_size = {N, K};
std::vector<int64_t> out_size = {M, N};

std::vector<int64_t> mat1_small_size = {M - 4, K - 3};
std::vector<int64_t> mat2_t_small_size = {N - 1, K - 3};

if (B > 1) {
mat1_size.resize(3);
mat1_size = {B, M, K};
mat2_t_size.resize(3);
mat2_t_size = {B, N, K};
out_size.resize(3);
out_size = {B, M, N};

mat1_small_size.resize(3);
mat1_small_size = {B, M - 4, K - 3};
mat2_t_small_size.resize(3);
mat2_t_small_size = {B, N - 1, K - 3};
}

// Build graph

IOValueRef mat1 =
graph.add_input_tensor(mat1_size, vkapi::kFloat, utils::kWidthPacked);
IOValueRef mat2_transpose =
graph.add_input_tensor(mat2_t_size, vkapi::kFloat, utils::kWidthPacked);

ValueRef mat2 = graph.add_tensor_view(mat2_transpose.value);

ValueRef dim0;
ValueRef dim1;

if (B > 1) {
dim0 = graph.add_scalar<int64_t>(1);
dim1 = graph.add_scalar<int64_t>(2);
} else {
dim0 = graph.add_scalar<int64_t>(0);
dim1 = graph.add_scalar<int64_t>(1);
}

IOValueRef out;
out.value = graph.add_tensor(out_size, vkapi::kFloat, utils::kWidthPacked);

VK_GET_OP_FN("aten.transpose.int")
(graph, {mat2_transpose.value, dim0, dim1, mat2});
VK_GET_OP_FN("aten.mm.default")(graph, {mat1.value, mat2, out.value});

out.staging = graph.set_output_tensor(out.value);

graph.prepare();
graph.encode_prepack();
graph.prepack();
graph.encode_execute();

for (int i = 1; i < 4; i++) {
float val_mat1 = i;
float val_mat2 = i + 1;
float val_out = K * (val_mat1 * val_mat2);

// Try at full size
graph.resize_input(0, mat1_size);
graph.resize_input(1, mat2_t_size);
graph.propagate_resize();
execute_graph_and_check_output(graph, {val_mat1, val_mat2}, {val_out});

// Try at reduced sizes
val_out = (K - 3) * (val_mat1 * val_mat2);
graph.resize_input(0, mat1_small_size);
graph.resize_input(1, mat2_t_small_size);
graph.propagate_resize();
execute_graph_and_check_output(graph, {val_mat1, val_mat2}, {val_out});
}
}

TEST(VulkanComputeGraphOpsTest, test_transpose_with_mm) {
for (auto storage_type : {utils::kBuffer, utils::kTexture3D}) {
test_transpose_view_mm(2, 7, 17, 5, storage_type);
}
}
Loading