<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "tools/optimizer/fusion/weight_format_hardcode_pass.h"
#include "tools/optimizer/common/gllo_utils.h"

using mindspore::lite::converter::FmkType_CAFFE;
using mindspore::lite::converter::FmkType_TFLITE;
using mindspore::lite::converter::FmkType_ONNX;
using mindspore::lite::converter::FmkType_MS;
using mindspore::schema::QuantType_WeightQuant;
using mindspore::schema::QuantType_QUANT_NONE;
using mindspore::schema::QuantType_AwareTraining;
using mindspore::schema::QuantType_PostTraining;
namespace mindspore::opt {
namespace {
constexpr size_t kConvWeightIndex = 2;
bool IsConvExtendNode(const BaseRef &n) {
  if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
    auto type = opt::GetCNodeType(n);
    return type == schema::PrimitiveType_Conv2D || type == schema::PrimitiveType_DepthwiseConv2D ||
        type == schema::PrimitiveType_DeConv2D || type == schema::PrimitiveType_DeDepthwiseConv2D;
  }
  return false;
}
}  // namespace
void WeightFormatHardCodePass::SetQuantType(QuantType type) {
  this->quant_type = type;
}
void WeightFormatHardCodePass::SetFmkType(FmkType type) {
  this->fmk_type = type;
}
lite::STATUS WeightFormatHardCodePass::HardCodeCAFFE(const AnfNodePtr &conv_node,
                                                     const ParamValueLitePtr &param_value) const {
  MS_ASSERT(conv_cnode != nullptr);
  MS_ASSERT(param_value != nullptr);
  switch (quant_type) {
    case QuantType_WeightQuant:
    case QuantType_QUANT_NONE:param_value->set_format(schema::Format::Format_KCHW);
      break;
    default: {
      MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) << ", node: "
                    << conv_node->fullname_with_scope();
      return lite::RET_ERROR;
    }
  }
  return lite::RET_OK;
}

lite::STATUS WeightFormatHardCodePass::HardCodeONNX(const AnfNodePtr &conv_node,
                                                    const ParamValueLitePtr &param_value) const {
  MS_ASSERT(conv_cnode != nullptr);
  MS_ASSERT(param_value != nullptr);
  auto op_type = GetCNodeType(conv_node);
  switch (this->quant_type) {
    case QuantType_AwareTraining: {
      // sum up from current onnx quant models
      if (op_type == schema::PrimitiveType_Conv2D) {
        param_value->set_format(schema::Format::Format_KHWC);
      } else if (op_type == schema::PrimitiveType_DepthwiseConv2D) {
        param_value->set_format(schema::Format::Format_CHWK);
      } else if (op_type == schema::PrimitiveType_DeConv2D) {
        param_value->set_format(schema::Format::Format_CKHW);
      } else {
        MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: "
                      << conv_node->fullname_with_scope();
        return lite::RET_ERROR;
      }
    }
      break;
    case QuantType_WeightQuant:
    case QuantType_QUANT_NONE: {
      // conv (K x C/group x kH x kW) group = 1
      // depth (K x C/group x kH x kW) group = channelOut ==> (K, multiplier, H, W)
      // deconv (C x K/group x kH x kW) group = 1
      // dedepth (C x K/group x kH x kW) group = channelIn ==> (C, multiplier, H, W)
      if (op_type == schema::PrimitiveType_Conv2D || op_type == schema::PrimitiveType_DepthwiseConv2D) {
        param_value->set_format(schema::Format::Format_KCHW);
      } else if (op_type == schema::PrimitiveType_DeConv2D) {
        param_value->set_format(schema::Format::Format_CKHW);
      } else {
        MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: "
                      << conv_node->fullname_with_scope();
        return lite::RET_ERROR;
      }
    }
      break;
    default: {
      MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) << ", node: "
                    << conv_node->fullname_with_scope();
      return lite::RET_ERROR;
    }
  }
  return lite::RET_OK;
}

lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node,
                                                  const ParamValueLitePtr &param_value) const {
  MS_ASSERT(conv_cnode != nullptr);
  MS_ASSERT(param_value != nullptr);
  auto op_type = GetCNodeType(conv_node);
  switch (this->quant_type) {
    case QuantType_AwareTraining: {
      if (op_type == schema::PrimitiveType_Conv2D) {
        param_value->set_format(schema::Format::Format_KCHW);
      } else if (op_type == schema::PrimitiveType_DepthwiseConv2D) {
        param_value->set_format(schema::Format::Format_CKHW);
      } else {
        param_value->set_format(schema::Format::Format_KCHW);
      }
    }
      break;
    case QuantType_WeightQuant:
    case QuantType_QUANT_NONE: {
      // sum up from current ms quant models
      if (op_type == schema::PrimitiveType_Conv2D) {
        param_value->set_format(schema::Format::Format_KCHW);
      } else if (op_type == schema::PrimitiveType_DepthwiseConv2D) {
        param_value->set_format(schema::Format::Format_CKHW);
      } else if (op_type == schema::PrimitiveType_DeConv2D) {
        param_value->set_format(schema::Format::Format_KCHW);
      } else {
        MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: "
                      << conv_node->fullname_with_scope();
        return lite::RET_ERROR;
      }
    }
      break;
    default: {
      MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) << ", node: "
                    << conv_node->fullname_with_scope();
      return lite::RET_ERROR;
    }
  }
  return lite::RET_OK;
}

lite::STATUS WeightFormatHardCodePass::HardCodeTFLITE(const AnfNodePtr &conv_node,
                                                      const ParamValueLitePtr &param_value) const {
  MS_ASSERT(conv_cnode != nullptr);
  MS_ASSERT(param_value != nullptr);
  auto op_type = GetCNodeType(conv_node);
  switch (this->quant_type) {
    case QuantType_AwareTraining:
    case QuantType_PostTraining:
    case QuantType_WeightQuant:
    case QuantType_QUANT_NONE: {
      if (op_type == schema::PrimitiveType_Conv2D) {
        param_value->set_format(schema::Format::Format_KHWC);
      } else if (op_type == schema::PrimitiveType_DepthwiseConv2D) {
        param_value->set_format(schema::Format::Format_CHWK);
      } else if (op_type == schema::PrimitiveType_DeConv2D) {
        param_value->set_format(schema::Format::Format_CHWK);
      } else {
        MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: "
                      << conv_node->fullname_with_scope();
        return lite::RET_ERROR;
      }
    }
      break;
    default: {
      MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: "
                    << conv_node->fullname_with_scope();
      return lite::RET_ERROR;
    }
  }
  return lite::RET_OK;
}

const BaseRef WeightFormatHardCodePass::DefinePattern() const {
  auto conv_var = std::make_shared<CondVar>(IsConvExtendNode);
  auto weight_var = std::make_shared<CondVar>(IsParamNode);
  auto other_var = std::make_shared<SeqVar>();
  return VectorRef({conv_var, conv_var, weight_var, other_var});
}

const AnfNodePtr WeightFormatHardCodePass::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
                                                   const EquivPtr &) const {
  MS_ASSERT(func_graph != nullptr);
  MS_ASSERT(node != nullptr);
  if (!utils::isa<CNode>(node)) {
    MS_LOG(ERROR) << "weight format hardcode pass only support cnode";
    return nullptr;
  }
  auto conv_cnode = node->cast<CNodePtr>();
  MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex);
  auto param_value = GetLiteParamValue(conv_cnode->input(kConvWeightIndex));
  lite::STATUS status;
  switch (fmk_type) {
    case FmkType_CAFFE:status = HardCodeCAFFE(node, param_value);
      break;
    case FmkType_TFLITE:status = HardCodeTFLITE(node, param_value);
      break;
    case FmkType_ONNX:status = HardCodeONNX(node, param_value);
      break;
    case FmkType_MS:status = HardCodeMS(node, param_value);
      break;
    default:MS_LOG(ERROR) << "Unsupported fmkType: " << fmk_type << ", node: " << node->fullname_with_scope();
      return nullptr;
  }
  if (status != lite::RET_OK) {
    MS_LOG(ERROR) << "schema::Format hardCode faild: " << status << ", node: " << node->fullname_with_scope();
    return nullptr;
  }
  return node;
}
}  // namespace mindspore::opt


In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_WEIGHT_FORMAT_HARDCODE_PASS_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_WEIGHT_FORMAT_HARDCODE_PASS_H_
#include <string>
#include "backend/optimizer/common/optimizer.h"
#include "schema/inner/model_generated.h"
#include "tools/converter/converter_flags.h"
#include "src/param_value_lite.h"

using mindspore::lite::converter::FmkType;
using mindspore::schema::QuantType;
namespace mindspore::opt {
class WeightFormatHardCodePass : public PatternProcessPass {
 public:
  explicit WeightFormatHardCodePass(bool multigraph = true) : PatternProcessPass("weight_format_hardcode_pass",
                                                                                 multigraph) {}
  ~WeightFormatHardCodePass() override = default;
  void SetQuantType(QuantType type);
  void SetFmkType(FmkType fmkType);
  const BaseRef DefinePattern() const override;
  const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;

 private:
  lite::STATUS HardCodeCAFFE(const AnfNodePtr &node, const ParamValueLitePtr &param_value) const;
  lite::STATUS HardCodeONNX(const AnfNodePtr &node, const ParamValueLitePtr &param_value) const;
  lite::STATUS HardCodeMS(const AnfNodePtr &node, const ParamValueLitePtr &param_value) const;
  lite::STATUS HardCodeTFLITE(const AnfNodePtr &node, const ParamValueLitePtr &param_value) const;

 private:
  QuantType quant_type = schema::QuantType_QUANT_NONE;
  FmkType fmk_type = lite::converter::FmkType_TF;
};
}  // namespace mindspore::opt
#endif  // MINDSPORE_LITE_SRC_PASS_FUSION_WEIGHT_FORMAT_HARDCODE_PASS_H_


In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "tools/optimizer/fusion/weight_format_transform_pass.h"
#include "tools/optimizer/common/gllo_utils.h"

using mindspore::lite::converter::FmkType_CAFFE;
using mindspore::lite::converter::FmkType_TFLITE;
using mindspore::lite::converter::FmkType_ONNX;
using mindspore::lite::converter::FmkType_MS;
using mindspore::schema::QuantType_WeightQuant;
using mindspore::schema::QuantType_QUANT_NONE;
using mindspore::schema::QuantType_AwareTraining;
using mindspore::schema::QuantType_PostTraining;

namespace mindspore::opt {
namespace {
constexpr size_t kConvWeightIndex = 2;
}  // namespace
void WeightFormatTransformPass::SetQuantType(QuantType type) {
  this->quant_type = type;
}
void WeightFormatTransformPass::SetFmkType(FmkType type) {
  this->fmk_type = type;
}
void WeightFormatTransformPass::SetDstFormat(schema::Format format) {
  this->dst_format = format;
}
lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr &graph, bool quant_flag) {
  MS_ASSERT(graph != nullptr);
  std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return());
  for (auto &node : node_list) {
    if (!utils::isa<CNodePtr>(node)) {
      continue;
    }
    auto type = opt::GetCNodeType(node);
    if (quant_flag) {
      if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D) {
        continue;
      }
    } else {
      if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D
          && type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) {
        continue;
      }
    }
    auto conv_cnode = node->cast<CNodePtr>();
    MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex);
    auto weight_node = conv_cnode->input(kConvWeightIndex);
    auto weight_value = GetLiteParamValue(weight_node);
    MS_ASSERT(weight_value->tensor_type() == TypeId::kNumberTypeFloat32
                  || weight_value->tensor_type() == TypeId::kNumberTypeUInt8);
    lite::STATUS status;
    schema::Format weight_dst_format = schema::Format::Format_KHWC;
    if (dst_format != schema::Format::Format_NUM_OF_FORMAT) {
      weight_dst_format = dst_format;
    }
    status = TransFilterFormat(weight_value, weight_dst_format);
    if (status == RET_OK) {
      weight_value->set_format(weight_dst_format);
    } else {
      MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(schema::EnumValuesFormat()[weight_value->format()]) << "To"
                    << EnumNameFormat(weight_dst_format) << " failed, node : " << node->fullname_with_scope()
                    << "quant type:" << quant_type;
      return ERROR;
    }
    auto type_id = static_cast<TypeId>(weight_value->tensor_type());
    auto type_ptr = TypeIdToType(type_id);
    auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, weight_value->tensor_shape());
    weight_node->set_abstract(abstract_tensor);
  }
  return
      RET_OK;
}

bool WeightFormatTransformPass::Run(const FuncGraphPtr &func_graph) {
  MS_ASSERT(func_graph != nullptr);
  bool quant_flag = false;
  if (this->quant_type == QuantType_AwareTraining || this->quant_type == QuantType_WeightQuant) {
    quant_flag = true;
  }
  auto status = ConvWeightFormatTrans(func_graph, quant_flag);
  if (status != lite::RET_OK) {
    MS_LOG(ERROR) << "NonQuantDataFormatTrans failed: " << status;
    return status;
  }
  return false;
}
}  // namespace mindspore::opt


In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_WEIGHT_FORMAT_TRANSFORM_PASS_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_WEIGHT_FORMAT_TRANSFORM_PASS_H_
#include <string>
#include "schema/inner/model_generated.h"
#include "tools/converter/converter_flags.h"
#include "backend/optimizer/common/pass.h"

using mindspore::lite::converter::FmkType;
using mindspore::schema::QuantType;
namespace mindspore::opt {
class WeightFormatTransformPass : public Pass {
 public:
  explicit WeightFormatTransformPass() : Pass("weight_format_transform_pass") {}
  ~WeightFormatTransformPass() override = default;
  void SetQuantType(QuantType type);
  void SetFmkType(FmkType fmkType);
  void SetDstFormat(schema::Format format);
  bool Run(const FuncGraphPtr &graph) override;

 private:
  lite::STATUS ConvWeightFormatTrans(const FuncGraphPtr &graph, bool quant_flag);

 private:
  QuantType quant_type = schema::QuantType_QUANT_NONE;
  FmkType fmk_type = lite::converter::FmkType_TF;
  schema::Format dst_format = schema::Format::Format_NUM_OF_FORMAT;
};
}  // namespace mindspore::opt
#endif  // MINDSPORE_LITE_SRC_PASS_FUSION_WEIGHT_FORMAT_TRANSFORM_PASS_H_


In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "tools/optimizer/common/gllo_utils.h"
#include <vector>
#include <algorithm>
#include <utility>
#include "src/ops/primitive_c.h"
#include "src/common/common.h"
#include "frontend/operator/ops.h"
#include "backend/optimizer/common/helper.h"

namespace mindspore {
namespace opt {
namespace {
constexpr auto kAnfPrimitiveIndex = 0;
bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
  MS_EXCEPTION_IF_NULL(node);
  if (!node->isa<CNode>()) {
    return false;
  }
  auto cnode = node->cast<CNodePtr>();
  MS_EXCEPTION_IF_NULL(cnode);
  return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type);
}

bool IsRealKernel(const AnfNodePtr &node) {
  MS_EXCEPTION_IF_NULL(node);
  // parameter and value node is not a real kernel too
  if (!node->isa<CNode>()) {
    return true;
  }
  auto cnode = node->cast<CNodePtr>();
  MS_EXCEPTION_IF_NULL(cnode);
  if (cnode->inputs().empty()) {
    MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString();
  }
  auto input = cnode->inputs()[0];
  bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) ||
                         IsPrimitive(input, prim::kPrimTensorSummary) ||
                         IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) ||
                         IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) ||
                         IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) ||
                         IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial);
  return !is_virtual_node;
}

ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
  if (utils::isa<int>(sexp)) {
    return NewValueNode(utils::cast<int>(sexp));
  }
  if (utils::isa<float>(sexp)) {
    return NewValueNode(utils::cast<float>(sexp));
  }
  if (utils::isa<bool>(sexp)) {
    return NewValueNode(utils::cast<bool>(sexp));
  }
  if (utils::isa<ValuePtr>(sexp)) {
    return NewValueNode(utils::cast<ValuePtr>(sexp));
  }
  return nullptr;
}

CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
  if (utils::isa<FuncGraphPtr>(graph)) {
    return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
  }
  if (utils::isa<VarPtr>(graph)) {
    return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph));
  }
  return nullptr;
}

VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
  if (utils::isa<VarPtr>(graph)) {
    MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
    return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
  }
  if (utils::isa<FuncGraphPtr>(graph)) {
    MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString();
    return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph));
  }
  MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString();
  return nullptr;
}

AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
                            bool multigraph) {
  MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString();
  std::vector<AnfNodePtr> input_nodes;
  const auto &tuple = utils::cast<VectorRef>(sexp);
  if (multigraph && utils::isa<VarPtr>(graph)) {
    for (auto &x : tuple) {
      AnfNodePtr node = SexpToNode(x, std::make_shared<Var>("G"), primitive_vars, true);
      input_nodes.push_back(node);
    }
    VarPtr var_ptr = utils::cast<VarPtr>(graph);
    return std::make_shared<CNode>(input_nodes, var_ptr);
  }

  for (auto &x : tuple) {
    AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph);
    input_nodes.push_back(node);
  }
  return CreateCNodeWithGraph(input_nodes, graph);
}
}  // namespace

bool AnfEqual(const BaseRef &a, const BaseRef &b) {
  if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
    auto a_node = utils::cast<AnfNodePtr>(a);
    auto b_node = utils::cast<AnfNodePtr>(b);
    MS_EXCEPTION_IF_NULL(a_node);
    MS_EXCEPTION_IF_NULL(b_node);
    if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
      auto a_value_node = a_node->cast<ValueNodePtr>();
      MS_EXCEPTION_IF_NULL(a_value_node);
      auto a_value = a_value_node->value();
      MS_EXCEPTION_IF_NULL(a_value);
      auto a_prim = a_value->cast<PrimitivePtr>();
      MS_EXCEPTION_IF_NULL(a_prim);

      auto b_value_node = b_node->cast<ValueNodePtr>();
      MS_EXCEPTION_IF_NULL(b_value_node);
      auto b_value = b_value_node->value();
      MS_EXCEPTION_IF_NULL(b_value);
      auto b_prim = b_value->cast<PrimitivePtr>();
      MS_EXCEPTION_IF_NULL(b_prim);

      return a_prim->cast<PrimitiveCPtr>()->Type() == b_prim->cast<PrimitiveCPtr>()->Type();
    } else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {
      auto a_value_node_ptr = a_node->cast<ValueNodePtr>();
      if (a_value_node_ptr == nullptr) {
        MS_LOG(EXCEPTION) << "cast value node ptr fail";
      }
      auto a_value_ptr = a_value_node_ptr->value();
      if (a_value_ptr == nullptr) {
        MS_LOG(EXCEPTION) << "value ptr is nullptr";
      }

      auto b_value_node_ptr = b_node->cast<ValueNodePtr>();
      if (b_value_node_ptr == nullptr) {
        MS_LOG(EXCEPTION) << "cast value node ptr fail";
      }
      auto b_value_ptr = b_value_node_ptr->value();
      if (b_value_ptr == nullptr) {
        MS_LOG(EXCEPTION) << "value ptr is nullptr";
      }

      if (utils::isa<lite::PrimitiveC>(a_value_ptr) && utils::isa<lite::PrimitiveC>(b_value_ptr)) {
        auto a_obj = (lite::PrimitiveC *)(a_value_ptr.get());
        auto b_obj = (lite::PrimitiveC *)(b_value_ptr.get());
        return (*a_obj) == (*b_obj);
      } else {
        return (*a_value_ptr) == (*b_value_ptr);
      }
    }
  }
  if (a.m_ptr->isa<lite::PrimitiveC>() && b.m_ptr->isa<lite::PrimitiveC>()) {
    auto a_value_node_ptr = a.m_ptr->cast<PrimitiveCPtr>();
    auto b_value_node_ptr = b.m_ptr->cast<PrimitiveCPtr>();
    return a_value_node_ptr->Type() == b_value_node_ptr->Type();
  }

  return a == b;
}

bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
  // To matchCNode and Kernel's type
  if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
    return true;
  }
  return a.type() == b.type();
}

AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
  MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
  MS_EXCEPTION_IF_NULL(primitive_vars);
  if (utils::isa<VectorRef>(sexp)) {
    return HandleSexpVector(sexp, graph, primitive_vars, multigraph);
  }
  if (utils::isa<VarPtr>(sexp)) {
    auto var_ptr = utils::cast<VarPtr>(sexp);
    MS_EXCEPTION_IF_NULL(var_ptr);
    if (var_ptr->primitive()) {
      (*primitive_vars)[var_ptr->primitive()] = var_ptr;
      return NewValueNode(var_ptr->primitive());
    }
    return CreateVarNodeWithSexp(sexp, graph);
  }
  if (utils::isa<AnfNodePtr>(sexp)) {
    return utils::cast<AnfNodePtr>(sexp);
  }
  auto value_node = CreateValueNodeWithSexp(sexp);
  if (value_node == nullptr) {
    MS_LOG(EXCEPTION) << "sexp cannot converted. sexp: " + sexp.ToString();
  }
  return value_node;
}

bool IsRealCNodeKernel(const AnfNodePtr &node) {
  MS_EXCEPTION_IF_NULL(node);
  // parameter and value node is not a real cnode kernel
  if (!node->isa<CNode>()) {
    return false;
  }
  // return considered as a real node
  if (CheckPrimitiveType(node, prim::kPrimReturn)) {
    return true;
  }
  return IsRealKernel(node);
}
bool IsGraphKernel(const AnfNodePtr &node) {
  MS_EXCEPTION_IF_NULL(node);
  // graph kernel should be a real cnode kernel.
  if (!IsRealCNodeKernel(node)) {
    return false;
  }

  auto cnode = node->cast<CNodePtr>();
  MS_EXCEPTION_IF_NULL(cnode);
  auto input = cnode->input(kAnfPrimitiveIndex);
  // graph kernel should has func_graph as first input.
  if (!IsValueNode<FuncGraph>(input)) {
    return false;
  }

  auto func_graph = GetValueNode<FuncGraphPtr>(input);
  MS_EXCEPTION_IF_NULL(func_graph);
  return func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
}

void CheckIfFuncGraphIsNull(const FuncGraphPtr &graph) {
  if (graph == nullptr) {
    MS_LOG(EXCEPTION) << "The graph is null.";
  }
}

void CheckIfAnfNodeIsNull(const AnfNodePtr &node) {
  if (node == nullptr) {
    MS_LOG(EXCEPTION) << "The AnfNode is null.";
  }
}

void CheckIfCNodeIsNull(const CNodePtr &node) {
  if (node == nullptr) {
    MS_LOG(EXCEPTION) << "The CNode is null.";
  }
}

void CheckIfVarIsNull(const VarPtr &var) {
  if (var == nullptr) {
    MS_LOG(EXCEPTION) << "The Var is null.";
  }
}

void CheckIfNodeIsParam(const AnfNodePtr &node) {
  if (node != nullptr && !utils::isa<ParameterPtr>(node)) {
    MS_LOG(EXCEPTION) << "The Node is not param.";
  }
}

void CheckInputSize(const CNodePtr &node, const int size) {
  if (static_cast<int>(node->inputs().size()) != size) {
    MS_LOG(EXCEPTION) << "The input size of node must be " << size << ", but it is" << node->inputs().size();
  }
}

void CheckLeastInputSize(const CNodePtr &node, const int size) {
  if (static_cast<int>(node->inputs().size()) < size) {
    MS_LOG(EXCEPTION) << "The input size of node must be " << size << ", but it is" << node->inputs().size();
  }
}

ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num,
                            const ParamValueLitePtr &weight_tensor) {
  auto bias_parameter = func_graph->add_parameter();
  MS_ASSERT(bias_parameter != nullptr);
  std::vector<int> shape = {kernel_num};
  auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(TypeIdToType(weight_tensor->tensor_type()), shape);
  bias_parameter->set_abstract(abstract_tensor);

  ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
  MS_ASSERT(param_value != nullptr);
  param_value->set_tensor_addr(bias_data);
  param_value->set_tensor_size(kernel_num * sizeof(float) / sizeof(uint8_t));
  param_value->set_format(weight_tensor->format());
  param_value->set_tensor_type(weight_tensor->tensor_type());
  param_value->set_tensor_shape(shape);
  bias_parameter->set_default_param(param_value);
  return bias_parameter;
}

schema::PrimitiveType GetCNodeType(const BaseRef &n) {
  ValueNodePtr value_node;
  if (utils::isa<CNodePtr>(n)) {
    auto in = utils::cast<CNodePtr>(n);
    value_node = in->input(0)->cast<ValueNodePtr>();
  } else if (utils::isa<ValueNodePtr>(n)) {
    value_node = utils::cast<ValueNodePtr>(n);
  } else {
    MS_LOG(ERROR) << "only value node or cnode has type";
    return schema::PrimitiveType_NONE;
  }
  MS_EXCEPTION_IF_NULL(value_node);
  auto value = value_node->value();
  MS_ASSERT(value != nullptr);
  if (utils::isa<PrimitiveCPtr>(value)) {
    auto primitive = value->cast<PrimitiveCPtr>();
    MS_ASSERT(primitive != nullptr);
    return (schema::PrimitiveType)primitive->Type();
  } else if (utils::isa<Primitive>(value)) {
    auto primitive = value->cast<PrimitivePtr>();
    MS_ASSERT(primitive != nullptr);
    MS_LOG(INFO) << "anf primitive node type:" << primitive->name();
    return schema::PrimitiveType_NONE;
  }
  return schema::PrimitiveType_NONE;
}
ParamValueLitePtr  GetLiteParamValue(const AnfNodePtr &node) {
  MS_ASSERT(node != nullptr);
  if (!utils::isa<ParameterPtr>(node)) {
    MS_LOG(ERROR) << "cur node not parameter";
    return nullptr;
  }
  auto param = node->cast<ParameterPtr>();
  MS_ASSERT(param != nullptr);
  auto param_value = std::dynamic_pointer_cast<ParamValueLite>(param->default_param());
  MS_ASSERT(param_value != nullptr);
  return param_value;
}
bool IsParamNode(const BaseRef &n) {
  if (!utils::isa<ParameterPtr>(n)) {
    return false;
  }
  auto param = utils::cast<ParameterPtr>(n)->default_param();
  auto tensor = std::dynamic_pointer_cast<ParamValueLite>(param);
  if (tensor == nullptr) {
    return false;
  }
  return tensor->tensor_addr() != nullptr;
}

bool IsConvNode(const BaseRef &n) {
  if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
    auto type = opt::GetCNodeType(n);
    return type == schema::PrimitiveType_Conv2D || type == schema::PrimitiveType_DepthwiseConv2D;
  }
  return false;
}

bool IsPoolingNode(const BaseRef &n) {
  if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
    auto type = opt::GetCNodeType(n);
    return type == schema::PrimitiveType_Pooling;
  }
  return false;
}

bool IsQuantNode(const BaseRef &n) {
  if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
    auto type = opt::GetCNodeType(n);
    return type == schema::PrimitiveType_QuantDTypeCast;
  }
  return false;
}

bool CheckIsAllInputsParam(const AnfNodePtr &node) {
  if (utils::isa<CNode>(node)) {
    auto cnode = node->cast<CNodePtr>();
    for (size_t i = 1; i < cnode->inputs().size(); i++) {
      if (!utils::isa<Parameter>(cnode->input(i))) {
        return false;
      }
    }
    return true;
  }
  return false;
}

size_t GetOutputTensorNum(const AnfNodePtr &node) {
  MS_EXCEPTION_IF_NULL(node);
  auto type = node->Type();
  if (type == nullptr) {
    return 1;
  }
  if (type->isa<Tuple>()) {
    auto tuple_type = type->cast<TuplePtr>();
    MS_EXCEPTION_IF_NULL(tuple_type);
    return tuple_type->size();
  } else if (type->isa<TensorType>() || type->isa<Number>()) {
    return 1;
  } else if (type->isa<TypeNone>()) {
    return 0;
  } else {
    return 1;
  }
}

bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node) {
  auto output_node_list = GetRealNodeUsedList(graph, node);
  if (output_node_list->size() != 1) {
    MS_LOG(DEBUG) << "fusion node has multi output nodes";
    return true;
  }
  return false;
}

std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
                                                                             const AnfNodePtr &node) {
  auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
  MS_EXCEPTION_IF_NULL(graph);
  auto manager = graph->manager();
  MS_EXCEPTION_IF_NULL(manager);
  auto iter = manager->node_users().find(node);
  if (iter == manager->node_users().end()) {
    MS_LOG(EXCEPTION) << "node has no output in manager";
  }
  auto output_info_list = iter->second;
  std::copy(output_info_list.begin(), output_info_list.end(), std::back_inserter(*output_node_list));
  return output_node_list;
}
size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) {
  MS_ASSERT(tuple_get_item != nullptr);
  if (tuple_get_item->size() != kTupleGetItemInputSize) {
    MS_LOG(ERROR) << "The node tuple_get_item must have 2 inputs!";
    return -1;
  }
  auto output_index_value_node = tuple_get_item->input(kInputNodeOutputIndexInTupleGetItem);
  MS_ASSERT(output_index_value_node != nullptr);
  auto value_node = output_index_value_node->cast<ValueNodePtr>();
  MS_ASSERT(value_node != nullptr);
  return IntToSize(GetValue<int>(value_node->value()));
}
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph,
                                                                                        const AnfNodePtr &node,
                                                                                        size_t output_index) {
  MS_ASSERT(graph != nullptr);
  MS_ASSERT(node != nullptr);
  auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
  auto manager = graph->manager();
  MS_ASSERT(manager != nullptr);
  auto iter = manager->node_users().find(node);
  if (iter == manager->node_users().end()) {
    MS_LOG(ERROR) << "node has no output in manager";
    return output_node_list;
  }
  auto output_info_list = iter->second;
  for (const auto &output_info : output_info_list) {
    size_t used_output_index;
    if (GetCNodeType(output_info.first) == schema::PrimitiveType_TupleGetItem) {
      used_output_index = GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first));
    } else if (GetCNodeType(node) == schema::PrimitiveType_TupleGetItem) {
      used_output_index = output_index;
    } else {
      if (output_index != 0) {
        MS_LOG(ERROR) << "node has no output in manager";
        return output_node_list;
      }
      return output_node_list;
    }
    if (used_output_index == output_index) {
      output_node_list->push_back(output_info);
    }
  }
  return output_node_list;
}
STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC,
                    int32_t *filterH, int32_t *filterW) {
  MS_ASSERT(oriDims.size() == 4);
  if (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC || type == kKCHW2CKHW) {
    *filterK = oriDims.at(lite::KCHW_K);
    *filterC = oriDims.at(lite::KCHW_C);
    *filterH = oriDims.at(lite::KCHW_H);
    *filterW = oriDims.at(lite::KCHW_W);
  } else if (type == kCKHW2HWCK || type == kCKHW2HWKC || type == kCKHW2KHWC) {
    *filterC = oriDims.at(lite::CKHW_C);
    *filterK = oriDims.at(lite::CKHW_K);
    *filterH = oriDims.at(lite::CKHW_H);
    *filterW = oriDims.at(lite::CKHW_W);
  } else if (type == kHWCK2KCHW || type == kHWCK2CKHW) {
    *filterH = oriDims.at(lite::HWCK_H);
    *filterW = oriDims.at(lite::HWCK_W);
    *filterC = oriDims.at(lite::HWCK_C);
    *filterK = oriDims.at(lite::HWCK_K);
  } else if (type == kHWKC2KCHW || type == kHWKC2CKHW) {
    *filterH = oriDims.at(lite::HWKC_H);
    *filterW = oriDims.at(lite::HWKC_W);
    *filterK = oriDims.at(lite::HWKC_K);
    *filterC = oriDims.at(lite::HWKC_C);
  } else if (type == kNHWC2KCHW || type == kNHWC2HWCK || type == kNHWC2CKHW) {
    *filterK = oriDims.at(lite::NHWC_N);
    *filterH = oriDims.at(lite::NHWC_H);
    *filterW = oriDims.at(lite::NHWC_W);
    *filterC = oriDims.at(lite::NHWC_C);
  } else if (type == kCHWK2HWCK || type == kCHWK2KHWC) {
    *filterC = oriDims.at(lite::CHWK_C);
    *filterH = oriDims.at(lite::CHWK_H);
    *filterW = oriDims.at(lite::CHWK_W);
    *filterK = oriDims.at(lite::CHWK_K);
  } else if (type == kKHWC2HWCK || type == kKHWC2CHWK) {
    *filterK = oriDims.at(lite::KHWC_K);
    *filterH = oriDims.at(lite::KHWC_H);
    *filterW = oriDims.at(lite::KHWC_W);
    *filterC = oriDims.at(lite::KHWC_C);
  } else {
    MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
    return RET_ERROR;
  }
  return RET_OK;
}

STATUS SetFilterDim(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC,
                    int32_t filterH, int32_t filterW) {
  MS_ASSERT(tensor != nullptr);
  if (type == kKCHW2HWCK || type == kCKHW2HWCK || type == kNHWC2HWCK || type == kKHWC2HWCK || type == kCHWK2HWCK) {
    tensor->set_tensor_shape({filterH, filterW, filterC, filterK});
  } else if (type == kKCHW2HWKC || type == kCKHW2HWKC) {
    tensor->set_tensor_shape({filterH, filterW, filterK, filterC});
  } else if (type == kHWCK2KCHW || type == kHWKC2KCHW || type == kNHWC2KCHW) {
    tensor->set_tensor_shape({filterK, filterC, filterH, filterW});
  } else if (type == kHWCK2CKHW || type == kHWKC2CKHW || type == kNHWC2CKHW || type == kKCHW2CKHW) {
    tensor->set_tensor_shape({filterC, filterK, filterH, filterW});
  } else if (type == kKHWC2CHWK) {
    tensor->set_tensor_shape({filterC, filterH, filterW, filterK});
  } else if (type == kKCHW2KHWC || type == kCKHW2KHWC || type == kCHWK2KHWC) {
    tensor->set_tensor_shape({filterK, filterH, filterW, filterC});
  } else {
    MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
    return RET_ERROR;
  }
  return RET_OK;
}
template<typename T>
static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC,
                              int32_t filterH, int32_t filterW) {
  MS_ASSERT(tensor != nullptr);
  int count = filterH * filterW * filterC * filterK;
  if (count <= 0) {
    MS_LOG(ERROR) << "Dim size invalid";
    return RET_ERROR;
  }
  std::unique_ptr<T[]> buf(new(std::nothrow) T[count]);
  if (buf == nullptr) {
    MS_LOG(ERROR) << "new buf failed";
    return RET_ERROR;
  }

  void *originWeightData = tensor->tensor_addr();
  T *weightData = static_cast<T *>(originWeightData);

  if (weightData == nullptr) {
    MS_LOG(ERROR) << "weightData is nullptr";
    return RET_ERROR;
  }
  T *p1Buff = nullptr;
  T *p2Buff = nullptr;
  switch (type) {
    case kCHWK2HWCK:
    case kCHWK2KHWC: {
      for (int c = 0; c < filterC; ++c) {
        for (int h = 0; h < filterH; ++h) {
          for (int w = 0; w < filterW; ++w) {
            for (int k = 0; k < filterK; ++k) {
              p1Buff = weightData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k));
              if (type == kCHWK2HWCK) {
                p2Buff =
                    buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
              } else if (type == kCHWK2KHWC) {
                p2Buff =
                    buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
              }
              *p2Buff = *p1Buff;
            }
          }
        }
      }
    }
      break;
    case kKHWC2HWCK: {
      for (int k = 0; k < filterK; ++k) {
        for (int h = 0; h < filterH; ++h) {
          for (int w = 0; w < filterW; ++w) {
            for (int c = 0; c < filterC; ++c) {
              p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
              p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
              *p2Buff = *p1Buff;
            }
          }
        }
      }
    }
      break;
    case kKCHW2HWCK:
    case kKCHW2CKHW:
    case kKCHW2KHWC:
    case kKCHW2HWKC: {
      for (int k = 0; k < filterK; ++k) {
        for (int c = 0; c < filterC; ++c) {
          for (int h = 0; h < filterH; ++h) {
            for (int w = 0; w < filterW; ++w) {
              p1Buff = weightData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
              if (type == kKCHW2HWCK) {
                p2Buff =
                    buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
              } else if (type == kKCHW2KHWC) {
                p2Buff =
                    buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
              } else if (type == kKCHW2CKHW) {
                p2Buff =
                    buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
              } else {
                p2Buff =
                    buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
              }
              *p2Buff = *p1Buff;
            }
          }
        }
      }
    }
      break;
    case kCKHW2HWCK:
    case kCKHW2KHWC:
    case kCKHW2HWKC: {
      for (int c = 0; c < filterC; ++c) {
        for (int k = 0; k < filterK; ++k) {
          for (int h = 0; h < filterH; ++h) {
            for (int w = 0; w < filterW; ++w) {
              p1Buff = weightData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
              if (type == kCKHW2HWCK) {
                p2Buff =
                    buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
              } else if (type == kCKHW2KHWC) {
                p2Buff =
                    buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
              } else {
                p2Buff =
                    buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
              }
              *p2Buff = *p1Buff;
            }
          }
        }
      }
    }
      break;
    case kHWCK2KCHW:
    case kHWCK2CKHW: {
      for (int h = 0; h < filterH; ++h) {
        for (int w = 0; w < filterW; ++w) {
          for (int c = 0; c < filterC; ++c) {
            for (int k = 0; k < filterK; ++k) {
              p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
              if (type == kHWCK2KCHW) {
                p2Buff =
                    buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
              } else {
                p2Buff =
                    buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
              }
              *p2Buff = *p1Buff;
            }
          }
        }
      }
    }
      break;
    case kHWKC2KCHW:
    case kHWKC2CKHW: {
      for (int h = 0; h < filterH; ++h) {
        for (int w = 0; w < filterW; ++w) {
          for (int c = 0; c < filterC; ++c) {
            for (int k = 0; k < filterK; ++k) {
              p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c));
              if (type == kHWKC2KCHW) {
                p2Buff =
                    buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
              } else {
                p2Buff =
                    buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
              }
              *p2Buff = *p1Buff;
            }
          }
        }
      }
    }
      break;
    case kNHWC2HWCK:
    case kNHWC2KCHW:
    case kNHWC2CKHW: {
      for (int k = 0; k < filterK; ++k) {
        for (int h = 0; h < filterH; ++h) {
          for (int w = 0; w < filterW; ++w) {
            for (int c = 0; c < filterC; ++c) {
              p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c));
              if (type == kNHWC2HWCK) {
                p2Buff =
                    buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
              } else if (type == kNHWC2CKHW) {
                p2Buff =
                    buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
              } else {
                p2Buff =
                    buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
              }
              *p2Buff = *p1Buff;
            }
          }
        }
      }
    }
      break;
    case kKHWC2CHWK: {
      for (int k = 0; k < filterK; ++k) {
        for (int h = 0; h < filterH; ++h) {
          for (int w = 0; w < filterW; ++w) {
            for (int c = 0; c < filterC; ++c) {
              p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
              p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (h * filterK * filterW) + (w * filterK) + (k));
              *p2Buff = *p1Buff;
            }
          }
        }
      }
    }
      break;
    default: {
      MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
      return RET_ERROR;
    }
  }

  auto ret = ::memcpy_s(tensor->tensor_addr(), count * sizeof(T), buf.get(), count * sizeof(T));
  if (ret != EOK) {
    MS_LOG(ERROR) << "memcpy_s failed: " << ret;
    return RET_ERROR;
  }
  return RET_OK;
}

template<typename T>
static STATUS TransFilterFormat(const ParamValueLitePtr &tensor, kTransFilterType type) {
  MS_ASSERT(tensor != nullptr);
  auto oriDims = tensor->tensor_shape();
  if (oriDims.size() != (size_t)lite::DIM_DEFAULT_SIZE) {
    MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size();
    return lite::RET_ERROR;
  }

  int32_t filterH;
  int32_t filterW;
  int32_t filterC;
  int32_t filterK;
  auto status = GetFilterDim(oriDims, type, &filterK, &filterC, &filterH, &filterW);
  if (status != lite::RET_OK) {
    MS_LOG(ERROR) << "GetFilterDim failed: " << status;
    return status;
  }
  status = SetFilterDim(tensor, type, filterK, filterC, filterH, filterW);
  if (status != lite::RET_OK) {
    MS_LOG(ERROR) << "SetFilterDim failed: " << status;
    return status;
  }
  status = TransFilterData<T>(tensor, type, filterK, filterC, filterH, filterW);
  if (status != lite::RET_OK) {
    MS_LOG(ERROR) << "TransFilterData failed: " << status;
    return status;
  }

  return lite::RET_OK;
}

STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_format) {
  if (tensor == nullptr) {
    return lite::RET_NULL_PTR;
  }
  auto ori_dims = tensor->tensor_shape();
  if (ori_dims.size() != (size_t)lite::DIM_DEFAULT_SIZE) {
    MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << ori_dims.size();
    return lite::RET_ERROR;
  }
  auto src_format = tensor->format();
  auto data_type = tensor->tensor_type();
  lite::STATUS status;
  switch (dst_format) {
    case schema::Format::Format_KHWC: {
      switch (src_format) {
        case schema::Format::Format_KCHW:
          if (data_type == kNumberTypeFloat32) {
            status = TransFilterFormat<float>(tensor, kKCHW2KHWC);
          } else if (data_type == kNumberTypeUInt8) {
            status = TransFilterFormat<uint8_t>(tensor, kKCHW2KHWC);
          } else if (data_type == kNumberTypeInt8) {
            status = TransFilterFormat<int8_t>(tensor, kKCHW2KHWC);
          } else {
            MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
            return RET_ERROR;
          }
          break;
        case schema::Format::Format_CKHW:
          if (data_type == kNumberTypeFloat32) {
            status = TransFilterFormat<float>(tensor, kCKHW2KHWC);
          } else if (data_type == kNumberTypeUInt8) {
            status = TransFilterFormat<uint8_t>(tensor, kCKHW2KHWC);
          } else if (data_type == kNumberTypeInt8) {
            status = TransFilterFormat<int8_t>(tensor, kCKHW2KHWC);
          } else {
            MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
            return RET_ERROR;
          }
          break;
        case schema::Format::Format_CHWK:
          if (data_type == kNumberTypeFloat32) {
            status = TransFilterFormat<float>(tensor, kCHWK2KHWC);
          } else if (data_type == kNumberTypeUInt8) {
            status = TransFilterFormat<uint8_t>(tensor, kCHWK2KHWC);
          } else if (data_type == kNumberTypeInt8) {
            status = TransFilterFormat<int8_t>(tensor, kCHWK2KHWC);
          } else {
            MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
            return RET_ERROR;
          }
          break;
        case schema::Format::Format_KHWC:return RET_OK;
        default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to "
                              << EnumNameFormat(dst_format);
          return RET_ERROR;
      }
    }
      break;
    case schema::Format::Format_HWCK: {
      switch (src_format) {
        case schema::Format::Format_KCHW:
          if (data_type == kNumberTypeFloat32) {
            status = TransFilterFormat<float>(tensor, kKCHW2HWCK);
          } else if (data_type == kNumberTypeUInt8) {
            status = TransFilterFormat<uint8_t>(tensor, kKCHW2HWCK);
          } else if (data_type == kNumberTypeInt8) {
            status = TransFilterFormat<int8_t>(tensor, kKCHW2HWCK);
          } else {
            MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
            return RET_ERROR;
          }
          break;
        case schema::Format::Format_KHWC:
          if (data_type == kNumberTypeFloat32) {
            status = TransFilterFormat<float>(tensor, kKHWC2HWCK);
          } else if (data_type == kNumberTypeUInt8) {
            status = TransFilterFormat<uint8_t>(tensor, kKHWC2HWCK);
          } else if (data_type == kNumberTypeInt8) {
            status = TransFilterFormat<int8_t>(tensor, kKHWC2HWCK);
          } else {
            MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
            return RET_ERROR;
          }
          break;
        case schema::Format::Format_CKHW:
          if (data_type == kNumberTypeFloat32) {
            status = TransFilterFormat<float>(tensor, kCKHW2HWCK);
          } else if (data_type == kNumberTypeUInt8) {
            status = TransFilterFormat<uint8_t>(tensor, kCKHW2HWCK);
          } else if (data_type == kNumberTypeInt8) {
            status = TransFilterFormat<int8_t>(tensor, kCKHW2HWCK);
          } else {
            MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
            return RET_ERROR;
          }
          break;
        case schema::Format::Format_CHWK:
          if (data_type == kNumberTypeFloat32) {
            status = TransFilterFormat<float>(tensor, kCHWK2HWCK);
          } else if (data_type == kNumberTypeUInt8) {
            status = TransFilterFormat<uint8_t>(tensor, kCHWK2HWCK);
          } else if (data_type == kNumberTypeInt8) {
            status = TransFilterFormat<int8_t>(tensor, kCHWK2HWCK);
          } else {
            MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
            return lite::RET_ERROR;
          }
          break;
        case schema::Format::Format_HWCK:return RET_OK;
        default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to "
                              << EnumNameFormat(dst_format);
          return RET_ERROR;
      }
    }
      break;
    case schema::Format::Format_KCHW: {
      switch (src_format) {
        case schema::Format::Format_KCHW:return RET_OK;
        case schema::Format::Format_HWCK:
          if (data_type == kNumberTypeFloat32) {
            status = TransFilterFormat<float>(tensor, kHWCK2KCHW);
          } else if (data_type == kNumberTypeUInt8) {
            status = TransFilterFormat<uint8_t>(tensor, kHWCK2KCHW);
          } else if (data_type == kNumberTypeInt8) {
            status = TransFilterFormat<int8_t>(tensor, kHWCK2KCHW);
          } else {
            MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
            return RET_ERROR;
          }
          break;
        case schema::Format::Format_HWKC:
          if (data_type == kNumberTypeFloat32) {
            status = TransFilterFormat<float>(tensor, kHWKC2KCHW);
          } else if (data_type == kNumberTypeUInt8) {
            status = TransFilterFormat<uint8_t>(tensor, kHWKC2KCHW);
          } else if (data_type == kNumberTypeInt8) {
            status = TransFilterFormat<int8_t>(tensor, kHWKC2KCHW);
          } else {
            MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
            return RET_ERROR;
          }
          break;
        case schema::Format::Format_KHWC:
          if (data_type == kNumberTypeFloat32) {
            status = TransFilterFormat<float>(tensor, kKHWC2KCHW);
          } else if (data_type == kNumberTypeUInt8) {
            status = TransFilterFormat<uint8_t>(tensor, kKHWC2KCHW);
          } else if (data_type == kNumberTypeInt8) {
            status = TransFilterFormat<int8_t>(tensor, kKHWC2KCHW);
          } else {
            MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
            return RET_ERROR;
          }
          break;
        case schema::Format::Format_CKHW:
          if (data_type == kNumberTypeFloat32) {
            status = TransFilterFormat<float>(tensor, kCKHW2KCHW);
          } else if (data_type == kNumberTypeUInt8) {
            status = TransFilterFormat<uint8_t>(tensor, kCKHW2KCHW);
          } else if (data_type == kNumberTypeInt8) {
            status = TransFilterFormat<int8_t>(tensor, kCKHW2KCHW);
          } else {
            MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
            return RET_ERROR;
          }
          break;
        case schema::Format::Format_CHWK:
          if (data_type == kNumberTypeFloat32) {
            status = TransFilterFormat<float>(tensor, kCHWK2KCHW);
          } else if (data_type == kNumberTypeUInt8) {
            status = TransFilterFormat<uint8_t>(tensor, kCHWK2KCHW);
          } else if (data_type == kNumberTypeInt8) {
            status = TransFilterFormat<int8_t>(tensor, kCHWK2KCHW);
          } else {
            MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
            return RET_ERROR;
          }
          break;
        default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to "
                              << EnumNameFormat(dst_format);
          return RET_ERROR;
      }
    }
      break;
    case schema::Format::Format_CKHW: {
      switch (src_format) {
        case schema::Format::Format_HWCK:
          if (data_type == kNumberTypeFloat32) {
            status = TransFilterFormat<float>(tensor, kHWCK2CKHW);
          } else if (data_type == kNumberTypeUInt8) {
            status = TransFilterFormat<uint8_t>(tensor, kHWCK2CKHW);
          } else if (data_type == kNumberTypeInt8) {
            status = TransFilterFormat<int8_t>(tensor, kHWCK2CKHW);
          } else {
            MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
            return RET_ERROR;
          }
          break;
        case schema::Format::Format_HWKC:
          if (data_type == kNumberTypeFloat32) {
            status = TransFilterFormat<float>(tensor, kHWKC2CKHW);
          } else if (data_type == kNumberTypeUInt8) {
            status = TransFilterFormat<uint8_t>(tensor, kHWKC2CKHW);
          } else if (data_type == kNumberTypeInt8) {
            status = TransFilterFormat<int8_t>(tensor, kHWKC2CKHW);
          } else {
            MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
            return RET_ERROR;
          }
          break;
        case schema::Format::Format_KCHW:
          if (data_type == kNumberTypeFloat32) {
            status = TransFilterFormat<float>(tensor, kKCHW2CKHW);
          } else if (data_type == kNumberTypeUInt8) {
            status = TransFilterFormat<uint8_t>(tensor, kKCHW2CKHW);
          } else if (data_type == kNumberTypeInt8) {
            status = TransFilterFormat<int8_t>(tensor, kKCHW2CKHW);
          } else {
            MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
            return RET_ERROR;
          }
          break;
        case schema::Format::Format_CKHW:return RET_OK;
        default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to "
                              << EnumNameFormat(dst_format);
          return RET_ERROR;
      }
    }
      break;
    default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to "
                          << EnumNameFormat(dst_format);
      return RET_ERROR;
  }
  if (status != RET_OK) {
    MS_LOG(ERROR) << "TransFilterData failed: " << status;
    return status;
  }
  return RET_OK;
}
}  // namespace opt
}  // namespace mindspore


In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_
#define MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_

#include <memory>
#include "src/ops//primitive_c.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "src/common/utils.h"
#include "backend/optimizer/common/pattern_engine.h"
#include "schema/inner/model_generated.h"
#include "src/param_value_lite.h"

using PrimitiveCPtr = std::shared_ptr<mindspore::lite::PrimitiveC>;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::lite::STATUS;
namespace mindspore {
namespace opt {
bool IsRealCNodeKernel(const AnfNodePtr &node);

bool IsGraphKernel(const AnfNodePtr &node);

void CheckIfFuncGraphIsNull(const FuncGraphPtr &graph);

void CheckIfAnfNodeIsNull(const AnfNodePtr &node);

void CheckIfCNodeIsNull(const CNodePtr &node);

void CheckIfVarIsNull(const VarPtr &var);

void CheckInputSize(const CNodePtr &node, int size);

void CheckIfNodeIsParam(const AnfNodePtr &node);

void CheckLeastInputSize(const CNodePtr &node, int size);

ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num,
                            const ParamValueLitePtr &weight_tensor);

schema::PrimitiveType GetCNodeType(const BaseRef &node);

bool IsParamNode(const BaseRef &n);

bool IsConvNode(const BaseRef &n);

bool IsPoolingNode(const BaseRef &n);

bool IsQuantNode(const BaseRef &n);

bool CheckIsAllInputsParam(const AnfNodePtr &node);

size_t GetOutputTensorNum(const AnfNodePtr &node);

bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node);

size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item);

ParamValueLitePtr  GetLiteParamValue(const AnfNodePtr &node);

enum kTransFilterType {
  kKCHW2HWCK,  // 0
  kKCHW2KHWC,
  kCKHW2KHWC,
  kCKHW2HWCK,
  kKCHW2HWKC,
  kCKHW2HWKC,
  kHWCK2KCHW,
  kHWCK2CKHW,
  kHWKC2KCHW,
  kHWKC2CKHW,
  kNHWC2KCHW,  // 10
  kNHWC2CKHW,
  kNHWC2HWCK,
  kKHWC2HWCK,
  kCHWK2HWCK,
  kKHWC2CHWK,
  kCHWK2KHWC,
  kKHWC2KCHW,
  kCKHW2KCHW,
  kCHWK2KCHW,
  kKCHW2CKHW  // 20
};

STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC,
                    int32_t *filterH, int32_t *filterW);

STATUS SetFilterDim(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC,
                    int32_t filterH, int32_t filterW);

template<typename T>
static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC,
                              int32_t filterH, int32_t filterW);

template<typename T>
static lite::STATUS TransFilterFormat(const ParamValueLitePtr &tensor, kTransFilterType type);

STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_format);
}  // namespace opt
}  // namespace mindspore
#endif  // MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_


#Data process