Skip to content
Closed
18 changes: 18 additions & 0 deletions backends/arm/arm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(self):
self.output_format = None
self.path_for_intermediates = None
self.permute_nhwc = False
self.quantize_io = False

def ethosu_compile_spec(
self,
Expand Down Expand Up @@ -101,9 +102,21 @@ def dump_intermediate_tosa(self, output_path: str):
return self

def set_permute_memory_format(self, set_nhwc_permutation: bool = True):
"""
Permute to channel last in compiler and runtime. Compilation and
runtime will convert rank 4 inputs to channel last for each sub-graph.
"""
self.permute_nhwc = set_nhwc_permutation
return self

def set_quantize_io(self, quantize_io: bool = False):
"""
Quantization of inputs and dequantization of outputs for cases where
whole graph is quantized and method signature is not of quantized type.
"""
self.quantize_io = quantize_io
return self

def build(self):
"""
Generate a list of compile spec objects from the builder
Expand All @@ -126,6 +139,9 @@ def build(self):
CompileSpec("permute_memory_format", "nhwc".encode())
)

if self.quantize_io:
self.compile_spec.append(CompileSpec("quantize_io", "True".encode()))

return self.compile_spec


Expand Down Expand Up @@ -153,6 +169,7 @@ def get_intermediate_path(compile_spec: List[CompileSpec]) -> str:
def generate_ethosu_compile_spec(
config: str,
permute_memory_to_nhwc: Optional[bool] = None,
quantize_io: Optional[bool] = None,
system_config: Optional[str] = None,
memory_mode: Optional[str] = None,
extra_flags: Optional[str] = None,
Expand All @@ -168,6 +185,7 @@ def generate_ethosu_compile_spec(
config_ini=config_ini,
)
.set_permute_memory_format(permute_memory_to_nhwc)
.set_quantize_io(quantize_io)
.build()
)

Expand Down
63 changes: 63 additions & 0 deletions backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:

supported &= self.is_node_supported_custom(node)

# Override partitioning based on pre partition passes
if supported and "arm_partition" in node.meta:
supported = supported & node.meta["arm_partition"]
node.meta.pop("arm_partition")

return supported

def is_node_supported_custom(self, node: torch.fx.Node) -> bool:
Expand All @@ -64,6 +69,54 @@ def is_node_supported_custom(self, node: torch.fx.Node) -> bool:
return True


from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.passes import PassManager


class TagIOQuant(ExportPass):
"""
Pass run before partitioning to tag Q/DQ on any placeholder and output
to ensure we don't greedily partition them for device. Float conversion
has to happen outside a TOSA base inference profile.
"""

def __init__(self, edge_program: torch.export.ExportedProgram):
super(TagIOQuant, self).__init__()
self.edge_program = edge_program

def is_quant_node(self, node: torch.fx.node.Node):
return node.target in {
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
}

def is_dequant_node(self, node: torch.fx.node.Node):
return node.target in {
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
}

def call(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
# tag q of input
if node.op == "placeholder":
for user in node.users.keys():
# if we have an input going into a quantize
if self.is_quant_node(user):
user.meta["arm_partition"] = False

# tag dq of outputs
if node.op == "output":
quant, *_ = node.args[0]
if self.is_dequant_node(quant):
quant.meta["arm_partition"] = False

graph_module.recompile()
return PassResult(graph_module, True)


@final
class ArmPartitioner(Partitioner):
def __init__(self, compile_spec: List[CompileSpec]) -> None:
Expand All @@ -75,6 +128,16 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
logger.info("ArmPartitioner::partition")
partition_tags = {}

for spec in self.delegation_spec.compile_specs:
if spec.key == "quantize_io" and spec.value.decode() == "True":
# Exclude IO quantization from the partition
passes = PassManager(
passes=[
TagIOQuant(exported_program),
]
)
passes(exported_program.graph_module)

capability_partitioner = CapabilityBasedPartitioner(
exported_program.graph_module,
TOSASupportedOperators(),
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/operators/op_placeholder.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ def process_placeholder(
buffer_values.shape, inputs[0].dtype, buffer_values, name=out
)
else:
# Cases for all the input tensors
if permute_memory_to_nhwc:
# Cases for all the input tensors of rank4
if permute_memory_to_nhwc and len(inputs[0].shape) == 4:
NHWC_Order = [0, 2, 3, 1]
input_shape = [inputs[0].shape[i] for i in NHWC_Order]
else:
Expand Down
163 changes: 127 additions & 36 deletions backends/arm/runtime/ArmBackendEthosU.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 Arm Limited and/or its affiliates.
* Copyright 2023-2024 Arm Limited and/or its affiliates.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
Expand All @@ -26,12 +26,10 @@ using namespace std;
namespace torch {
namespace executor {

// TODO: we should be in 0x31, to access a full 2MB SRAM
// region and enable maximum program performance up to
// 2MB, rather than 1.
// SRAM (rwx) : ORIGIN = 0x31000000, LENGTH = 0x00200000
#define CS300_SRAM_LOW ((void*)0x11000000)
#define CS300_SRAM_HIGH ((void*)0x110FFFFF)
typedef struct {
FreeableBuffer* processed;
bool permuted_io_flag;
} ExecutionHandle;

class ArmBackend final : public PyTorchBackendInterface {
public:
Expand Down Expand Up @@ -60,40 +58,37 @@ class ArmBackend final : public PyTorchBackendInterface {
return Error::InvalidProgram;
}

// Verify address range is accessible current expectation is the program
// is wholly stored in SRAM
// TODO: expect to improve capabilities here by supporting DRAM storage
// and only moving required data into SRAM.
if (!(data > CS300_SRAM_LOW || foot < CS300_SRAM_HIGH)) {
ET_LOG(Error, "ArmBackend::init: Expected program binary to be in SRAM");
ET_LOG(
Error,
"ArmBackend::init: program binary range %p:%p",
data,
foot + 16);
return Error::InvalidProgram;
MemoryAllocator* allocator = context.get_runtime_allocator();
ExecutionHandle* handle =
ET_ALLOCATE_INSTANCE_OR_RETURN_ERROR(allocator, ExecutionHandle);
handle->processed = processed;

for (auto& compile_spec : compile_specs) {
if (0 == std::strcmp(compile_spec.key, "permute_memory_format") &&
0 == std::memcmp(compile_spec.value.buffer, "nhwc", 4)) {
handle->permuted_io_flag = true;
}
}

// Return the same buffer we were passed - this data will be
// executed directly
return processed;
return handle;
}

Error execute(
BackendExecutionContext& context,
DelegateHandle* input_handle,
EValue** args) const override {
FreeableBuffer* processed = (FreeableBuffer*)input_handle;

ET_LOG(Info, "ArmBackend::execute %p", processed->data());

ExecutionHandle* execution_handle = (ExecutionHandle*)input_handle;
VelaHandles handles;

// Command stream - we know at this point it's aligned
char* data = (char*)processed->data();
char* data = (char*)execution_handle->processed->data();
ET_LOG(Info, "ArmBackend::execute %p", data);

// Read key sections from the vela_bin_stream
if (vela_bin_read(data, &handles, processed->size()) == false) {
if (vela_bin_read(data, &handles, execution_handle->processed->size()) ==
false) {
ET_LOG(Error, "ArmBackend::vela_read: error, invalid binary layout");
return Error::InvalidProgram;
}
Expand All @@ -108,18 +103,63 @@ class ArmBackend final : public PyTorchBackendInterface {
handles.scratch_data,
handles.scratch_data_size);

// Write inputs into SRAM scratch area defined by Vela
// Write argument values (from EValue tensor) into Ethos-U scratch
// TODO(MLETORCH-123): Optimise into direct write from Vela into the SRAM
// or DRAM output for compatible data layouts.
for (int i = 0; i < handles.inputs->count; i++) {
const char* input_addr =
handles.scratch_data + handles.inputs->io[i].offset;
// Process input EValue into scratch
// TODO: Optimise into direct write from Vela into the SRAM or DRAM output
// for compatible data layouts.
int* input_address = (int*)input_addr;
auto tensor_in = args[i]->toTensor();
for (int j = 0; j < tensor_in.numel(); j++) {
// TODO: extend beyond tensors with 4 byte elements
input_address[j] = tensor_in.mutable_data_ptr<int>()[j];
VelaIO* scratch_in = &handles.inputs->io[i];
char* scratch_addr = handles.scratch_data + handles.inputs->io[i].offset;

// We accept:
bool supported = 0;
// 32 bit int (simple non-quantised test cases)
supported |=
(tensor_in.scalar_type() == ScalarType::Int and
handles.inputs->io[i].elem_size == 4);
// 8 bit int (IOQDQ pass prepared networks)
supported |=
(tensor_in.scalar_type() == ScalarType::Char and
handles.inputs->io[i].elem_size == 1);
if (!supported) {
ET_LOG(
Error,
"Input %d expected Integer (4 byte) or Char (1 byte) integer inputs",
i);
return Error::InvalidProgram;
}

// Select a compatible copy routine including checking for input layouts
// which require permutation.
bool permuted_input_shape;
ET_CHECK_OK_OR_RETURN_ERROR(check_requires_permute(
i,
tensor_in,
&handles.inputs->io[i],
execution_handle->permuted_io_flag,
&permuted_input_shape));
bool both_char = tensor_in.scalar_type() == ScalarType::Char and
handles.inputs->io[i].elem_size == 1;
bool both_int = tensor_in.scalar_type() == ScalarType::Int and
handles.inputs->io[i].elem_size == 4;

// Select a compatible copy routine
if (both_char and permuted_input_shape) {
// permuted byte copy CHW to HWC
permute_CHW_to_HWC(
scratch_addr,
tensor_in.mutable_data_ptr<char>(),
tensor_in.size(2),
tensor_in.size(3));
} else if (both_char or both_int) {
// Sizes match and elt size matches so memcpy
memcpy(
scratch_addr,
tensor_in.mutable_data_ptr<char>(),
tensor_in.nbytes());
} else {
ET_LOG(Error, "No matching input copy routine");
return Error::InvalidProgram;
}
}

Expand Down Expand Up @@ -173,6 +213,57 @@ class ArmBackend final : public PyTorchBackendInterface {
void destroy(DelegateHandle* handle) const override {
return;
}

private:
Error check_requires_permute(
int index,
const exec_aten::Tensor tensor_in,
VelaIO* input,
bool permuted_io_flag,
bool* is_permuted) const {
bool permuted_input_shape = false;
if (tensor_in.dim() == 4) {
// special case for NHWC workaround in AOT; as the compilation has
// permuted to channel last in an undetectable way, we assume here
// that the application has similarly permuted any input tensors.
permuted_input_shape = tensor_in.size(0) == input->shape[0] &&
tensor_in.size(1) == input->shape[3] &&
tensor_in.size(2) == input->shape[1] &&
tensor_in.size(3) == input->shape[2];
if (permuted_input_shape) {
ET_LOG(Info, "Tensor input %d will be permuted", index);
}
if (permuted_io_flag != permuted_input_shape) {
ET_LOG(Error, "Permute compile flag and permuted input don't agree");
return Error::InvalidProgram;
}
}
if (!permuted_input_shape) {
// Error check matching shapes in the general case
for (int i = 0; i < tensor_in.dim(); i++) {
if (tensor_in.size(i) != input->shape[i]) {
ET_LOG(Error, "Tensor input %d mismatched shape", index);
ET_LOG(
Error,
"dimension %d mismatch, %d != %d",
index,
tensor_in.size(i),
input->shape[i]);
return Error::InvalidProgram;
}
}
}
*is_permuted = permuted_input_shape;
return Error::Ok;
}

void permute_CHW_to_HWC(char* input, char* output, int H, int W) const {
for (int i = 0; i != H * W; ++i) {
output[i * 3 + 0] = input[i + 0 * W * H];
output[i * 3 + 1] = input[i + 1 * W * H];
output[i * 3 + 2] = input[i + 2 * W * H];
}
}
};

namespace {
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/runtime/VelaBinStream.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 Arm Limited and/or its affiliates.
* Copyright 2023-2024 Arm Limited and/or its affiliates.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -49,7 +49,7 @@ typedef struct {
size_t cmd_data_size;
const char* weight_data;
size_t weight_data_size;
const char* scratch_data;
char* scratch_data;
size_t scratch_data_size;
VelaIOs* inputs;
VelaIOs* outputs;
Expand Down
Loading