Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,67 +89,101 @@ void xa_opt_quantized_conv_nchw_asym8sxsym8s_asym8s(

WORD32 scratch_size = 0;

if (groups == 1) {
WORD32 out_data_format = 1;

WORD8* ptr1 = (WORD8*)kernels::allocate_temp_memory(
ctx,
((batches * input_channels * input_height * input_width) + 8) *
sizeof(WORD8));

WORD8* ptr2 = (WORD8*)kernels::allocate_temp_memory(
ctx,
((out_channels * kernel_channels * kernel_height * kernel_width) + 8) *
sizeof(WORD8));

WORD8* pin = (WORD8*)ALIGN_PTR(ptr1, 8);
WORD8* pkernel = (WORD8*)ALIGN_PTR(ptr2, 8);

WORD32 p_inp_shape[kNnlibMaxDim];
p_inp_shape[0] = input.size(0);
p_inp_shape[1] = input_channels;
p_inp_shape[2] = input_height;
p_inp_shape[3] = input_width;

WORD32 p_out_shape[kNnlibMaxDim];
p_out_shape[0] = input.size(0);
p_out_shape[1] = input_height;
p_out_shape[2] = input_width;
p_out_shape[3] = input_channels;

WORD32 p_permute_vec[kNnlibMaxDim] = {0, 2, 3, 1};

xa_nn_transpose_8_8(
pin,
p_out_shape,
p_inp,
p_inp_shape,
p_permute_vec,
kNnlibMaxDim,
kNnlibMaxDim);

WORD32 p_inp_shape1[kNnlibMaxDim];
p_inp_shape1[0] = out_channels;
p_inp_shape1[1] = kernel_channels;
p_inp_shape1[2] = kernel_height;
p_inp_shape1[3] = kernel_width;

WORD32 p_out_shape1[kNnlibMaxDim];
p_out_shape1[0] = out_channels;
p_out_shape1[1] = kernel_height;
p_out_shape1[2] = kernel_width;
p_out_shape1[3] = kernel_channels;

xa_nn_transpose_8_8(
ET_CHECK_MSG(groups == 1, "Only groups=1 supported for regular convolution");
WORD32 out_data_format = 1;

WORD8* ptr1 = (WORD8*)kernels::allocate_temp_memory(
ctx,
((batches * input_channels * input_height * input_width) + 8) *
sizeof(WORD8));

WORD8* ptr2 = (WORD8*)kernels::allocate_temp_memory(
ctx,
((out_channels * kernel_channels * kernel_height * kernel_width) + 8) *
sizeof(WORD8));

WORD8* pin = (WORD8*)ALIGN_PTR(ptr1, 8);
WORD8* pkernel = (WORD8*)ALIGN_PTR(ptr2, 8);

WORD32 p_inp_shape[kNnlibMaxDim];
p_inp_shape[0] = input.size(0);
p_inp_shape[1] = input_channels;
p_inp_shape[2] = input_height;
p_inp_shape[3] = input_width;

WORD32 p_out_shape[kNnlibMaxDim];
p_out_shape[0] = input.size(0);
p_out_shape[1] = input_height;
p_out_shape[2] = input_width;
p_out_shape[3] = input_channels;

WORD32 p_permute_vec[kNnlibMaxDim] = {0, 2, 3, 1};

xa_nn_transpose_8_8(
pin,
p_out_shape,
p_inp,
p_inp_shape,
p_permute_vec,
kNnlibMaxDim,
kNnlibMaxDim);

WORD32 p_inp_shape1[kNnlibMaxDim];
p_inp_shape1[0] = out_channels;
p_inp_shape1[1] = kernel_channels;
p_inp_shape1[2] = kernel_height;
p_inp_shape1[3] = kernel_width;

WORD32 p_out_shape1[kNnlibMaxDim];
p_out_shape1[0] = out_channels;
p_out_shape1[1] = kernel_height;
p_out_shape1[2] = kernel_width;
p_out_shape1[3] = kernel_channels;

xa_nn_transpose_8_8(
pkernel,
p_out_shape1,
p_kernel,
p_inp_shape1,
p_permute_vec,
kNnlibMaxDim,
kNnlibMaxDim);

scratch_size = xa_nn_conv2d_getsize(
input_height,
input_width,
input_channels,
kernel_height,
kernel_width,
kernel_channels,
dilation_height,
dilation_width,
y_stride,
y_padding,
x_stride,
x_padding,
out_height,
out_width,
out_channels,
inp_precision,
kernel_precision,
out_data_format);

scratch_size = scratch_size < 0 ? 0 : scratch_size;

ptr_scratch = (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size);

p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8);

for (int _n = 0; _n < batches; _n++) {
WORD8* in_batch = pin + _n * input_channels * input_height * input_width;
WORD8* out_batch = p_out + _n * out_channels * out_height * out_width;

xa_nn_conv2d_per_chan_sym8sxasym8s(
out_batch,
in_batch,
pkernel,
p_out_shape1,
p_kernel,
p_inp_shape1,
p_permute_vec,
kNnlibMaxDim,
kNnlibMaxDim);

scratch_size = xa_nn_conv2d_getsize(
p_bias,
input_height,
input_width,
input_channels,
Expand All @@ -158,59 +192,20 @@ void xa_opt_quantized_conv_nchw_asym8sxsym8s_asym8s(
kernel_channels,
dilation_height,
dilation_width,
y_stride,
y_padding,
out_channels,
x_stride,
y_stride,
x_padding,
y_padding,
out_height,
out_width,
out_channels,
inp_precision,
kernel_precision,
out_data_format);

scratch_size = scratch_size < 0 ? 0 : scratch_size;

ptr_scratch = (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size);

p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8);

for (int _n = 0; _n < batches; _n++) {
WORD8* in_batch = pin + _n * input_channels * input_height * input_width;
WORD8* out_batch = p_out + _n * out_channels * out_height * out_width;

xa_nn_conv2d_per_chan_sym8sxasym8s(
out_batch,
in_batch,
pkernel,
p_bias,
input_height,
input_width,
input_channels,
kernel_height,
kernel_width,
kernel_channels,
dilation_height,
dilation_width,
out_channels,
x_stride,
y_stride,
x_padding,
y_padding,
out_height,
out_width,
input_zero_bias,
out_multiplier32,
out_shift32,
out_zero_bias,
out_data_format,
p_scratch);
}
return;
input_zero_bias,
out_multiplier32,
out_shift32,
out_zero_bias,
out_data_format,
p_scratch);
}

// Depthwise convolutions are now handled by specialized operators
ET_CHECK_MSG(groups == 1, "Only groups=1 supported for regular convolution");
}

void quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor_out(
Expand Down
Loading
Loading