Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 24 additions & 25 deletions backends/vulkan/runtime/api/containers/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ namespace api {

std::vector<int64_t> calculate_dim_order(
const size_t ndim,
const int32_t packed_dim_whcn_idx) {
const int32_t packed_dim) {
// Special case for zero dim tensors
if (ndim == 0) {
return {0};
}
std::vector<int64_t> dim_order(ndim);
int64_t last_dim = ndim - 1 - packed_dim_whcn_idx;
int64_t last_dim = ndim - 1 - packed_dim;

int64_t cur_dim = 0;
for (int d = 0; d < ndim; ++d) {
Expand Down Expand Up @@ -131,7 +131,7 @@ std::vector<int64_t> unsqueeze_strides(

std::vector<int64_t> calculate_padded_sizes(
const std::vector<int64_t>& sizes,
const int32_t packed_dim_whcn_idx) {
const int32_t packed_dim) {
int64_t ndim = sizes.size();
if (ndim == 0) {
ndim = 1;
Expand All @@ -145,7 +145,7 @@ std::vector<int64_t> calculate_padded_sizes(
}

// Pad the packed dim to the next multiple of 4.
const int64_t dim_offset = packed_dim_whcn_idx + 1;
const int64_t dim_offset = packed_dim + 1;
const int64_t padded_dim_size = utils::val_at(-dim_offset, sizes);
padded_sizes.at(ndim_up4 - dim_offset) = utils::align_up_4(padded_dim_size);

Expand All @@ -155,7 +155,7 @@ std::vector<int64_t> calculate_padded_sizes(
utils::uvec3 calculate_image_extents(
const std::vector<int64_t>& padded_sizes,
const std::vector<int64_t>& axis_map,
const int32_t packed_dim_whcn_idx) {
const int32_t packed_dim) {
VK_CHECK_COND(padded_sizes.size() == 4);
VK_CHECK_COND(axis_map.size() == 4);

Expand All @@ -176,8 +176,8 @@ utils::uvec3 calculate_image_extents(
// Multiply the extents of the batch axis by the batch size.
extents[batch_axis] *= padded_sizes.at(0);

VK_CHECK_COND(extents[axis_map.at(packed_dim_whcn_idx)] % 4 == 0);
extents[axis_map.at(packed_dim_whcn_idx)] /= 4;
VK_CHECK_COND(extents[axis_map.at(packed_dim)] % 4 == 0);
extents[axis_map.at(packed_dim)] /= 4;
return extents;
}

Expand Down Expand Up @@ -254,14 +254,14 @@ vTensorStorage::vTensorStorage(
Context* const context,
const utils::StorageType storage_type,
const std::vector<int64_t>& axis_map,
const int32_t packed_dim_whcn_idx,
const int32_t packed_dim,
const std::vector<int64_t>& padded_sizes,
const vkapi::ScalarType dtype,
const bool allocate_memory)
: context_(context),
storage_type_{storage_type},
image_extents_(
calculate_image_extents(padded_sizes, axis_map, packed_dim_whcn_idx)),
calculate_image_extents(padded_sizes, axis_map, packed_dim)),
buffer_length_{utils::multiply_integers(padded_sizes)},
buffer_offset_{0},
image_(allocate_image(
Expand Down Expand Up @@ -378,13 +378,12 @@ vTensor::vTensor(
: dtype_(dtype),
// Calculate tensor metadata
sizes_(sizes.begin(), sizes.end()),
packed_dim_whcn_idx_(
utils::to_packed_dim_whcn_idx<int32_t>(memory_layout)),
dim_order_(calculate_dim_order(sizes_.size(), packed_dim_whcn_idx_)),
packed_dim_(utils::to_packed_dim<int32_t>(memory_layout)),
dim_order_(calculate_dim_order(sizes_.size(), packed_dim_)),
axis_map_(default_axis_map()),
strides_(calculate_strides(sizes, dim_order_)),
numel_(utils::multiply_integers(sizes_)),
padded_sizes_{calculate_padded_sizes(sizes, packed_dim_whcn_idx_)},
padded_sizes_{calculate_padded_sizes(sizes, packed_dim_)},
unsqueezed_strides_{unsqueeze_strides(strides_, numel_)},
padded_numel_(utils::multiply_integers(padded_sizes_)),
logical_limits_{{0, 0, 0}},
Expand All @@ -399,7 +398,7 @@ vTensor::vTensor(
context,
storage_type,
axis_map_,
packed_dim_whcn_idx_,
packed_dim_,
padded_sizes_,
dtype_,
allocate_memory) {
Expand All @@ -422,7 +421,7 @@ vTensor::vTensor(const vTensor& other)
: dtype_(other.dtype_),
// Copy tensor size metadata
sizes_(other.sizes_.begin(), other.sizes_.end()),
packed_dim_whcn_idx_{other.packed_dim_whcn_idx_},
packed_dim_{other.packed_dim_},
dim_order_(other.dim_order_.begin(), other.dim_order_.end()),
axis_map_(other.axis_map_.begin(), other.axis_map_.end()),
strides_(other.strides_.begin(), other.strides_.end()),
Expand Down Expand Up @@ -450,12 +449,12 @@ vTensor::vTensor(
: dtype_(other.dtype_),
// Copy tensor size metadata
sizes_(sizes.begin(), sizes.end()),
packed_dim_whcn_idx_(other.packed_dim_whcn_idx_),
packed_dim_(other.packed_dim_),
dim_order_(dim_order.begin(), dim_order.end()),
axis_map_(default_axis_map()),
strides_(calculate_strides(sizes_, dim_order_)),
numel_(utils::multiply_integers(sizes_)),
padded_sizes_{calculate_padded_sizes(sizes, packed_dim_whcn_idx_)},
padded_sizes_{calculate_padded_sizes(sizes, packed_dim_)},
unsqueezed_strides_{unsqueeze_strides(strides_, numel_)},
padded_numel_(utils::multiply_integers(padded_sizes_)),
logical_limits_(other.logical_limits_),
Expand Down Expand Up @@ -512,7 +511,7 @@ void vTensor::set_logical_limits(const utils::uvec3& image_extents) {
}

utils::GPUMemoryLayout vTensor::estimate_memory_layout() const {
switch (packed_dim_whcn_idx_) {
switch (packed_dim_) {
case WHCN::kWidthDim:
return utils::kWidthPacked;
case WHCN::kHeightDim:
Expand Down Expand Up @@ -602,14 +601,14 @@ void vTensor::update_metadata() {
strides_ = calculate_strides(sizes_, dim_order_);
numel_ = utils::multiply_integers(sizes_);

padded_sizes_ = calculate_padded_sizes(sizes_, packed_dim_whcn_idx_);
padded_sizes_ = calculate_padded_sizes(sizes_, packed_dim_);
unsqueezed_strides_ = unsqueeze_strides(strides_, numel_);
padded_numel_ = utils::multiply_integers(padded_sizes_);

// Calculate the image extents that would have been used to allocate a texture
// withthe current sizes, and use that to set the logical limits.
set_logical_limits(
calculate_image_extents(padded_sizes_, axis_map_, packed_dim_whcn_idx_));
calculate_image_extents(padded_sizes_, axis_map_, packed_dim_));

if (sizes_uniform_.buffer()) {
sizes_uniform_.update(utils::make_whcn_ivec4(sizes_));
Expand All @@ -633,7 +632,7 @@ void vTensor::check_sizes(const std::vector<int64_t>& sizes) const {
// For texture storage check that the current texture is large enough for
// the new sizes of the tensor.
utils::uvec3 virtual_extents =
calculate_image_extents(padded_sizes_, axis_map_, packed_dim_whcn_idx_);
calculate_image_extents(padded_sizes_, axis_map_, packed_dim_);

bool valid_resize = virtual_extents[0] <= storage_.image_extents_[0];
valid_resize =
Expand Down Expand Up @@ -705,11 +704,11 @@ void vTensor::virtual_transpose(const int64_t dim0, const int64_t dim1) {

const int dim0_whcn = sizes_.size() - 1 - dim0;
const int dim1_whcn = sizes_.size() - 1 - dim1;
if (packed_dim_whcn_idx_ == dim0_whcn) {
packed_dim_whcn_idx_ = dim1_whcn;
if (packed_dim_ == dim0_whcn) {
packed_dim_ = dim1_whcn;
}
if (packed_dim_whcn_idx_ == dim1_whcn) {
packed_dim_whcn_idx_ = dim0_whcn;
if (packed_dim_ == dim1_whcn) {
packed_dim_ = dim0_whcn;
}

if (storage_type() == utils::kBuffer) {
Expand Down
16 changes: 8 additions & 8 deletions backends/vulkan/runtime/api/containers/Tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace api {
*/
std::vector<int64_t> calculate_dim_order(
const size_t ndim,
const int32_t packed_dim_whcn_idx);
const int32_t packed_dim);

/*
* Given the sizes of a tensor and the dim order of the tensor (both in NCHW)
Expand Down Expand Up @@ -57,15 +57,15 @@ std::vector<int64_t> unsqueeze_strides(
*/
std::vector<int64_t> calculate_padded_sizes(
const std::vector<int64_t>& sizes,
const int32_t packed_dim_whcn_idx);
const int32_t packed_dim);

/*
* Calculate the image extents required of a texture backed tensor.
*/
utils::uvec3 calculate_image_extents(
const std::vector<int64_t>& padded_sizes,
const std::vector<int64_t>& axis_map,
const int32_t packed_dim_whcn_idx);
const int32_t packed_dim);

struct LastAccess {
vkapi::PipelineStageFlags stage;
Expand All @@ -90,7 +90,7 @@ class vTensorStorage final {
Context* context,
const utils::StorageType storage_type,
const std::vector<int64_t>& axis_map,
const int32_t packed_dim_whcn_idx,
const int32_t packed_dim,
const std::vector<int64_t>& padded_sizes,
const vkapi::ScalarType dtype,
const bool allocate_memory = true);
Expand Down Expand Up @@ -228,7 +228,7 @@ class vTensor final {
// which dimension is packed along a texel. For buffer backed tensors, this
// describes which dimension has a stride of 1 (i.e. is last in the dim
// order).
int32_t packed_dim_whcn_idx_;
int32_t packed_dim_;

/*
* "Layout" metadata. These describe with further detail how tensor data is
Expand Down Expand Up @@ -378,12 +378,12 @@ class vTensor final {
* tensor. In some scenarios, the exact layout of the tensor may not be able
* to be replicated due to calling `virtual_*()` functions after construction;
* however, this function will provide a memory layout that will produce the
* same `packed_dim_whcn_idx` as this tensor.
* same `packed_dim_` as this tensor.
*/
utils::GPUMemoryLayout estimate_memory_layout() const;

inline int32_t packed_dim_whcn_idx() const {
return packed_dim_whcn_idx_;
inline int32_t packed_dim() const {
return packed_dim_;
}

inline const std::vector<int64_t>& sizes() const {
Expand Down
4 changes: 2 additions & 2 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,8 @@ class ComputeGraph final {
return values_.at(idx).toConstTensor().estimate_memory_layout();
}

inline int32_t packed_dim_whcn_idx_of(const ValueRef idx) const {
return values_.at(idx).toConstTensor().packed_dim_whcn_idx();
inline int32_t packed_dim_of(const ValueRef idx) const {
return values_.at(idx).toConstTensor().packed_dim();
}

inline vkapi::BufferBindInfo sizes_ubo(const ValueRef idx) {
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ void add_binary_op_node(
graph.create_params_buffer(broadcast_params),
graph.create_params_buffer(alpha_val)},
// Specialization Constants
{SV(t_out->packed_dim_whcn_idx())},
{SV(t_out->packed_dim())},
// Resizing Logic
resize_binary_op_node,
{}));
Expand Down
4 changes: 2 additions & 2 deletions backends/vulkan/runtime/graph/ops/impl/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ ValueRef prepack_biases(
v,
{t->sizes_ubo(), t->axis_map_ubo()},
// Specialization constants
{SV(t->packed_dim_whcn_idx())}));
{SV(t->packed_dim())}));

return v;
}
Expand Down Expand Up @@ -216,7 +216,7 @@ ValueRef prepack_weights(
graph.create_params_buffer(
utils::make_ivec4(original_sizes, /*reverse = */ true))},
// Specialization constants
{SV(t->packed_dim_whcn_idx())}));
{SV(t->packed_dim())}));

return v;
}
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/impl/Full.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ void add_full_node(
// Shader params buffers
{t_out->sizes_ubo(), graph.create_params_buffer(fill_value_val)},
// Specialization Constants
{SV(t_out->packed_dim_whcn_idx())},
{SV(t_out->packed_dim())},
// Resizing Logic
resize_full_node,
{size_or_in}));
Expand Down
17 changes: 8 additions & 9 deletions backends/vulkan/runtime/graph/ops/impl/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ void check_addmm_args(
VK_CHECK_COND(mat1_sizes.size() == 2 || mat1_sizes.size() == 3);
VK_CHECK_COND(mat1_sizes.size() == mat2_sizes.size());

VK_CHECK_COND(
graph.packed_dim_whcn_idx_of(mat1) == graph.packed_dim_whcn_idx_of(out));
VK_CHECK_COND(graph.packed_dim_of(mat1) == graph.packed_dim_of(out));

VK_CHECK_COND(utils::val_at(-1, mat1_sizes) == utils::val_at(-2, mat2_sizes));

Expand Down Expand Up @@ -127,10 +126,10 @@ void add_addmm_naive_node(
graph.create_params_buffer(params),
},
// Specialization Constants
{graph.packed_dim_whcn_idx_of(out),
graph.packed_dim_whcn_idx_of(mat1),
graph.packed_dim_whcn_idx_of(mat2),
graph.packed_dim_whcn_idx_of(self)},
{graph.packed_dim_of(out),
graph.packed_dim_of(mat1),
graph.packed_dim_of(mat2),
graph.packed_dim_of(self)},
// Resizing Logic
resize_addmm_node,
{mat2_is_transposed}));
Expand Down Expand Up @@ -221,7 +220,7 @@ void add_addmm_optimized_node(
graph.create_params_buffer(params),
},
// Specialization Constants
{graph.packed_dim_whcn_idx_of(out)},
{graph.packed_dim_of(out)},
// Resizing Logic
resize_addmm_node,
{mat2_is_transposed}));
Expand All @@ -247,10 +246,10 @@ void add_addmm_node(
}

Params params = {alpha_val, beta_val};
if (graph.packed_dim_whcn_idx_of(mat1) == WHCN::kChannelsDim) {
if (graph.packed_dim_of(mat1) == WHCN::kChannelsDim) {
add_addmm_optimized_node(
graph, self, mat1, mat2, beta, alpha, out, params, mat2_is_transposed);
} else if (graph.packed_dim_whcn_idx_of(mat1) == WHCN::kWidthDim) {
} else if (graph.packed_dim_of(mat1) == WHCN::kWidthDim) {
add_addmm_naive_node(
graph, self, mat1, mat2, beta, alpha, out, params, mat2_is_transposed);
} else {
Expand Down
15 changes: 7 additions & 8 deletions backends/vulkan/runtime/graph/ops/impl/MatMul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ void check_matmul_args(
VK_CHECK_COND(mat1_sizes.size() == 2 || mat1_sizes.size() == 3);
VK_CHECK_COND(mat1_sizes.size() == mat2_sizes.size());

VK_CHECK_COND(
graph.packed_dim_whcn_idx_of(mat1) == graph.packed_dim_whcn_idx_of(out));
VK_CHECK_COND(graph.packed_dim_of(mat1) == graph.packed_dim_of(out));

VK_CHECK_COND(utils::val_at(-1, mat1_sizes) == utils::val_at(-2, mat2_sizes));
}
Expand Down Expand Up @@ -139,9 +138,9 @@ void add_matmul_naive_texture3d_node(
graph.axis_map_ubo(mat2),
},
// Specialization Constants
{graph.packed_dim_whcn_idx_of(out),
graph.packed_dim_whcn_idx_of(mat1),
graph.packed_dim_whcn_idx_of(mat2)},
{graph.packed_dim_of(out),
graph.packed_dim_of(mat1),
graph.packed_dim_of(mat2)},
// Resizing Logic
resize_matmul_node,
{mat2_is_transposed}));
Expand Down Expand Up @@ -223,7 +222,7 @@ void add_matmul_optimized_node(
graph.axis_map_ubo(mat2_packed),
},
// Specialization Constants
{graph.packed_dim_whcn_idx_of(out)},
{graph.packed_dim_of(out)},
// Resizing Logic
resize_matmul_node,
{mat2_is_transposed}));
Expand All @@ -238,9 +237,9 @@ void add_matmul_node(
if (graph.is_buffer_storage(out)) {
add_matmul_naive_buffer_node(
graph, mat1, mat2_data, out, mat2_is_transposed);
} else if (graph.packed_dim_whcn_idx_of(mat1) == WHCN::kChannelsDim) {
} else if (graph.packed_dim_of(mat1) == WHCN::kChannelsDim) {
add_matmul_optimized_node(graph, mat1, mat2_data, out, mat2_is_transposed);
} else if (graph.packed_dim_whcn_idx_of(mat1) == WHCN::kWidthDim) {
} else if (graph.packed_dim_of(mat1) == WHCN::kWidthDim) {
add_matmul_naive_texture3d_node(
graph, mat1, mat2_data, out, mat2_is_transposed);
} else {
Expand Down
7 changes: 3 additions & 4 deletions backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ void check_qlinear_args(
VK_CHECK_COND(qmat2_sizes.size() == 2);
VK_CHECK_COND(scales_sizes.size() == 1);

VK_CHECK_COND(
graph.packed_dim_whcn_idx_of(mat1) == graph.packed_dim_whcn_idx_of(out));
VK_CHECK_COND(graph.packed_dim_of(mat1) == graph.packed_dim_of(out));

VK_CHECK_COND(
utils::val_at(-1, mat1_sizes) == utils::val_at(-1, qmat2_sizes));
Expand Down Expand Up @@ -79,8 +78,8 @@ void add_q_8w_linear_node(

std::string kernel_name = "q_8w_linear";
kernel_name.reserve(kShaderNameReserve);
add_packed_dim_suffix(kernel_name, graph.packed_dim_whcn_idx_of(mat1));
add_packed_dim_suffix(kernel_name, graph.packed_dim_whcn_idx_of(q_mat2));
add_packed_dim_suffix(kernel_name, graph.packed_dim_of(mat1));
add_packed_dim_suffix(kernel_name, graph.packed_dim_of(q_mat2));
add_dtype_suffix(kernel_name, graph.dtype_of(out));
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));

Expand Down
Loading