Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions backends/nxp/neutron_node_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from dataclasses import dataclass

import numpy as np

from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOperator import (
BuiltinOperator,
)
Expand All @@ -15,6 +14,10 @@

@dataclass
class NeutronNodeArtifacts:
input_names: list[str]
input_indices: list[int]
output_names: list[str]
output_indices: list[int]
microcode: np.ndarray
weights: np.ndarray
kernels: np.ndarray
Expand Down Expand Up @@ -99,4 +102,42 @@ def extract_artifacts_from_neutron_node(
microcode.dtype == weights.dtype == kernels.dtype == np.dtype("uint8")
), "The Neutron Node uses unexpected data types."

return NeutronNodeArtifacts(microcode, weights, kernels)
input_names = []
input_indices = []
graph_inputs = sub_graph.InputsAsNumpy()
node_inputs = neutron_node.InputsAsNumpy()[:-3]
for tensor_idx in node_inputs:
which_graph_input = np.where(graph_inputs == tensor_idx)[0]
assert (
which_graph_input.size == 1
), "Mismatch between Neutron Node inputs and graph inputs."
input_indices.append(which_graph_input[0])
input_names.append(sub_graph.Tensors(graph_inputs[which_graph_input[0]]).Name())

assert (
neutron_node.OutputsLength() >= 2
), f"The Neutron Node only has `{neutron_node.GetOutputsLen()}` outputs. Expected at least `2` including the scratch buffer."

output_names = []
output_indices = []
graph_outputs = sub_graph.OutputsAsNumpy()
node_outputs = neutron_node.OutputsAsNumpy()[:-1]
for tensor_idx in node_outputs:
which_graph_output = np.where(graph_outputs == tensor_idx)[0]
assert (
which_graph_output.size == 1
), "Mismatch between Neutron Node outputs and graph outputs."
output_indices.append(which_graph_output[0])
output_names.append(
sub_graph.Tensors(graph_outputs[which_graph_output[0]]).Name()
)

return NeutronNodeArtifacts(
input_names,
input_indices,
output_names,
output_indices,
microcode,
weights,
kernels,
)
70 changes: 49 additions & 21 deletions backends/nxp/nxp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,39 +245,67 @@ def _format_string_for_array(self, array: np.ndarray) -> str:

return f"{array.size}s{self._padding_format_string_for_array(array)}"

def _create_payload_header(self, io_formats) -> np.ndarray:
def _create_payload_header(self, io_formats, neutron_artifacts) -> np.ndarray:
"""
Create bytes header for returned payload. It contains information about
input and output tensor formats. Tensors are ordered based on graph signature
of ExportedProgram. Header schema:

+----------------------------------+-----------------------------------+
| Input TensorFormats length (1B) | Output TensorFormats length (1B) |
+----------------------------------+-----------------------------------+
| 1st input tensor format (1B) | [nth* input tensor format (1B)] |
+----------------------------------+-----------------------------------+
| 1st output tensor format (1B) | [nth* output tensor format (1B)] |
+----------------------------------+-----------------------------------+
+----------------------------+-----------------------------+------------------------+
| Neutron inputs length (1B) | Neutron outputs length (1B) | Input args length (1B) |
+----------------------------+-----------+-----------------+------------------------+
| 1st input tensor format (1B) | [nth* input tensor format (1B)] |
+----------------------------------------+------------------------------------------+
| 1st output tensor format (1B) | [nth* output tensor format (1B)] |
+----------------------------------------+------------------------------------------+
| 1st input map (1B) | [nth* input map (1B)] |
+----------------------------------------+------------------------------------------+
| 1st output map (1B) | [nth* output map (1B)] |
+----------------------------------------+------------------------------------------+

:param io_formats: IO tensors formats.
:return: Bytes representation of payload header.
"""
inputs = io_formats["inputs"]
outputs = io_formats["outputs"]

assert len(inputs) < 256, "Models with more than 255 inputs are not supported."
assert (
len(outputs) < 256
len(neutron_artifacts.input_indices) < 256
), "Models with more than 255 inputs are not supported."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: This is actually a valid case, not a "programming error". We just use 8bit field to encode input indices. Consider using ValueError instead of assert.

assert (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: This is actually a valid case, not a "programming error". We just use 8bit field to encode input indices. Consider using ValueError instead of assert.

len(neutron_artifacts.output_indices) < 256
), "Models with more than 255 outputs are not supported."

header_data = [len(inputs)]
header_data.append(len(outputs))
header_data = [len(neutron_artifacts.input_indices)]
header_data.append(len(neutron_artifacts.output_indices))
header_data.append(len(inputs))

for _tensor, tensor_format in inputs.items():
header_data.append(1 if tensor_format == TensorFormat.CHANNELS_LAST else 0)
for input_name in neutron_artifacts.input_names:
try:
header_data.append(
1
if inputs[input_name.decode()] == TensorFormat.CHANNELS_LAST
else 0
)
except KeyError:
raise AssertionError(
f"Input tensor `{input_name.decode()}` not found in the converted model."
)

for _tensor, tensor_format in outputs.items():
header_data.append(1 if tensor_format == TensorFormat.CHANNELS_LAST else 0)
for output_name in neutron_artifacts.output_names:
try:
header_data.append(
1
if outputs[output_name.decode()] == TensorFormat.CHANNELS_LAST
else 0
)
except KeyError:
raise AssertionError(
f"Output tensor `{output_name.decode()}` not found in the converted model."
)

header_data.extend(neutron_artifacts.input_indices)
header_data.extend(neutron_artifacts.output_indices)

# noinspection PyTypeChecker
return np.array(header_data, dtype=np.uint8)
Expand Down Expand Up @@ -314,9 +342,9 @@ def get_binary_payload(self, io_formats, neutron_model) -> bytes:

+----------------------------------------------------------------------------------------------------------------+
| 16 bytes aligned blocks |
+===========================+===========================+============================+===========================+
| Input formats length (1B) | Output formats length (1B) | [nth* input format (1B)] | [nth* output format (1B)] |
+---------------------------+--------------------------- +---------------------------+---------------------------+
+================================================================================================================+
| Header |
+----------------------------------------------------------------------------------------------------------------+
| Neutron microcode |
+----------------------------------------------------------------------------------------------------------------+
| Neutron weights |
Expand All @@ -331,9 +359,9 @@ def get_binary_payload(self, io_formats, neutron_model) -> bytes:
:param neutron_model: Neutron model with single NeutronGraph node.
:return: 16 bytes aligned binary payload.
"""
header = self._create_payload_header(io_formats)

# Extract the Neutron microcode, weights and kernels from the Neutron Node in the `neutron_model`.
neutron_artifacts = extract_artifacts_from_neutron_node(neutron_model)

header = self._create_payload_header(io_formats, neutron_artifacts)

return self._pack_with_alignment(header, neutron_artifacts)
121 changes: 74 additions & 47 deletions backends/nxp/runtime/NeutronBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,37 +25,53 @@ namespace neutron {
#define ALIGN_SIZE(size) \
((size + BUFFER_ALIGNMENT - 1) & (~(BUFFER_ALIGNMENT - 1)))

// clang-format off
/* Header schema:
+----------------------------------+-----------------------------------+
| Input TensorFormats length (1B) | Output TensorFormats length (1B) |
+----------------------------------+-----------------------------------+
| 1st input tensor format (1B) | [nth* input tensor format (1B)] |
+----------------------------------+-----------------------------------+
| 1st output tensor format (1B) | [nth* output tensor format (1B)] |
+----------------------------------+-----------------------------------+
+----------------------------+-----------------------------+------------------------+
| Neutron inputs length (1B) | Neutron outputs length (1B) | Input args length (1B) |
+----------------------------+-----------+-----------------+------------------------+
| 1st input tensor format (1B) | [nth* input tensor format (1B)] |
+----------------------------------------+------------------------------------------+
| 1st output tensor format (1B) | [nth* output tensor format (1B)] |
+----------------------------------------+------------------------------------------+
| 1st input map (1B) | [nth* input map (1B)] |
+----------------------------------------+------------------------------------------+
| 1st output map (1B) | [nth* output map (1B)] |
+----------------------------------------+------------------------------------------+
*/
// clang-format on
#define ITEM_SIZE 1 // 1 Byte
#define INPUT_TENSOR_FORMAT_LEN_POS 0
#define OUTPUT_TENSOR_FORMAT_LEN_POS 1
#define INPUT_TENSOR_FORMAT_ARRAY_ADDR(base) (base + 2 * ITEM_SIZE)
#define INPUT_ARGS_LEN_POS 2
#define INPUT_TENSOR_FORMAT_ARRAY_ADDR(base) (base + 3 * ITEM_SIZE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is ok for now but in the future consider FC/BC issues when updating serialization formats.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We hope that the format is now general enough to cover the future scenarios.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Says like a true developer, until a second before breaking BC :-p

#define OUTPUT_TENSOR_FORMAT_ARRAY_ADDR(base) \
(base + 2 * ITEM_SIZE + base[INPUT_TENSOR_FORMAT_LEN_POS])
#define PAYLOAD_ADDR(base) \
(base + \
ALIGN_SIZE( \
2 * ITEM_SIZE + base[INPUT_TENSOR_FORMAT_LEN_POS] + \
base[OUTPUT_TENSOR_FORMAT_LEN_POS]))
(base + 3 * ITEM_SIZE + base[INPUT_TENSOR_FORMAT_LEN_POS])
#define INPUT_TENSOR_MAP_ARRAY_ADDR(base) \
(base + 3 * ITEM_SIZE + 1 * base[INPUT_TENSOR_FORMAT_LEN_POS] + \
1 * base[OUTPUT_TENSOR_FORMAT_LEN_POS])
#define OUTPUT_TENSOR_MAP_ARRAY_ADDR(base) \
(base + 3 * ITEM_SIZE + 2 * base[INPUT_TENSOR_FORMAT_LEN_POS] + \
1 * base[OUTPUT_TENSOR_FORMAT_LEN_POS])
#define PAYLOAD_ADDR(base) \
(base + \
ALIGN_SIZE( \
3 * ITEM_SIZE + 2 * base[INPUT_TENSOR_FORMAT_LEN_POS] + \
2 * base[OUTPUT_TENSOR_FORMAT_LEN_POS]))

// Aggregate neutron model handle and data structures into one.
typedef struct {
int numInputs = 0;
int numOutputs = 0;
int numInputArgs = 0;
uint32_t scratchSize = 0;
NeutronModelConfig mcfg;
NeutronDataConfig dcfg;
NeutronModelHandle nmh = NULL;
const uint8_t* inputTranspositionFlags;
const uint8_t* outputTranspositionFlags;
const uint8_t* inputMap;
const uint8_t* outputMap;
} NeutronConfig;

// Applied on outputs.
Expand Down Expand Up @@ -210,6 +226,15 @@ void transposeOutput(
}
}

bool multipleChannelsPresent(const ArrayRef<exec_aten::SizesType>& sizes) {
size_t length = sizes.size();
if (length < 3) {
return true;
}
size_t C = sizes[length - 3];
return C != 1;
}

class NeutronBackend final : public PyTorchBackendInterface {
public:
NeutronBackend() {}
Expand All @@ -234,17 +259,19 @@ class NeutronBackend final : public PyTorchBackendInterface {
// cfg->mcfg.microcode
// cfg->mcfg.weights
// cfg->mcfg.kernels
const uint8_t* transpositionFlags =
const uint8_t* payloadFlags =
static_cast<const uint8_t*>(processed->data());
int numInputs = transpositionFlags[INPUT_TENSOR_FORMAT_LEN_POS];
int numOutputs = transpositionFlags[OUTPUT_TENSOR_FORMAT_LEN_POS];
cfg->inputTranspositionFlags =
INPUT_TENSOR_FORMAT_ARRAY_ADDR(transpositionFlags);
uint32_t numInputs = payloadFlags[INPUT_TENSOR_FORMAT_LEN_POS];
uint32_t numOutputs = payloadFlags[OUTPUT_TENSOR_FORMAT_LEN_POS];
cfg->numInputArgs = payloadFlags[INPUT_ARGS_LEN_POS];
cfg->inputTranspositionFlags = INPUT_TENSOR_FORMAT_ARRAY_ADDR(payloadFlags);
cfg->outputTranspositionFlags =
OUTPUT_TENSOR_FORMAT_ARRAY_ADDR(transpositionFlags);
OUTPUT_TENSOR_FORMAT_ARRAY_ADDR(payloadFlags);
cfg->inputMap = INPUT_TENSOR_MAP_ARRAY_ADDR(payloadFlags);
cfg->outputMap = OUTPUT_TENSOR_MAP_ARRAY_ADDR(payloadFlags);

const uint32_t* buffer = static_cast<const uint32_t*>(
static_cast<const void*> PAYLOAD_ADDR(transpositionFlags));
static_cast<const void*> PAYLOAD_ADDR(payloadFlags));
uint32_t magicWord = buffer[0];
// Check valid microcode.
if (magicWord != 0x64434D6E) {
Expand Down Expand Up @@ -314,39 +341,37 @@ class NeutronBackend final : public PyTorchBackendInterface {
cfg->dcfg.outputs[cfg->numOutputs] =
static_cast<void*>(context.allocate(cfg->scratchSize, 16));

// Set inputs and outputs from args.
// Set inputs from args.
// Transpose inputs if needed.
for (int i = 0; i < cfg->numInputs; i++) {
cfg->dcfg.inputs[i] = args[i]->toTensor().const_data_ptr();
}
for (int i = 0; i < cfg->numOutputs; i++) {
cfg->dcfg.outputs[i] =
args[cfg->numInputs + i]->toTensor().mutable_data_ptr();
}

// Transpose inputs.
for (int i = 0; i < cfg->numInputs; i++) {
if (cfg->inputTranspositionFlags[i]) {
if (args[i]->toTensor().sizes().size() < 3) {
auto arg = args[cfg->inputMap[i]]->toTensor();
if (cfg->inputTranspositionFlags[i] &&
multipleChannelsPresent(arg.sizes())) {
if (arg.sizes().size() < 3) {
ET_LOG(Error, "Unable to transpose 1D and 2D input to channel last");
return Error::InvalidProgram;
}
// Allocate buffer, the allocator is reset after each PTE instruction.
void* buffer = context.allocate(args[i]->toTensor().nbytes(), 16);
void* buffer = context.allocate(arg.nbytes());
transposeInput(
args[i]->toTensor().const_data_ptr(),
buffer,
args[i]->toTensor().sizes(),
args[i]->toTensor().element_size());
arg.const_data_ptr(), buffer, arg.sizes(), arg.element_size());
cfg->dcfg.inputs[i] = buffer;
} else {
cfg->dcfg.inputs[i] = arg.const_data_ptr();
}
}
// Redirect outputs.

// Set outputs from args.
// Redirect outputs if needed before transposition.
for (int i = 0; i < cfg->numOutputs; i++) {
if (cfg->outputTranspositionFlags[i]) {
auto arg = args[cfg->numInputArgs + cfg->outputMap[i]]->toTensor();
if (cfg->outputTranspositionFlags[i] &&
multipleChannelsPresent(arg.sizes())) {
// Allocate buffer, the allocator is reset after each PTE instruction.
void* buffer =
context.allocate(args[cfg->numInputs + i]->toTensor().nbytes(), 16);
void* buffer = context.allocate(arg.nbytes());
cfg->dcfg.outputs[i] = buffer;
} else {
cfg->dcfg.outputs[i] = arg.mutable_data_ptr();
}
}

Expand All @@ -368,17 +393,19 @@ class NeutronBackend final : public PyTorchBackendInterface {

// Transpose outputs.
for (int i = 0; i < cfg->numOutputs; i++) {
if (cfg->outputTranspositionFlags[i]) {
if (args[cfg->numInputs + i]->toTensor().sizes().size() < 3) {
auto arg = args[cfg->numInputArgs + cfg->outputMap[i]]->toTensor();
if (cfg->outputTranspositionFlags[i] &&
multipleChannelsPresent(arg.sizes())) {
if (arg.sizes().size() < 3) {
ET_LOG(
Error, "Unable to transpose 1D and 2D output to channel first");
return Error::InvalidProgram;
}
transposeOutput(
cfg->dcfg.outputs[i],
args[cfg->numInputs + i]->toTensor().mutable_data_ptr(),
args[cfg->numInputs + i]->toTensor().sizes(),
args[cfg->numInputs + i]->toTensor().element_size());
arg.mutable_data_ptr(),
arg.sizes(),
arg.element_size());
}
}

Expand Down
Loading
Loading