Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions backends/arm/runtime/VGFSetup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ namespace vgf {
/* static function to map format to byte count */
static uint32_t get_format_size(VkFormat format);

// SPV_ARM_tensor does not support rank-0 representations according to the spec.
// Use an unsqueezed dimension when the resource table contains an empty
// shape. Tensors are output as rank 0 when copied back from the vgf backend.
namespace {
constexpr int64_t kScalarSentinelDimension = 1;
}

// Debug function to inspect memory properties
static string memory_flags_to_string(VkMemoryPropertyFlags flags) {
if (flags == 0)
Expand Down Expand Up @@ -264,7 +271,11 @@ static void debug_print_resources(
the_shape.size(),
the_stride.size());
for (int j = 0; j < the_shape.size(); j++) {
ET_LOG(Info, " %d: dim %ld", j, the_shape[j]);
ET_LOG(
Info,
" %d: dim %lld",
j,
static_cast<long long>(the_shape[j]));
}
// Allocate a tensor with bound memory
break;
Expand Down Expand Up @@ -387,6 +398,7 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef<CompileSpec> specs) {
// Get tensor shape and strides
auto shape = resource_decoder->getTensorShape(i);
auto stride = resource_decoder->getTensorStride(i);
const auto shape_size = shape.size();

switch (resource_decoder->getCategory(i)) {
case vgflib::ResourceCategory::INPUT:
Expand All @@ -409,9 +421,9 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef<CompileSpec> specs) {
result = allocate_tensor(
vk_physical,
vk_device,
vgflib::ToVkFormat(resource_decoder->getVkFormat(i)),
static_cast<uint32_t>(shape.size()),
shape.begin(),
resource_format,
shape_size == 0 ? 1 : static_cast<uint32_t>(shape_size),
shape_size == 0 ? &kScalarSentinelDimension : shape.begin(),
static_cast<uint32_t>(stride.size()),
stride.begin(),
&tensor_description,
Expand All @@ -422,8 +434,7 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef<CompileSpec> specs) {
ET_LOG(Error, "Failed to allocate tensor for VGF resource %d", i);
return false;
}
size_t e_size = get_format_size(
vgflib::ToVkFormat(resource_decoder->getVkFormat(i)));
size_t e_size = get_format_size(resource_format);
if (0 == e_size) {
ET_LOG(Error, "failed to get element size of VkFormat");
return false;
Expand All @@ -449,9 +460,11 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef<CompileSpec> specs) {
.sType = VK_STRUCTURE_TYPE_TENSOR_DESCRIPTION_ARM,
.pNext = nullptr,
.tiling = VK_TENSOR_TILING_LINEAR_ARM,
.format = vgflib::ToVkFormat(resource_decoder->getVkFormat(i)),
.dimensionCount = static_cast<uint32_t>(shape.size()),
.pDimensions = shape.begin(),
.format = resource_format,
.dimensionCount =
shape_size == 0 ? 1 : static_cast<uint32_t>(shape_size),
.pDimensions =
shape_size == 0 ? &kScalarSentinelDimension : shape.begin(),
// Note: stride_data of 0's causes size==0, null means stride==size
.pStrides = (0 == stride.size() ? nullptr : stride.begin()),
.usage = VK_TENSOR_USAGE_DATA_GRAPH_BIT_ARM,
Expand Down
1 change: 0 additions & 1 deletion backends/arm/test/ops/test_mean_dim.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
Expand Down
1 change: 0 additions & 1 deletion backends/arm/test/ops/test_scalar_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.arm.test import common

Expand Down
Loading