Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions backends/vulkan/runtime/api/containers/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,17 @@ void vTensorStorage::transition(
last_access_.access = cur_access;
}

bool vTensorStorage::is_copy_of(const vTensorStorage& other) const {
if (storage_type_ != other.storage_type_) {
return false;
}
// Copies are only enabled for buffer storage at the moment
if (storage_type_ != utils::kBuffer) {
return false;
}
return buffer_.is_copy_of(other.buffer_);
}

void vTensorStorage::discard_and_reallocate(
const std::vector<int64_t>& padded_sizes,
const utils::GPUMemoryLayout gpu_memory_layout,
Expand Down
12 changes: 12 additions & 0 deletions backends/vulkan/runtime/api/containers/Tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,11 @@ class vTensorStorage final {
return image_.format();
}

/*
* Used for checking if this vTensorStorage is a copy of another instance
*/
bool is_copy_of(const vTensorStorage& other) const;

void discard_and_reallocate(
const std::vector<int64_t>& padded_sizes,
const utils::GPUMemoryLayout gpu_memory_layout,
Expand Down Expand Up @@ -458,6 +463,13 @@ class vTensor final {
* tensor sizes
*/
void reallocate(const std::vector<int64_t>& new_sizes);

/*
* Check if this vTensor instance is a view of another vTensor instance
*/
inline bool is_view_of(const vTensor& other) const {
return storage_.is_copy_of(other.storage_);
}
};

} // namespace api
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ utils::uvec3 ComputeGraph::create_local_wg_size(
}

utils::uvec3 ComputeGraph::create_local_wg_size(const ValueRef idx) {
return create_local_wg_size(image_extents_of(idx));
return create_local_wg_size(create_global_wg_size(idx));
}

void ComputeGraph::copy_into_staging(
Expand Down
22 changes: 22 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,21 @@ class ComputeGraph final {

std::vector<int64_t> sizes_of(const ValueRef idx) const;

/*
* Returns the size of the tensor at `idx` along the specified dimension.
* Negative indexing is allowed.
*/
template <typename T>
T size_at(const int64_t dim, const ValueRef idx) const {
const Value& val = values_.at(idx);
if (val.isTensor()) {
return static_cast<T>(utils::val_at(dim, val.toConstTensor().sizes()));
} else if (val.isTensorRef()) {
return static_cast<T>(utils::val_at(dim, val.toConstTensorRef().sizes));
}
VK_THROW("Could not get sizes of value with type ", val.type());
}

vkapi::ScalarType dtype_of(const ValueRef idx) const;

inline utils::uvec3 image_extents_of(const ValueRef idx) const {
Expand All @@ -204,6 +219,13 @@ class ComputeGraph final {
return values_.at(idx).toConstTensor().has_buffer_storage();
}

inline bool val_is_view_of(const ValueRef maybe_view, const ValueRef base)
const {
return values_.at(maybe_view)
.toConstTensor()
.is_view_of(values_.at(base).toConstTensor());
}

inline utils::GPUMemoryLayout memory_layout_of(const ValueRef idx) const {
return values_.at(idx).toConstTensor().gpu_memory_layout();
}
Expand Down
60 changes: 60 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/matmul_naive_buffer.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#define T ${buffer_scalar_type(DTYPE)}

${define_required_extensions(DTYPE)}

layout(std430) buffer;

${layout_declare_tensor(0, "w", "t_out", DTYPE, "buffer")}
${layout_declare_tensor(1, "r", "t_mat1", DTYPE, "buffer")}
${layout_declare_tensor(2, "r", "t_mat2", DTYPE, "buffer")}
${layout_declare_ubo(3, "ivec4", "out_sizes")}
${layout_declare_ubo(4, "ivec4", "out_strides")}
${layout_declare_ubo(5, "ivec4", "mat1_sizes")}
${layout_declare_ubo(6, "ivec4", "mat1_strides")}
${layout_declare_ubo(7, "ivec4", "mat2_sizes")}
${layout_declare_ubo(8, "ivec4", "mat2_strides")}
${layout_declare_ubo(9, "int", "out_numel")}

#include "indexing_utils.h"

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

void main() {
const ivec4 out_idx = ivec4(
gl_GlobalInvocationID.x,
gl_GlobalInvocationID.y,
gl_GlobalInvocationID.z % out_sizes.z,
gl_GlobalInvocationID.z / out_sizes.z);

if (any(greaterThanEqual(out_idx, out_sizes))) {
return;
}

int mat1_id = to_buffer_id(
ivec4(0, out_idx.y, out_idx.z, out_idx.w), mat1_strides);
int mat2_id = to_buffer_id(
ivec4(out_idx.x, 0, out_idx.z, out_idx.w), mat2_strides);

T sum = T(0.0);
for (int i = 0; i < mat1_sizes.x; ++i) {
sum += t_mat1[mat1_id] * t_mat2[mat2_id];

mat1_id += mat1_strides.x;
mat2_id += mat2_strides.y;
}

const int out_id = to_buffer_id(out_idx, out_strides);
t_out[out_id] = T(sum);
}
16 changes: 16 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/matmul_naive_buffer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

matmul_naive_buffer:
parameter_names_with_default_values:
DTYPE: float
STORAGE: buffer
generate_variant_forall:
DTYPE:
- VALUE: float
- VALUE: half
shader_variants:
- NAME: matmul_naive_buffer
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,11 @@ $if MAT2_IS_TRANSPOSED:
#include "indexing_utils.h"
#include "matmul.h"

layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly image3D im_out;
layout(set = 0, binding = 1) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat1;
layout(set = 0, binding = 2) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat2;

layout(set = 0, binding = 3) uniform PRECISION restrict OutLimits {
ivec3 out_limits;
};

layout(set = 0, binding = 4) uniform PRECISION restrict InSizes {
ivec4 in_sizes;
};
${layout_declare_tensor(0, "w", "im_out", DTYPE, "texture3d")}
${layout_declare_tensor(1, "r", "im_mat1", DTYPE, "texture3d")}
${layout_declare_tensor(2, "r", "im_mat2", DTYPE, "texture3d")}
${layout_declare_ubo(3, "ivec3", "out_limits")}
${layout_declare_ubo(4, "ivec4", "in_sizes")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

matmul_naive:
matmul_naive_texture3d:
parameter_names_with_default_values:
DTYPE: float
NDIM: 3
STORAGE: texture3d
MAT1_PACKING: W_packed
MAT2_PACKING: H_packed
MAT2_IS_TRANSPOSED: false
Expand All @@ -16,9 +16,9 @@ matmul_naive:
- VALUE: float
- VALUE: half
shader_variants:
- NAME: matmul_naive_W_packed_H_packed
- NAME: matmul_naive_W_packed_W_packed
- NAME: matmul_naive_texture3d_W_packed_H_packed
- NAME: matmul_naive_texture3d_W_packed_W_packed
MAT2_PACKING: W_packed
- NAME: matmul_transposed_naive_W_packed_W_packed
- NAME: matmul_transposed_naive_texture3d_W_packed_W_packed
MAT2_PACKING: W_packed
MAT2_IS_TRANSPOSED: true
54 changes: 50 additions & 4 deletions backends/vulkan/runtime/graph/ops/impl/MatMul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,48 @@ void resize_matmul_node(
out->virtual_resize(new_out_sizes);
}

void add_matmul_naive_node(
void add_matmul_naive_buffer_node(
ComputeGraph& graph,
const ValueRef mat1,
const ValueRef mat2_data,
const ValueRef out,
const ValueRef mat2_is_transposed) {
ValueRef mat2 = prepack_if_tensor_ref(graph, mat2_data, utils::kHeightPacked);

std::string kernel_name = "matmul_naive_buffer";
add_dtype_suffix(kernel_name, graph.dtype_of(out));

utils::uvec3 global_size = {
graph.size_at<uint32_t>(-1, out),
graph.size_at<uint32_t>(-2, out),
graph.size_at<uint32_t>(-3, out) * graph.size_at<uint32_t>(-4, out)};

graph.execute_nodes().emplace_back(new ExecuteNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
global_size,
graph.create_local_wg_size(global_size),
// Inputs and Outputs
{{out, vkapi::MemoryAccessType::WRITE},
{{mat1, mat2}, vkapi::MemoryAccessType::READ}},
// Shader params buffers
{
graph.sizes_ubo(out),
graph.strides_ubo(out),
graph.sizes_ubo(mat1),
graph.strides_ubo(mat1),
graph.sizes_ubo(mat2),
graph.strides_ubo(mat2),
graph.numel_ubo(out),
},
// Specialization Constants
{},
// Resizing Logic
resize_matmul_node,
{mat2_is_transposed}));
}

void add_matmul_naive_texture3d_node(
ComputeGraph& graph,
const ValueRef mat1,
const ValueRef mat2_data,
Expand All @@ -74,6 +115,7 @@ void add_matmul_naive_node(
? "matmul_transposed_naive"
: "matmul_naive";
kernel_name.reserve(kShaderNameReserve);
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
add_memory_layout_suffix(kernel_name, graph.memory_layout_of(mat1));
add_memory_layout_suffix(kernel_name, graph.memory_layout_of(mat2));
add_dtype_suffix(kernel_name, graph.dtype_of(out));
Expand Down Expand Up @@ -174,12 +216,16 @@ void add_matmul_node(
const ValueRef mat2_data,
const ValueRef out,
const ValueRef mat2_is_transposed) {
if (graph.memory_layout_of(mat1) == utils::kChannelsPacked) {
if (graph.is_buffer_storage(out)) {
add_matmul_naive_buffer_node(
graph, mat1, mat2_data, out, mat2_is_transposed);
} else if (graph.memory_layout_of(mat1) == utils::kChannelsPacked) {
add_matmul_optimized_node(graph, mat1, mat2_data, out, mat2_is_transposed);
} else if (graph.memory_layout_of(mat1) == utils::kWidthPacked) {
add_matmul_naive_node(graph, mat1, mat2_data, out, mat2_is_transposed);
add_matmul_naive_texture3d_node(
graph, mat1, mat2_data, out, mat2_is_transposed);
} else {
VK_THROW("Input should be channel packed or width packed.");
VK_THROW("Input texture should be channel packed or width packed.");
}
}

Expand Down
4 changes: 4 additions & 0 deletions backends/vulkan/runtime/vk_api/memory/Buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ class VulkanBuffer final {
return (handle_ != VK_NULL_HANDLE);
}

inline bool is_copy_of(const VulkanBuffer& other) const {
return (handle_ == other.handle_) && is_copy_;
}

inline void bind_allocation(const Allocation& memory) {
VK_CHECK_COND(!memory_, "Cannot bind an already bound allocation!");
VK_CHECK(vmaBindBufferMemory(allocator_, memory.allocation, handle_));
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def get_mm_inputs():
test_suite.prepacked_args = ["mat2"]
# ATen matmul doesn't support half
test_suite.dtypes = ["at::kFloat"]
test_suite.storage_types = ["utils::kTexture3D", "utils::kBuffer"]
test_suite.layouts = [
"utils::kWidthPacked",
"utils::kChannelsPacked",
Expand Down
Loading