<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

#Data process

In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "tools/optimizer/fusion/custom_multi_fullconnect_fusion.h"
#include <memory>
#include "src/ops/primitive_c.h"
#include "src/ops/conv2d.h"
#include "schema/inner/model_generated.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "src/ops/primitive_c.h"

namespace mindspore::opt {
namespace {
bool IsReshapeNode(const BaseRef &n) {
  if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
    auto type = opt::GetCNodeType(n);
    return type == schema::PrimitiveType_Reshape || type==schema::PrimitiveType_Add;
  }
  return false;
}
bool IsConvORAddNode(const BaseRef &n) {
  if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
    auto type = opt::GetCNodeType(n);
    return type == schema::PrimitiveType_Conv2D || type == schema::PrimitiveType_Add;
  }
  return false;
}
ValueNodePtr CreateSliceValueNode(std::vector<int> begin, std::vector<int> size, std::vector<int> axes) {
  auto slice_primitive = std::make_unique<schema::PrimitiveT>();
  slice_primitive->value.type = schema::PrimitiveType_Slice;
  auto attr = std::make_unique<schema::SliceT>();
  attr->format = schema::Format::Format_NHWC;
  attr->begin = std::move(begin);
  attr->axes = std::move(axes);
  attr->size = std::move(size);
  slice_primitive->value.value = attr.release();
  auto primitive_value = lite::PrimitiveC::Create(slice_primitive.release());
  auto value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(primitive_value));
  return value_node;
}
STATUS JointMultiFCWeights(std::vector<ParameterPtr> joint_fc_weights, ParameterPtr jointed_new_paramter) {
  auto joint_size = joint_fc_weights.size();
  if (joint_size < 2) {
    MS_LOG(WARNING) << "joint fc weight size at least 2";
    return RET_FAILED;
  }
  auto fc_weight_param = std::dynamic_pointer_cast<ParamValueLite>(joint_fc_weights[0]->default_param());
  auto fc_weight_size = fc_weight_param->tensor_size();
  auto fc_weight_shape = fc_weight_param->tensor_shape();
  auto new_tensor_data = new(std::nothrow) char[joint_size * fc_weight_size];
  if (new_tensor_data == nullptr) {
    MS_LOG(ERROR) << "tensor_data is nullptr";
    return RET_ERROR;
  }
  for (size_t i = 0; i < joint_size; i++) {
    auto weight_param = std::dynamic_pointer_cast<ParamValueLite>(joint_fc_weights[i]->default_param());
    auto tensor_shape = fc_weight_param->tensor_shape();
    if (tensor_shape != fc_weight_shape) {
      MS_LOG(WARNING) << "joint fc weight shape must same";
      return RET_FAILED;
    }
    auto tensor_addr = weight_param->tensor_addr();
    if (tensor_addr == nullptr) {
      MS_LOG(ERROR) << "input tensor addr nullptr";
      return RET_ERROR;
    }
    if (EOK != memcpy_s(new_tensor_data + i * fc_weight_size, fc_weight_size, tensor_addr, fc_weight_size)) {
      MS_LOG(ERROR) << "memcpy_s data failed";
      return RET_ERROR;
    }
  }
  auto type_ptr = TypeIdToType(fc_weight_param->tensor_type());
  auto jointed_param_shape = fc_weight_shape;
  jointed_param_shape.insert(jointed_param_shape.begin(), joint_size);
  auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, jointed_param_shape);
  jointed_new_paramter->set_abstract(abstract_tensor);
  ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
  MS_ASSERT(param_value != nullptr);
  param_value->set_tensor_shape(jointed_param_shape);
  param_value->set_tensor_type(fc_weight_param->tensor_type());
  param_value->set_format(fc_weight_param->format());
  param_value->set_tensor_addr(new_tensor_data);
  param_value->set_tensor_size(joint_size * fc_weight_size);
  jointed_new_paramter->set_default_param(param_value);
  return RET_OK;
}
}
const BaseRef CustomMultiFCFusion::DefinePattern() const {
  auto reshape_var = std::make_shared<CondVar>(IsReshapeNode);
  auto fc_var = std::make_shared<CondVar>(IsFullConnectNode);
  auto fc_weight = std::make_shared<CondVar>(IsParamNode);
  auto fc_bias = std::make_shared<CondVar>(IsParamNode);
  return VectorRef({fc_var, reshape_var, fc_weight, fc_bias});
}

const AnfNodePtr CustomMultiFCFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
                                              const EquivPtr &) const {
  MS_ASSERT(func_graph != nullptr);
  MS_ASSERT(node != nullptr);
  auto fc_cnode = node->cast<CNodePtr>();
  MS_ASSERT(fc_cnode != nullptr);
  auto reshape_node = fc_cnode->input(1);
  // reshape cnode all outputs must fullconnect node
  auto reshape_output_nodes = GetRealNodeUsedListByOutputIdx(func_graph, reshape_node, 0);
  std::vector<ParameterPtr> jointed_fc_weights;
  std::vector<schema::QuantParamT> jointed_quant_params;
  std::vector<std::vector<schema::QuantParamT>> input_quant_params;
  std::vector<std::vector<schema::QuantParamT>> output_quant_params;
  for (auto node_pair:*reshape_output_nodes) {
    if (!IsFullConnectNode(node_pair.first) || node_pair.first->cast<CNodePtr>() == nullptr) {
      MS_LOG(WARNING) << "reshape all output nodes must be fullconnect node";
      return nullptr;
    }
    auto fc_weight = node_pair.first->cast<CNodePtr>()->input(2);
    if (fc_weight == nullptr || !fc_weight->isa<Parameter>()) {
      MS_LOG(WARNING) << "fullconnect node weight must paramter";
      return nullptr;
    }
    jointed_fc_weights.push_back(fc_weight->cast<ParameterPtr>());
    auto fc_node = node_pair.first->cast<CNodePtr>();
    auto fc_prim = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(fc_node->input(0));
    auto fc_input_quantParams = fc_prim->GetInputQuantParams();
    if (fc_input_quantParams.size() > 1 && !fc_input_quantParams[1].empty()) {
      jointed_quant_params.push_back(fc_input_quantParams[1][0]);
    }
    input_quant_params = fc_input_quantParams;
    output_quant_params = fc_prim->GetOutputQuantParams();
  }
  auto jointed_new_paramter = func_graph->add_parameter();
  if (JointMultiFCWeights(jointed_fc_weights, jointed_new_paramter) != RET_OK) {
    MS_LOG(WARNING) << "fullconnect node weight joint new paramter failed";
    return nullptr;
  }
  auto fc_node = reshape_output_nodes->at(0).first->cast<CNodePtr>();
// create batchmatmul node replace multi fullconnect
  auto matmul_primitive = std::make_unique<schema::PrimitiveT>();
  std::unique_ptr<schema::MatMulT> attr = std::make_unique<schema::MatMulT>();
  matmul_primitive->value.type = schema::PrimitiveType_MatMul;
  matmul_primitive->value.value = attr.release();
  auto matmul_cvalue = lite::PrimitiveC::Create(matmul_primitive.release());
  input_quant_params.pop_back();
  input_quant_params.pop_back();
  input_quant_params.emplace_back(jointed_quant_params);
  matmul_cvalue->SetInputQuantParam(input_quant_params);
  matmul_cvalue->SetOutputQuantParam(output_quant_params);
  auto matmul_value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(matmul_cvalue));
  std::vector<AnfNodePtr> matmul_inputs = {matmul_value_node, reshape_node, jointed_new_paramter};
  auto new_matmul_node = func_graph->NewCNode(matmul_inputs);
  new_matmul_node->set_fullname_with_scope("matmul_" + fc_node->fullname_with_scope());

  // create same size slice node
  std::vector<int> slice_begin = {0, 0, 0};
  std::vector<int> slice_size = {1, -1, -1};
  std::vector<int> slice_axes = {0, 1, 2};
  auto manager = func_graph->manager();
  for (size_t i = 0; i < jointed_fc_weights.size(); i++) {
    slice_begin[0] = i;
    std::vector<AnfNodePtr> op_inputs = {CreateSliceValueNode(slice_begin, slice_size, slice_axes), new_matmul_node};
    auto slice_cnode = func_graph->NewCNode(op_inputs);
    manager->Replace(reshape_output_nodes->at(i).first, slice_cnode);
  }
  return nullptr;
}
}  // namespace mindspore::opt


In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "src/runtime/kernel/arm/fp32/matmul.h"
#include "include/errorcode.h"
#include "nnacl/fp32/matmul.h"
#include "src/runtime/runtime_api.h"

using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_INPUT_TENSOR_ERROR;
using mindspore::lite::RET_MEMORY_FAILED;
using mindspore::lite::RET_OK;

namespace mindspore::kernel {
MatmulCPUKernel::~MatmulCPUKernel() { FreeTmpBuffer(); }

void MatmulCPUKernel::FreeTmpBuffer() {
  if (a_pack_ptr_ != nullptr) {
    free(a_pack_ptr_);
    a_pack_ptr_ = nullptr;
  }
  if (b_pack_ptr_ != nullptr) {
    free(b_pack_ptr_);
    b_pack_ptr_ = nullptr;
  }
  if (bias_ptr_ != nullptr) {
    free(bias_ptr_);
    bias_ptr_ = nullptr;
  }
}

int MatmulCPUKernel::MallocMatrixABuffer() {
  auto a_shape = in_tensors_[0]->shape();
  int batch = 1;
  for (size_t i = 0; i < a_shape.size() - 2; ++i) {
    batch *= a_shape[i];
  }
  params_->a_batch = batch;
  params_->row_ = params_->a_transpose_ ? a_shape[a_shape.size() - 1] : a_shape[a_shape.size() - 2];
#ifdef ENABLE_ARM64
  if (params_->row_ == 1) {
    is_vector_a_ = true;
  }
#endif
  params_->deep_ = params_->a_transpose_ ? a_shape[a_shape.size() - 2] : a_shape[a_shape.size() - 1];
  params_->row_4_ = UP_ROUND(params_->row_, C4NUM);
  params_->row_12_ = UP_ROUND(params_->row_, C12NUM);

#ifdef ENABLE_ARM32
  a_pack_ptr_ = reinterpret_cast<float *>(malloc(params_->a_batch * params_->row_4_ * params_->deep_ * sizeof(float)));
  if (a_pack_ptr_ == nullptr) {
    FreeTmpBuffer();
    return RET_MEMORY_FAILED;
  }
  memset(a_pack_ptr_, 0, params_->row_4_ * params_->deep_ * sizeof(float));
#else
  int row_tmp = is_vector_a_ ? 1 : params_->row_12_;
  a_pack_ptr_ = reinterpret_cast<float *>(malloc(params_->a_batch * row_tmp * params_->deep_ * sizeof(float)));
  if (a_pack_ptr_ == nullptr) {
    FreeTmpBuffer();
    return RET_MEMORY_FAILED;
  }
  memset(a_pack_ptr_, 0, params_->a_batch * row_tmp * params_->deep_ * sizeof(float));
#endif
  return RET_OK;
}

int MatmulCPUKernel::MallocMatrixBBuffer() {
  auto b_shape = in_tensors_[1]->shape();
  if (b_shape.empty()) {
    return RET_OK;
  }
  int batch = 1;
  for (size_t i = 0; i < b_shape.size() - 2; ++i) {
    batch *= b_shape[i];
  }
  params_->b_batch = batch;
  params_->col_ = params_->b_transpose_ ? b_shape[b_shape.size() - 2] : b_shape[b_shape.size() - 1];
  params_->col_8_ = UP_ROUND(params_->col_, 8);
  params_->deep_ = params_->b_transpose_ ? b_shape[b_shape.size() - 1] : b_shape[b_shape.size() - 2];

  int col_tmp = is_vector_a_ ? params_->col_ : params_->col_8_;
  b_pack_ptr_ = reinterpret_cast<float *>(malloc(params_->b_batch * col_tmp * params_->deep_ * sizeof(float)));
  if (b_pack_ptr_ == nullptr) {
    FreeTmpBuffer();
    return RET_MEMORY_FAILED;
  }
  memset(b_pack_ptr_, 0, params_->b_batch * col_tmp * params_->deep_ * sizeof(float));

  thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, 8));
  thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, 8), thread_count_);
  return RET_OK;
}

int MatmulCPUKernel::InitBias() {
  if (in_tensors_.size() == 3) {
    auto c_shape = out_tensors_[0]->shape();
    auto bias_shape = in_tensors_[1]->shape();
    if (bias_shape[bias_shape.size() - 1] != c_shape[c_shape.size() - 1]) {
      MS_LOG(ERROR) << "The bias'dimension is not equal with colum";
      FreeTmpBuffer();
      return RET_INPUT_TENSOR_ERROR;
    }
    auto col = c_shape[c_shape.size() - 1];
    auto col_8 = UP_ROUND(col, 8);
    auto col_tmp = is_vector_a_ ? col : col_8;
    bias_ptr_ = reinterpret_cast<float *>(malloc(col_tmp * sizeof(float)));
    if (bias_ptr_ == nullptr) {
      FreeTmpBuffer();
      return RET_MEMORY_FAILED;
    }
    memcpy(bias_ptr_, in_tensors_[2]->data_c(), in_tensors_[2]->ElementsNum() * sizeof(float));
  }
  return RET_OK;
}

int MatmulCPUKernel::ReSize() {
  if (params_->a_const_ == false || params_->a_has_shape_ == false) {
    if (a_pack_ptr_ != nullptr) {
      free(a_pack_ptr_);
      a_pack_ptr_ = nullptr;
    }
    auto ret = MallocMatrixABuffer();
    if (ret != RET_OK) {
      MS_LOG(ERROR) << "Matmul fp32 malloc matrix a buffer failed";
      return RET_ERROR;
    }
  }
  if (params_->b_const_ == false || params_->b_has_shape_ == false) {
    if (b_pack_ptr_ != nullptr) {
      free(b_pack_ptr_);
      b_pack_ptr_ = nullptr;
    }
    auto ret = MallocMatrixBBuffer();
    if (ret != RET_OK) {
      MS_LOG(ERROR) << "Matmul fp32 malloc matrix b buffer failed";
      return RET_ERROR;
    }
  }
  if (bias_ptr_ != nullptr) {
    free(bias_ptr_);
    bias_ptr_ = nullptr;
  }
  auto ret = InitBias();
  if (ret != RET_OK) {
    MS_LOG(ERROR) << "Matmul fp32 init bias failed";
    return RET_ERROR;
  }
  return RET_OK;
}

void MatmulCPUKernel::InitMatrixA(float *src_ptr, float *dst_ptr) {
  if (is_vector_a_) {
    memcpy(dst_ptr, src_ptr, params_->a_batch * params_->deep_ * sizeof(float));
    return;
  }

  for (int i = 0; i < params_->a_batch; i++) {
    float *src = src_ptr + i * params_->deep_ * params_->row_;
#ifdef ENABLE_ARM32
    float *dst = dst_ptr + i * params_->deep_ * params_->row_4_;
    if (params_->a_transpose_) {
      RowMajor2Row4Major(src, dst, params_->deep_, params_->row_);
    } else {
      RowMajor2Col4Major(src, dst, params_->row_, params_->deep_);
    }
#else
    float *dst = dst_ptr + i * params_->deep_ * params_->row_12_;
    if (params_->a_transpose_) {
      RowMajor2Row12Major(src, dst, params_->deep_, params_->row_);
    } else {
      RowMajor2Col12Major(src, dst, params_->row_, params_->deep_);
    }
#endif
  }
  return;
}

void MatmulCPUKernel::InitMatrixB(float *src_ptr, float *dst_ptr) {
  if (is_vector_a_) {
    if (params_->b_transpose_) {
      memcpy(dst_ptr, src_ptr, params_->b_batch * params_->col_ * params_->deep_ * sizeof(float));
    } else {
      for (int i = 0; i < params_->b_batch; i++) {
        float *src = src_ptr + i * params_->deep_ * params_->col_;
        float *dst = dst_ptr + i * params_->deep_ * params_->col_;
        RowMajor2ColMajor(src, dst, params_->deep_, params_->col_);
      }
    }
    return;
  }

  for (int i = 0; i < params_->b_batch; i++) {
    float *src = src_ptr + i * params_->deep_ * params_->col_;
    float *dst = dst_ptr + i * params_->deep_ * params_->col_8_;
    if (params_->b_transpose_) {
      RowMajor2Col8Major(src, dst, params_->col_, params_->deep_);
    } else {
      RowMajor2Row8Major(src, dst, params_->deep_, params_->col_);
    }
  }
  return;
}

int MatmulCPUKernel::Init() {
  auto a_shape = in_tensors_[0]->shape();
  auto b_shape = in_tensors_[1]->shape();
  params_->a_has_shape_ = (a_shape.size() != 0);
  params_->b_has_shape_ = (b_shape.size() != 0);
  if (params_->a_has_shape_) {
    auto ret = MallocMatrixABuffer();
    if (ret != RET_OK) {
      MS_LOG(ERROR) << "Matmul fp32 malloc matrix a buffer failed";
      return RET_ERROR;
    }
  }
  if (params_->b_has_shape_) {
    auto ret = MallocMatrixBBuffer();
    if (ret != RET_OK) {
      MS_LOG(ERROR) << "Matmul fp32 malloc matrix b buffer failed";
      return RET_ERROR;
    }
  }

  bool a_broadcast = false;
  bool b_broadcast = false;
  if (a_shape.size() == 2 && b_shape.size() == 2) {
    a_broadcast = false;
    b_broadcast = false;
  } else if (a_shape.size() == 2 && (b_shape.size() != 2 && b_shape[b_shape.size() - 3] != 1)) {
    a_broadcast = true;
  } else if ((a_shape.size() != 2 && a_shape[a_shape.size() - 3] != 1) && b_shape.size() == 2) {
    b_broadcast = true;
  } else if (a_shape[a_shape.size() - 3] == 1 && b_shape.size() > 2 && b_shape[b_shape.size() - 3] != 1) {
    a_broadcast = true;
  } else if (a_shape[a_shape.size() - 3] != 1 && a_shape.size() > 2 && b_shape[b_shape.size() - 3] == 1) {
    b_broadcast = true;
  }
  params_->a_broadcast_ = a_broadcast;
  params_->b_broadcast_ = b_broadcast;

  params_->a_const_ = (in_tensors_[0]->data_c() != nullptr);
  params_->b_const_ = (in_tensors_[1]->data_c() != nullptr);
  if (params_->a_const_) {
    InitMatrixA(reinterpret_cast<float *>(in_tensors_[0]->data_c()), a_pack_ptr_);
    a_ptr_ = a_pack_ptr_;
  }
  if (params_->b_const_) {
    InitMatrixB(reinterpret_cast<float *>(in_tensors_[1]->data_c()), b_pack_ptr_);
    b_ptr_ = b_pack_ptr_;
  }
  if (!InferShapeDone()) {
    return RET_OK;
  }
  auto ret = InitBias();
  if (ret != RET_OK) {
    MS_LOG(ERROR) << "Matmul fp32 init bias failed";
    return RET_ERROR;
  }
  return RET_OK;
}

int MatmulCPUKernel::RunImpl(int task_id) {
  int cur_oc = MSMIN(thread_stride_ * C8NUM, params_->col_ - task_id * thread_stride_ * C8NUM);
  if (cur_oc <= 0) {
    return RET_OK;
  }
  auto b = cur_b_ptr_ + task_id * thread_stride_ * C8NUM * params_->deep_;
  auto c = cur_c_ptr_ + task_id * thread_stride_ * C8NUM;
  auto bias = bias_ptr_ ? bias_ptr_ + task_id * thread_stride_ * C8NUM : NULL;
  if (is_vector_a_) {
    MatVecMul(cur_a_ptr_, b, c, bias, ActType_No, params_->deep_, cur_oc);
  } else {
    MatMulOpt(cur_a_ptr_, b, c, bias, ActType_No, params_->deep_, params_->row_, cur_oc, params_->col_, OutType_Nhwc);
  }
  return RET_OK;
}

int MatmulFloatRun(void *cdata, int task_id) {
  auto op = reinterpret_cast<MatmulCPUKernel *>(cdata);
  auto error_code = op->RunImpl(task_id);
  if (error_code != RET_OK) {
    MS_LOG(ERROR) << "MatmulFp32Run error task_id[" << task_id << "] error_code[" << error_code << "]";
    return RET_ERROR;
  }
  return RET_OK;
}

int MatmulCPUKernel::Run() {
  auto prepare_ret = Prepare();
  if (prepare_ret != RET_OK) {
    MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
    return prepare_ret;
  }
  auto a_src = reinterpret_cast<float *>(in_tensors_[0]->data_c());
  auto b_src = reinterpret_cast<float *>(in_tensors_[1]->data_c());
  auto c_src = reinterpret_cast<float *>(out_tensors_[0]->data_c());

  if (params_->a_const_ == false || is_train()) {
    if (is_vector_a_) {
      a_ptr_ = a_src;
    } else {
      InitMatrixA(a_src, a_pack_ptr_);
      a_ptr_ = a_pack_ptr_;
    }
  }
  if (params_->b_const_ == false || is_train()) {
    if (is_vector_a_) {
      b_ptr_ = b_src;
    } else {
      InitMatrixB(b_src, b_pack_ptr_);
      b_ptr_ = b_pack_ptr_;
    }
  }
  if (!params_->a_broadcast_ && !params_->b_broadcast_) {
    for (int i = 0; i < params_->a_batch; ++i) {
      if (is_vector_a_) {
        cur_a_ptr_ = a_ptr_ + i * params_->deep_;
        cur_b_ptr_ = b_ptr_ + i * params_->deep_ * params_->col_;
        cur_c_ptr_ = c_src + i * params_->row_ * params_->col_;
      } else {
        cur_a_ptr_ = a_ptr_ + i * params_->row_12_ * params_->deep_;
        cur_b_ptr_ = b_ptr_ + i * params_->deep_ * params_->col_8_;
        cur_c_ptr_ = c_src + i * params_->row_ * params_->col_;
      }
      ParallelLaunch(this->context_->thread_pool_, MatmulFloatRun, this, thread_count_);
    }
  } else if (params_->a_broadcast_) {
    for (int i = 0; i < params_->a_batch; i++) {
      cur_a_ptr_ = a_ptr_ + i * params_->row_12_ * params_->deep_;
      for (int j = 0; j < params_->b_batch; j++) {
        cur_b_ptr_ = b_ptr_ + j * params_->deep_ * params_->col_8_;
        cur_c_ptr_ = c_src + (i * params_->a_batch + j) * params_->row_ * params_->col_;
        ParallelLaunch(this->context_->thread_pool_, MatmulFloatRun, this, thread_count_);
      }
    }
    MS_LOG(ERROR) << "Matmul op input shape error ,cannot broadcast";
  } else if (params_->b_broadcast_) {
    for (int i = 0; i < params_->b_batch; i++) {
      cur_b_ptr_ = b_ptr_ + i * params_->deep_ * params_->col_8_;
      for (int j = 0; j < params_->a_batch; j++) {
        cur_a_ptr_ = a_ptr_ + j * params_->row_12_ * params_->deep_;
        cur_c_ptr_ = c_src + (i * params_->a_batch + j) * params_->row_ * params_->col_;
        ParallelLaunch(this->context_->thread_pool_, MatmulFloatRun, this, thread_count_);
      }
    }
  } else {
    MS_LOG(ERROR) << "Matmul op input shape error ,cannot broadcast";
    return RET_ERROR;
  }
  return RET_OK;
  }

  void MatmulCPUKernel::eval() {
    // Copy weights after training
    LiteKernel::eval();
    if (params_->a_const_ == true) {
      InitMatrixA(reinterpret_cast<float *>(in_tensors_[0]->MutableData()), a_pack_ptr_);
    }
    if (params_->b_const_ == true) {
      InitMatrixB(reinterpret_cast<float *>(in_tensors_[1]->MutableData()), b_pack_ptr_);
    }
  }

}  // namespace mindspore::kernel


In [None]:
 if (input_tensor->shape().size() == 3
        && input_tensor->GetQuantParams().size() == input_tensor->shape()[0]) { // per batch matmul
      auto per_batch_size = input_tensor->shape()[0];
      auto quant_param = input_tensor->GetQuantParams();
      for (size_t i = 0; i < per_batch_size; i++) {
        auto param = quant_param.at(i);
        auto scale = param.scale;
        auto zero_point = param.zeroPoint;
        auto matrix_size = input_tensor->ElementsNum() / per_batch_size;
        for (int64_t j = 0; j < matrix_size; j++) {
          dequant_datas[i * matrix_size + j] =
              static_cast<float>((quant_datas[i * matrix_size + j] - zero_point) * scale);
        }
      }
      return dequant_datas;
    } else if (input_tensor->GetQuantParams().size() != kPerTensor) {

In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CUSTOM_MULTI_FULLCONNECT_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_CUSTOM_MULTI_FULLCONNECT_FUSION_H_

#include <string>
#include "backend/optimizer/common/optimizer.h"
#include "schema/inner/model_generated.h"

namespace mindspore {
namespace opt {
class CustomMultiFCFusion : public PatternProcessPass {
 public:
  explicit CustomMultiFCFusion(bool multigraph = true) : PatternProcessPass("custom_multiFCFusion", multigraph) {}
  ~CustomMultiFCFusion() override = default;
  const BaseRef DefinePattern() const override;
  const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
}  // namespace opt
}  // namespace mindspore
#endif  // MINDSPORE_LITE_SRC_PASS_FUSION_CUSTOM_MULTI_FULLCONNECT_FUSION_H_
