<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

#Data process

In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "tools/optimizer/fusion/batchmatmul_fusion.h"
#include <memory>
#include "src/ops/primitive_c.h"
#include "src/param_value_lite.h"
#include "schema/inner/model_generated.h"
#include "utils/utils.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "securec/include/securec.h"

namespace mindspore::opt {
namespace {
bool IsStackNode(const BaseRef &n) {
  if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
    auto type = opt::GetCNodeType(n);
    return type == schema::PrimitiveType_Stack;
  }
  return false;
}
bool IsFullConnectNode(const BaseRef &n) {
  if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
    auto type = opt::GetCNodeType(n);
    return type == schema::PrimitiveType_FullConnection;
  }
  return false;
}
void *GetInputAddr(const AnfNodePtr &node, int input_index) {
  MS_ASSERT(node != nullptr);
  if (!node->isa<CNode>()) {
    MS_LOG(ERROR) << "GetInputAddr not cnode";
    return nullptr;
  }
  auto cnode = node->cast<CNodePtr>();
  if (input_index >= cnode->inputs().size()) {
    MS_LOG(ERROR) << "input index error";
    return nullptr;
  }
  if (cnode->input(input_index)->isa<Parameter>()) {
    auto param_input = cnode->input(input_index)->cast<ParameterPtr>();
    auto param_value = std::dynamic_pointer_cast<ParamValueLite>(param_input->default_param());;
    if (param_value == nullptr) {
      MS_LOG(ERROR) << "param not paramValueLite";
      return nullptr;
    }
    return param_value->tensor_addr();
  }
  MS_LOG(ERROR) << "input not paramter";
  return nullptr;
}
STATUS GetRightMatmulInputParamter(CNodePtr &stack_node, ParameterPtr &rmatmul_input) {
  MS_ASSERT(stack_node != nullptr);
  MS_ASSERT(right_matmul_input != nullptr);
  auto joint_fullconnect_size = stack_node->inputs().size() - 1;
  auto fc = stack_node->input(1)->cast<CNodePtr>();
  auto fc_weight = fc->input(2)->cast<ParameterPtr>();
  auto fc_weight_param = std::dynamic_pointer_cast<ParamValueLite>(fc_weight->default_param());
  auto tensor_size = fc_weight_param->tensor_size();
  auto rmatmul_input_shape = fc_weight_param->tensor_shape();
  auto new_tensor_data = new(std::nothrow) int8_t[joint_fullconnect_size * tensor_size];
  if (new_tensor_data == nullptr) {
    MS_LOG(ERROR) << "tensor_data is nullptr";
    return RET_ERROR;
  }
  for (int i = 1; i < joint_fullconnect_size + 1; i++) {
    auto tensor_addr = GetInputAddr(stack_node->input(i), 2);
    if (tensor_addr == nullptr) {
      MS_LOG(ERROR) << "input tensor addr nullptr";
      return RET_ERROR;
    }
    if (EOK != memcpy_s(new_tensor_data + (i - 1) * tensor_size, tensor_size, tensor_addr, tensor_size)) {
      MS_LOG(ERROR) << "memcpy_s data failed";
      return RET_ERROR;
    }
  }
  rmatmul_input_shape.insert(rmatmul_input_shape.begin(), joint_fullconnect_size);
  auto type_ptr = TypeIdToType(fc_weight_param->tensor_type());
  auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, rmatmul_input_shape);
  rmatmul_input->set_abstract(abstract_tensor);
  rmatmul_input->set_name(stack_node->fullname_with_scope() + "right_parameter");
  ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
  MS_ASSERT(param_value != nullptr);
  param_value->set_tensor_shape(rmatmul_input_shape);
  param_value->set_tensor_type(fc_weight_param->tensor_type());
  param_value->set_format(fc_weight_param->format());
  param_value->set_tensor_addr(new_tensor_data);
  param_value->set_tensor_size(joint_fullconnect_size * tensor_size);
  rmatmul_input->set_default_param(param_value);
  return RET_OK;
}
}  // namespace
const BaseRef BatchMatMulFusion::DefinePattern() const {
  auto pack_var = std::make_shared<CondVar>(IsStackNode);
  auto fullconnect_var = std::make_shared<CondVar>(IsFullConnectNode);
  auto bn_other_var = std::make_shared<SeqVar>();
  return VectorRef({pack_var, fullconnect_var, fullconnect_var, bn_other_var});
}

// slice +fullconnect ->batchmatmul
const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
                                            const EquivPtr &) const {

  MS_ASSERT(func_graph != nullptr);
  MS_ASSERT(node != nullptr);
  auto stack_cnode = node->cast<CNodePtr>();
  // check stack node all inputs must fullconnect
  for (int i = 1; i < stack_cnode->inputs().size(); i++) {
    auto input_node = stack_cnode->input(i);
    if (!IsFullConnectNode(input_node)) {
      MS_LOG(WARNING) << "batchmatmulfusion stack node all inputs must fullconnect type";
      return nullptr;
    }
  }
  auto fullconnect_node = stack_cnode->input(1);
  MS_ASSERT(fullconnnect_node != nullptr);
  auto fullconnect_cnode = fullconnect_node->cast<CNodePtr>();
  MS_ASSERT(fullconnect_cnode->inputs().size() == 3);
  auto left_slice_node = fullconnect_cnode->input(1);
  auto left_slice_cnode = left_slice_node->cast<CNodePtr>();
  auto left_matmul_input = left_slice_cnode->input(1);
  auto right_reshape_node = fullconnect_cnode->input(2);

  auto matmul_primitive = std::make_unique<schema::PrimitiveT>();
  std::unique_ptr<schema::MatMulT> attr = std::make_unique<schema::MatMulT>();
  matmul_primitive->value.type = schema::PrimitiveType_MatMul;
  matmul_primitive->value.value = attr.release();
  auto matmul_cvalue = lite::PrimitiveC::Create(matmul_primitive.release());
  // get matmul quantParams
  std::vector<schema::QuantParamT> jointed_quant_params;
  for (int i = 1; i < 9; i++) {
    auto fullconnect_node2 = stack_cnode->input(i)->cast<CNodePtr>();
    auto fc_prim = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(fullconnect_node2->input(0));
    auto fc_input_quantParams = fc_prim->GetInputQuantParams();
    if (fc_input_quantParams.size() > 1 && !fc_input_quantParams[1].empty()) {
      jointed_quant_params.push_back(fc_input_quantParams[1][0]);
    }
  }
  auto fc_prim = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(fullconnect_cnode->input(0));
  auto rmatmul_quant_params = fc_prim->GetInputQuantParams();
  rmatmul_quant_params.pop_back();
  rmatmul_quant_params.pop_back();
  // no bias quantParams
  rmatmul_quant_params.emplace_back(jointed_quant_params);
  matmul_cvalue->SetInputQuantParam(rmatmul_quant_params);
  matmul_cvalue->SetOutputQuantParam(fc_prim->GetOutputQuantParams());
  auto matmul_value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(matmul_cvalue));
  std::vector<AnfNodePtr> matmul_inputs = {matmul_value_node, left_matmul_input};

  // batchmatmul right node may be const
  if (right_reshape_node->isa<Parameter>()) {
//    return stack_cnode;
    auto rmatmul_paramter = func_graph->add_parameter();
    if (GetRightMatmulInputParamter(stack_cnode, rmatmul_paramter) != RET_OK) {
      MS_LOG(ERROR) << "GetRightMatmulInputParamter failed";
      return node;
    }
    auto prim = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(matmul_value_node);
    prim->GetPrimitiveT()->value.AsMatMul()->transposeB = true;
    matmul_inputs.push_back(rmatmul_paramter);
  } else {
    auto right_reshape_cnode = right_reshape_node->cast<CNodePtr>();
    MS_ASSERT(right_reshape_cnode->inputs().size() > 1);
    auto right_transpose_node = right_reshape_cnode->input(1);
    auto right_transpose_cnode = right_transpose_node->cast<CNodePtr>();
    auto right_slice_node = right_transpose_cnode->input(1);
    auto right_slice_cnode = right_slice_node->cast<CNodePtr>();
    auto right_matmul_input = right_slice_cnode->input(1);
    matmul_inputs.push_back(right_matmul_input);
  }
  auto matmul_cnode = func_graph->NewCNode(matmul_inputs);
  matmul_cnode->set_fullname_with_scope("matmul_" + stack_cnode->fullname_with_scope());
  MS_LOG(INFO) << "stack node:" << stack_cnode->fullname_with_scope() << " batchmatmul fusion success";
  return matmul_cnode;
}
}  // namespace mindspore::opt


In [None]:
 if (input_tensor->shape().size() == 3
      && input_tensor->GetQuantParams().size() == input_tensor->shape()[0]) { // per batch matmul
    auto per_batch_size = input_tensor->shape()[0];
    auto quant_param = input_tensor->GetQuantParams();
    for (size_t i = 0; i < per_batch_size; i++) {
      auto param = quant_param.at(i);
      auto scale = param.scale;
      auto zero_point = param.zeroPoint;
      auto matrix_size = input_tensor->ElementsNum() / per_batch_size;
      for (int64_t j = 0; j < matrix_size; j++) {
        dequant_datas[i * matrix_size + j] =
            static_cast<float>((quant_datas[i * matrix_size + j] - zero_point) * scale);
      }
    }
    return dequant_datas;
  }

In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "src/runtime/kernel/arm/fp32/matmul.h"
#include "nnacl/fp32/matmul.h"
#include "src/runtime/runtime_api.h"
#include "include/errorcode.h"

using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_INPUT_TENSOR_ERROR;
using mindspore::lite::RET_MEMORY_FAILED;
using mindspore::lite::RET_OK;

namespace mindspore::kernel {
MatmulCPUKernel::~MatmulCPUKernel() { FreeTmpBuffer(); }

void MatmulCPUKernel::FreeTmpBuffer() {
  if (a_c12_ptr_ != nullptr) {
    free(a_c12_ptr_);
    a_c12_ptr_ = nullptr;
  }
  if (b_r8_ptr_ != nullptr) {
    free(b_r8_ptr_);
    b_r8_ptr_ = nullptr;
  }
  if (bias_ptr_ != nullptr) {
    free(bias_ptr_);
    bias_ptr_ = nullptr;
  }
}

int MatmulCPUKernel::ReSize() {
  FreeTmpBuffer();
  auto a_shape = in_tensors_[0]->shape();
  auto b_shape = in_tensors_[1]->shape();
  auto c_shape = out_tensors_[0]->shape();
  if (in_tensors_.size() == 3) {
    auto bias_shape = in_tensors_[2]->shape();
    if (bias_shape[bias_shape.size() - 1] != c_shape[c_shape.size() - 1]) {
      MS_LOG(ERROR) << "The bias' dimension is not equal with column";
      return RET_INPUT_TENSOR_ERROR;
    }
  }
  bool a_broadcast = false;
  bool b_broadcast = false;
  if (a_shape.size() == 2 && b_shape.size() == 2) {
    a_broadcast = false;
    b_broadcast = false;
  } else if (a_shape.size() == 2 && (b_shape.size() != 2 && b_shape[b_shape.size() - 3] != 1)) {
    a_broadcast = true;
  } else if ((a_shape.size() != 2 && a_shape[a_shape.size() - 3] != 1) && b_shape.size() == 2) {
    b_broadcast = true;
  } else if (a_shape[a_shape.size() - 3] == 1 && b_shape.size() > 2 && b_shape[b_shape.size() - 3] != 1) {
    a_broadcast = true;
  } else if (a_shape[a_shape.size() - 3] != 1 && a_shape.size() > 2 && b_shape[b_shape.size() - 3] == 1) {
    b_broadcast = true;
  }
  params_->a_broadcast_ = a_broadcast;
  params_->b_broadcast_ = b_broadcast;
  int a_batch = 1;
  int b_batch = 1;
  for (size_t i = 0; i < a_shape.size() - 2; ++i) {
    a_batch *= a_shape[i];
  }
  for (size_t i = 0; i < b_shape.size() - 2; ++i) {
    b_batch *= b_shape[i];
  }
  params_->a_batch = a_batch;
  params_->b_batch = b_batch;
  params_->row_ = c_shape[c_shape.size() - 2];
  params_->col_ = c_shape[c_shape.size() - 1];
  params_->deep_ = params_->a_transpose_ ? a_shape[a_shape.size() - 2] : a_shape[a_shape.size() - 1];
  params_->row_4_ = UP_ROUND(params_->row_, C4NUM);
  params_->row_12_ = UP_ROUND(params_->row_, C12NUM);
  params_->col_8_ = UP_ROUND(params_->col_, 8);
  thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, 8));
  thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, 8), thread_count_);

#ifdef ENABLE_ARM32
  a_c12_ptr_ = reinterpret_cast<float *>(malloc(params_->a_batch * params_->row_4_ * params_->deep_ * sizeof(float)));
  if (a_c12_ptr_ == nullptr) {
    FreeTmpBuffer();
    return RET_MEMORY_FAILED;
  }
  memset(a_c12_ptr_, 0, params_->row_4_ * params_->deep_ * sizeof(float));
#else
  a_c12_ptr_ = reinterpret_cast<float *>(malloc(params_->a_batch * params_->row_12_ * params_->deep_ * sizeof(float)));
  if (a_c12_ptr_ == nullptr) {
    FreeTmpBuffer();
    return RET_MEMORY_FAILED;
  }
  memset(a_c12_ptr_, 0, params_->row_12_ * params_->deep_ * sizeof(float));
#endif

  b_r8_ptr_ = reinterpret_cast<float *>(malloc(params_->b_batch * params_->col_8_ * params_->deep_ * sizeof(float)));
  if (b_r8_ptr_ == nullptr) {
    FreeTmpBuffer();
    return RET_MEMORY_FAILED;
  }
  memset(b_r8_ptr_, 0, params_->col_8_ * params_->deep_ * sizeof(float));

  params_->a_const_ = (in_tensors_[0]->data_c() != nullptr);
  params_->b_const_ = (in_tensors_[1]->data_c() != nullptr);
  if (params_->a_const_ == true) {
    InitMatrixA(reinterpret_cast<float *>(in_tensors_[0]->data_c()), a_c12_ptr_);
  }
  if (params_->b_const_ == true) {
    InitMatrixB(reinterpret_cast<float *>(in_tensors_[1]->data_c()), b_r8_ptr_);
  }

  bias_ptr_ = reinterpret_cast<float *>(malloc(params_->col_8_ * sizeof(float)));
  if (bias_ptr_ == nullptr) {
    FreeTmpBuffer();
    return RET_MEMORY_FAILED;
  }
  memset(bias_ptr_, 0, params_->col_8_ * sizeof(float));
  if (in_tensors_.size() == 3) {
    memcpy(bias_ptr_, in_tensors_[2]->data_c(), params_->col_ * sizeof(float));
  }

  return RET_OK;
}

void MatmulCPUKernel::InitMatrixA(float *src_ptr, float *dst_ptr) {
  for (int i = 0; i < params_->a_batch; i++) {
    float *src = src_ptr + i * params_->deep_ * params_->row_;
#ifdef ENABLE_ARM32
    float *dst = dst_ptr + i * params_->deep_ * params_->row_4_;
    if (params_->a_transpose_) {
      RowMajor2Row4Major(src, dst, params_->deep_, params_->row_);
    } else {
      RowMajor2Col4Major(src, dst, params_->row_, params_->deep_);
    }
#else
    float *dst = dst_ptr + i * params_->deep_ * params_->row_12_;
    if (params_->a_transpose_) {
      RowMajor2Row12Major(src, dst, params_->deep_, params_->row_);
    } else {
      RowMajor2Col12Major(src, dst, params_->row_, params_->deep_);
    }
#endif
  }
  return;
}

void MatmulCPUKernel::InitMatrixB(float *src_ptr, float *dst_ptr) {
  for (int i = 0; i < params_->b_batch; i++) {
    float *src = src_ptr + i * params_->deep_ * params_->col_;
    float *dst = dst_ptr + i * params_->deep_ * params_->col_8_;
    if (params_->b_transpose_) {
      RowMajor2Col8Major(src, dst, params_->col_, params_->deep_);
    } else {
      RowMajor2Row8Major(src, dst, params_->deep_, params_->col_);
    }
  }
  return;
}

int MatmulCPUKernel::Init() {
  if (!InferShapeDone()) {
    return RET_OK;
  }
  return ReSize();
}

int MatmulCPUKernel::RunImpl(int task_id) {
  int cur_oc = MSMIN(thread_stride_ * C8NUM, params_->col_ - task_id * thread_stride_ * C8NUM);
  if (cur_oc <= 0) {
    return RET_OK;
  }
  MatMulOpt(a_ptr_, b_ptr_ + task_id * thread_stride_ * C8NUM * params_->deep_,
            c_ptr_ + task_id * thread_stride_ * C8NUM, bias_ptr_ + task_id * thread_stride_ * C8NUM, ActType_No,
            params_->deep_, params_->row_, cur_oc, params_->col_, OutType_Nhwc);
  return RET_OK;
}

int MatmulFloatRun(void *cdata, int task_id) {
  auto op = reinterpret_cast<MatmulCPUKernel *>(cdata);
  auto error_code = op->RunImpl(task_id);
  if (error_code != RET_OK) {
    MS_LOG(ERROR) << "MatmulFp32Run error task_id[" << task_id << "] error_code[" << error_code << "]";
    return RET_ERROR;
  }
  return RET_OK;
}

int MatmulCPUKernel::Run() {
  auto prepare_ret = Prepare();
  if (prepare_ret != RET_OK) {
    MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
    return prepare_ret;
  }
  auto a_src = reinterpret_cast<float *>(in_tensors_[0]->data_c());
  auto b_src = reinterpret_cast<float *>(in_tensors_[1]->data_c());
  auto c_src = reinterpret_cast<float *>(out_tensors_[0]->data_c());
  auto a_shape = in_tensors_[0]->shape();
  auto b_shape = in_tensors_[1]->shape();

  if (params_->a_const_ == false || is_train()) {
    InitMatrixA(a_src, a_c12_ptr_);
  }
  if (params_->b_const_ == false || is_train()) {
    InitMatrixB(b_src, b_r8_ptr_);
  }
  if (!params_->a_broadcast_ && !params_->b_broadcast_) {
    for (int i = 0; i < params_->a_batch; ++i) {
      a_ptr_ = a_c12_ptr_ + i * params_->row_12_ * params_->deep_;
      b_ptr_ = b_r8_ptr_ + i * params_->deep_ * params_->col_8_;
      c_ptr_ = c_src + i * params_->row_ * params_->col_;
      ParallelLaunch(this->context_->thread_pool_, MatmulFloatRun, this, thread_count_);
    }
    return RET_OK;
  } else if (params_->a_broadcast_) {
    for (int i = 0; i < params_->a_batch; i++) {
      a_ptr_ = a_c12_ptr_ + i * params_->row_12_ * params_->deep_;
      for (int j = 0; j < params_->b_batch; j++) {
        b_ptr_ = b_r8_ptr_ + j * params_->deep_ * params_->col_8_;
        c_ptr_ = c_src + (i * params_->a_batch + j) * params_->row_ * params_->col_;
        ParallelLaunch(this->context_->thread_pool_, MatmulFloatRun, this, thread_count_);
      }
    }
    return RET_OK;
  } else if (params_->b_broadcast_) {
    for (int i = 0; i < params_->b_batch; i++) {
      b_ptr_ = b_r8_ptr_ + i * params_->deep_ * params_->col_8_;
      for (int j = 0; j < params_->a_batch; j++) {
        a_ptr_ = a_c12_ptr_ + j * params_->row_12_ * params_->deep_;
        c_ptr_ = c_src + (i * params_->a_batch + j) * params_->row_ * params_->col_;
        ParallelLaunch(this->context_->thread_pool_, MatmulFloatRun, this, thread_count_);
      }
    }
    return RET_OK;
  } else {
    MS_LOG(ERROR) << "Matmul op input shape error ,cannot broadcast";
    return RET_ERROR;
  }
}

void MatmulCPUKernel::eval() {
  // Copy weights after training
  LiteKernel::eval();
  if (params_->a_const_ == true) {
    InitMatrixA(reinterpret_cast<float *>(in_tensors_[0]->MutableData()), a_c12_ptr_);
  }
  if (params_->b_const_ == true) {
    InitMatrixB(reinterpret_cast<float *>(in_tensors_[1]->MutableData()), b_r8_ptr_);
  }
}

}  // namespace mindspore::kernel


In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef MINDSPORE_LITE_NNACL_MATMUL_H_
#define MINDSPORE_LITE_NNACL_MATMUL_H_

#include "nnacl/op_base.h"

typedef void (*MATMUL_OPT_R4_FUNC)(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16,
                                   const int *input_sum, const int *bias);

typedef void (*MATMUL_OPT_R_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
                                  size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
                                  int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini,
                                  int32_t maxi, size_t per_channel);

typedef enum OutType { OutType_C8 = 0, OutType_Nhwc = 1, OutType_TileC8 = 2 } OutType;

typedef struct MatMulParameter {
  OpParameter op_parameter_;
  int row_;
  int col_;
  int row_4_;
  int row_8_;
  int row_12_;
  int row_16_;
  int col_2_;
  int col_4_;
  int col_8_;
  int deep_;
  int deep_4_;
  int deep_16_;
  bool has_bias_;
  int batch;
  int a_batch;
  int b_batch;
  bool a_transpose_; /* false :  row-major  */
  bool b_transpose_; /* true  :  col-major  */
  bool a_const_;
  bool b_const_;
  bool a_broadcast_;
  bool b_broadcast_;
  ActType act_type_;
} MatMulParameter;

#endif  // MINDSPORE_LITE_NNACL_MATMUL_H_


In [None]:
TEST_F(TestMatMulFp32, broadcast_adims_large) {
  std::vector<lite::Tensor *> inputs_;
  std::vector<lite::Tensor *> outputs_;
  auto matmul_param = new MatMulParameter();
  matmul_param->a_transpose_ = false;
  matmul_param->b_transpose_ = false;
  matmul_param->has_bias_ = false;
  float a[] = {-0.65235930681229, -0.83840399980545, -1.78766810894012,
               0.70013177394867, 1.59338378906250, -0.27830731868744,
               -0.21799579262733, 0.43158695101738, -0.55378085374832,
               0.36990931630135, -1.17206823825836, -1.11876296997070,
               -0.71741801500320, 3.30769562721252, -0.90681165456772,
               0.91257780790329, -0.95812422037125, -1.72401988506317,
               0.38048243522644, -0.81240177154541, 0.01727002300322,
               -0.42814204096794, 1.09676754474640, -0.09681072831154,
               0.45356941223145, -1.34337413311005, 1.40391993522644,
               0.89548885822296, -0.21618857979774, 1.02820229530334,
               1.03970003128052, -0.71132820844650, -0.70775115489960,
               -0.72431832551956, 0.33484363555908, -0.30216658115387,
               1.33590710163116, 0.95974528789520, 0.30546423792839,
               0.68026691675186, 0.26281154155731, 1.30050456523895,
               -0.27555844187737, 0.86635190248489, -0.34620115160942};
  float b[] = {-2.10949733853340e-01, -1.75398275256157e-01, 6.04069411754608e-01,
               9.28257405757904e-01, -4.25255388021469e-01, -6.50140106678009e-01,
               4.82616513967514e-01, 2.19438716769218e-01, 5.27316093444824e-01,
               5.60069322586060e-01, 1.32913506031036e+00, -1.19743621349335e+00,
               -7.01747953891754e-01, 4.76675212383270e-01, 9.75661203265190e-02,
               -6.11941099166870e-01, 1.29469335079193e+00, -6.02801263332367e-01,
               -6.93235933780670e-01, 2.75707364082336e-01, 3.32885072566569e-04,
               -8.51491987705231e-01, 1.40798020362854e+00, 6.13633811473846e-01,
               -4.24172759056091e-01};
  std::vector<int> a_shape = {3, 3, 5};
  std::vector<int> b_shape = {5, 5};
  std::vector<int> c_shape = {3, 3, 5};
  int total_size = MMTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape);
  auto ctx = new lite::InnerContext;
  ctx->thread_num_ = 1;
  ASSERT_EQ(lite::RET_OK, ctx->Init());
  auto mm = new kernel::MatmulCPUKernel(reinterpret_cast<OpParameter *>(matmul_param), inputs_, outputs_, ctx, nullptr);
  mm->Init();
  mm->Run();
  float correct[] = {-2.12126612663269, 1.40011596679688, 2.49785637855530,
                     -1.40740084648132, -0.84939938783646, 1.11307847499847,
                     -1.60514271259308, 0.33582586050034, 0.44332292675972,
                     -0.27121970057487, -2.00336194038391, 5.57930183410645,
                     -3.72071981430054, -4.86936187744141, 1.09844863414764,
                     -2.09415149688721, 2.62629437446594, 0.17763727903366,
                     -1.24220609664917, -0.64340001344681, 1.79185485839844,
                     -2.03451776504517, -0.15619860589504, 0.65850496292114,
                     -0.35920926928520, 0.69350230693817, -1.31451427936554,
                     0.44618299603462, 0.70097935199738, 0.94919872283936,
                     -0.25420671701431, -0.90106028318405, 1.87669420242310,
                     0.96024185419083, -1.25131511688232, 0.28414657711983,
                     -0.63526278734207, 0.21078896522522, 1.08711969852448,
                     0.76600521802902, -1.79747617244720, 2.32795929908752,
                     -0.37217232584953, -0.01464513223618, 0.97543424367905};
  CompareOutputData(reinterpret_cast<float *>(outputs_[0]->MutableData()), correct, total_size, 0.0001);
  delete mm;
  for (auto t : inputs_) delete t;
  for (auto t : outputs_) delete t;
}
TEST_F(TestMatMulFp32, broadcast_bdims_large) {
  std::vector<lite::Tensor *> inputs_;
  std::vector<lite::Tensor *> outputs_;
  auto matmul_param = new MatMulParameter();
  matmul_param->a_transpose_ = false;
  matmul_param->b_transpose_ = false;
  matmul_param->has_bias_ = false;
  float b[] = {-0.65235930681229, -0.83840399980545, -1.78766810894012,
               0.70013177394867, 1.59338378906250, -0.27830731868744,
               -0.21799579262733, 0.43158695101738, -0.55378085374832,
               0.36990931630135, -1.17206823825836, -1.11876296997070,
               -0.71741801500320, 3.30769562721252, -0.90681165456772,
               0.91257780790329, -0.95812422037125, -1.72401988506317,
               0.38048243522644, -0.81240177154541, 0.01727002300322,
               -0.42814204096794, 1.09676754474640, -0.09681072831154,
               0.45356941223145, -1.34337413311005, 1.40391993522644,
               0.89548885822296, -0.21618857979774, 1.02820229530334,
               1.03970003128052, -0.71132820844650, -0.70775115489960,
               -0.72431832551956, 0.33484363555908, -0.30216658115387,
               1.33590710163116, 0.95974528789520, 0.30546423792839,
               0.68026691675186, 0.26281154155731, 1.30050456523895,
               -0.27555844187737, 0.86635190248489, -0.34620115160942};
  float a[] = {0.60076111555099,  1.64438819885254,  1.48806667327881,
               2.10185551643372,  0.82506781816483,  0.03017991222441,
               -0.08507440239191,  0.55047357082367,  1.09753358364105,
               0.74297022819519,  0.25211459398270,  1.22624528408051,
               -0.50102031230927,  2.38538837432861, -0.02676716446877};
  std::vector<int> a_shape = {1,5, 3};
  std::vector<int> b_shape = {3, 3, 5};
  std::vector<int> c_shape = {3, 5, 5};
  int total_size = MMTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape);
  auto ctx = new lite::InnerContext;
  ctx->thread_num_ = 1;
  ASSERT_EQ(lite::RET_OK, ctx->Init());
  auto mm = new kernel::MatmulCPUKernel(reinterpret_cast<OpParameter *>(matmul_param), inputs_, outputs_, ctx, nullptr);
  mm->Init();
  mm->Run();
  float correct[] = {-2.59367299079895e+00, -2.52694416046143e+00, -1.43183088302612e+00,
                     4.43205308914185e+00,  2.16121345758438e-01, -1.63616037368774e+00,
                     -1.97582948207855e+00, -3.42298316955566e+00,  1.11449503898621e+00,
                     3.62689518928528e+00, -1.38408601284027e+00, -1.27655410766602e+00,
                     -3.97728353738785e-01,  3.26590204238892e+00, -9.27187144756317e-01,
                     -1.99209201335907e+00, -2.04974699020386e+00, -2.09910511970520e+00,
                     4.43660688400269e+00,  1.65122762322426e-01, -3.05652856826782e-01,
                     -7.00010731816292e-02,  1.94436383247375e+00, -1.76030027866364e+00,
                     1.08332492411137e-01, -1.42239034175873e+00,  8.09490919113159e-01,
                     2.10033464431763e+00, -2.52318382263184e-01,  1.78781831264496e+00,
                     1.89181268215179e+00, -2.32471489906311e+00, -2.69170737266541e+00,
                     7.13319122791290e-01, -1.30229461193085e+00, -1.54252851009369e+00,
                     1.38668024539948e+00,  1.73324060440063e+00, -3.22935283184052e-01,
                     1.44727909564972e+00, -9.64934051036835e-01,  9.01751577854156e-01,
                     9.37045887112617e-02, -6.82048546150327e-03,  7.71589398384094e-01,
                     -3.80065977573395e-01, -5.78824281692505e-01,  3.45601582527161e+00,
                     -4.15773868560791e-01,  1.46144688129425e+00,  5.18813312053680e-01,
                     3.70464897155762e+00,  7.42955088615417e-01,  1.35634887218475e+00,
                     8.04613530635834e-01,  1.94392287731171e+00, -3.53646010160446e-01,
                     -7.04052031040192e-01, -1.24423730373383e+00,  1.25461089611053e+00,
                     3.36579121649265e-02,  2.22324490547180e+00,  2.86091297864914e-01,
                     1.18062126636505e+00, -3.39850485324860e-02,  1.01855707168579e+00,
                     1.40304362773895e+00, -6.21774435043335e-01,  6.01224958896637e-01,
                     -4.24346700310707e-03, -1.24873030185699e+00,  3.50823640823364e+00,
                     2.65133881568909e+00,  1.06835925579071e+00,  1.46420419216156e+00};
  CompareOutputData(reinterpret_cast<float *>(outputs_[0]->MutableData()), correct, total_size, 0.0001);
  delete mm;
  for (auto t : inputs_) delete t;
  for (auto t : outputs_) delete t;
}