<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "tools/common/graph_util.h"
#include <stdlib.h>
#include <time.h>
#include <utility>
#include <set>
#include "schema/inner/model_generated.h"
#include "tools/common/tensor_util.h"
#include "tools/common/node_util.h"
#include "utils/log_adapter.h"
#include "src/common/utils.h"

namespace mindspore {
namespace lite {
OpDefCopyer GetSimpleOpCopyer() {
  return [](CNodeT *inCNode) -> std::unique_ptr<CNodeT> {
    std::unique_ptr<CNodeT> newCNode(new CNodeT);

    newCNode->name = inCNode->name;
    newCNode->quantType = inCNode->quantType;
    newCNode->primitive = std::make_unique<schema::PrimitiveT>();
    newCNode->primitive->value.type = inCNode->primitive->value.type;
    return newCNode;
  };
}

std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, const int inputIndexIdx) {
  return GetInputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), inputIndexIdx);
}

std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const CNodeT &node, const int inputIndexIdx) {
  std::vector<uint32_t> inputIndexes;
  if (inputIndexIdx == -1) {
    inputIndexes = node.inputIndex;
  } else {
    MS_ASSERT(node.inputIndex.size() > inputIndexIdx);
    inputIndexes.emplace_back(node.inputIndex.at(inputIndexIdx));
  }
  std::set<size_t> inputNodeIdx;
  for (uint32_t inputIdx : inputIndexes) {
    auto linkedPreIdx = GetLinkedPreIdx(graphT, inputIdx);
    inputNodeIdx.insert(linkedPreIdx.begin(), linkedPreIdx.end());
  }
  std::vector<size_t> ret;
  ret.insert(ret.end(), inputNodeIdx.begin(), inputNodeIdx.end());
  return ret;
}

std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx,
                                     const int outputIndexIdx) {
  return GetOutputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), outputIndexIdx);
}

std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const CNodeT &node, const int outputIndexIdx) {
  std::vector<uint32_t> outputIndexes;
  if (outputIndexIdx == -1) {
    outputIndexes = node.outputIndex;
  } else {
    MS_ASSERT(node.outputIndex.size() > outputIndexIdx);
    outputIndexes.emplace_back(node.outputIndex.at(outputIndexIdx));
  }
  std::set<size_t> outputNodeIdx;
  for (uint32_t outputIdx : outputIndexes) {
    auto linkedPostIdx = GetLinkedPostIdx(graphT, outputIdx);
    outputNodeIdx.insert(linkedPostIdx.begin(), linkedPostIdx.end());
  }
  std::vector<size_t> ret;
  ret.insert(ret.end(), outputNodeIdx.begin(), outputNodeIdx.end());
  return ret;
}

std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
  std::vector<size_t> preNodeIdx;
  for (size_t i = 0; i < graphT.nodes.size(); i++) {
    auto &oldNode = graphT.nodes.at(i);
    if (oldNode == nullptr) {
      continue;
    }
    auto outputIndexes = oldNode->outputIndex;
    if (IsContain<uint32_t>(outputIndexes, tensorIdx)) {
      preNodeIdx.emplace_back(i);
    }
  }
  return preNodeIdx;
}

std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
  std::vector<size_t> postNodeIdx;
  for (size_t i = 0; i < graphT.nodes.size(); i++) {
    auto &oldNode = graphT.nodes.at(i);
    if (oldNode == nullptr) {
      continue;
    }
    auto inputIndexes = oldNode->inputIndex;
    if (IsContain<uint32_t>(inputIndexes, tensorIdx)) {
      postNodeIdx.emplace_back(i);
    }
  }
  return postNodeIdx;
}

STATUS IsolateNode(schema::MetaGraphT *graphT, CNodeT *node) {
  MS_ASSERT(graphT != nullptr);
  MS_ASSERT(node != nullptr);
  size_t nodeIdx = 0;
  for (size_t i = 0; i < graphT->nodes.size(); i++) {
    auto &inNode = graphT->nodes.at(i);
    MS_ASSERT(inNode != nullptr);
    if (inNode->name == node->name) {
      nodeIdx = i;
      break;
    }
  }
  auto inputTensorIdxes = node->inputIndex;
  auto outputTensorIdxes = node->outputIndex;
  if (inputTensorIdxes.empty()) {
    MS_LOG(ERROR) << "Node " << node->name.c_str() << "should has no inputs";
    return RET_ERROR;
  }
  if (outputTensorIdxes.size() != 1) {
    MS_LOG(ERROR) << "FakeQuantNode " << node->name.c_str()
                  << "should has 1 output, in fact: " << outputTensorIdxes.size();
    return RET_ERROR;
  }
  auto inDataTensorIdx = inputTensorIdxes.front();
  auto outDataTensorIdx = outputTensorIdxes.front();

  MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
  auto &gOutTensorIdx = graphT->outputIndex;
  for (auto iter = gOutTensorIdx.begin(); iter != gOutTensorIdx.end(); iter++) {
    if (*iter == outDataTensorIdx) {
      *iter = inDataTensorIdx;
      break;
    }
  }

  // find poseNode
  auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
  for (auto postNodeIdx : postNodeIdxes) {
    MS_ASSERT(graphT->nodes.size() > postNodeIdx);
    auto &postNode = graphT->nodes.at(postNodeIdx);
    MS_ASSERT(postNode != nullptr);
    for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
      if (*iter == outDataTensorIdx) {
        *iter = inDataTensorIdx;
        break;
      }
    }
  }

  // whether need to remove weightInputTensores
  // remove all node's outputTensors
  RemoveTensor(graphT, outputTensorIdxes);
  node->inputIndex.clear();
  node->outputIndex.clear();

  return RET_OK;
}

STATUS IsolateOneWayNode(schema::MetaGraphT *graph, size_t subGraphIdx, size_t nodeIdx, bool removeTensor) {
  MS_ASSERT(graph != nullptr);
  return IsolateOneWayNode(graph, nodeIdx, removeTensor);
}

STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor) {
  MS_ASSERT(graphT != nullptr);
  if (graphT->nodes.size() <= nodeIdx) {
    MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
    return RET_PARAM_INVALID;
  }

  CNodeT *node = graphT->nodes.at(nodeIdx).get();
  auto inputTensorIdxes = node->inputIndex;
  auto outputTensorIdxes = node->outputIndex;
  auto preNodeIdxes = GetInputNodeIdx(*graphT, nodeIdx);
  if (preNodeIdxes.size() > 1 || outputTensorIdxes.size() > 1) {
    MS_LOG(ERROR) << "Only support node who has no more than one input and one output";
    return RET_ERROR;
  }
  if (inputTensorIdxes.empty()) {
    MS_LOG(ERROR) << "Error, " << nodeIdx << "th node has no input tensor";
    return RET_ERROR;
  }
  auto inDataTensorIdx = inputTensorIdxes.front();
  if (!outputTensorIdxes.empty()) {
    auto outDataTensorIdx = outputTensorIdxes.front();
    MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
    MS_ASSERT(graphT->allTensors.at(inDataTensorIdx) != nullptr);
    auto &gOutTensorIdx = graphT->outputIndex;
    for (auto iter = gOutTensorIdx.begin(); iter != gOutTensorIdx.end(); iter++) {
      if (*iter == outDataTensorIdx) {
        *iter = inDataTensorIdx;
        break;
      }
    }
    // find poseNode
    auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
    for (auto postNodeIdx : postNodeIdxes) {
      MS_ASSERT(graphT->nodes.size() > postNodeIdx);
      auto &postNode = graphT->nodes.at(postNodeIdx);
      MS_ASSERT(postNode != nullptr);
      for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
        if (*iter == outDataTensorIdx) {
          *iter = inDataTensorIdx;
          break;
        }
      }
    }
  }

  if (removeTensor) {
    // now all node's outputTensors are useless
    // remove all node's outputTensors
    auto status = RemoveTensor(graphT, outputTensorIdxes);
    if (status != RET_OK) {
      MS_LOG(ERROR) << "RemoveOutputTensors of node " << node->name.c_str() << "failed";
      return RET_ERROR;
    }
  }
  node->inputIndex.clear();
  node->outputIndex.clear();
  return RET_OK;
}

STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, CNodeT *node, bool removeTensor) {
  MS_ASSERT(graphT != nullptr);
  MS_ASSERT(node != nullptr);
  bool isSubNode = false;
  size_t nodeIdx = 0;
  for (size_t i = 0; i < graphT->nodes.size(); i++) {
    auto &inNode = graphT->nodes.at(i);
    if (inNode->name == node->name) {
      isSubNode = true;
      nodeIdx = i;
      break;
    }
  }
  if (!isSubNode) {
    MS_LOG(ERROR) << "Node " << node->name.c_str() << "is not in graphT " << graphT->name.c_str();
    return RET_PARAM_INVALID;
  } else {
    return IsolateOneWayNode(graphT, nodeIdx, removeTensor);
  }
}

STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete) {
  for (auto iter = toDeleteTensorIdxes.begin(); iter != toDeleteTensorIdxes.end();) {
    uint32_t deleteIdx = *iter;
    if (!forceDelete) {
      if (GetRefCount(graphT, deleteIdx) > 1) {
        iter++;
        continue;
      }
    }
    // update graph input indexes
    for (auto gInIdx = graphT->inputIndex.begin(); gInIdx != graphT->inputIndex.end(); gInIdx++) {
      if (*gInIdx > deleteIdx) {
        (*gInIdx)--;
      }
    }
    // update graph output indexes
    for (auto gOutIdx = graphT->outputIndex.begin(); gOutIdx != graphT->outputIndex.end(); gOutIdx++) {
      if (*gOutIdx > deleteIdx) {
        (*gOutIdx)--;
      }
    }
    // update nodes indexes
    for (auto nodeIter = graphT->nodes.begin(); nodeIter != graphT->nodes.end(); nodeIter++) {
      // update nodes input indexes
      UpdateNodeIndex((*nodeIter).get(), deleteIdx);
    }
    // update deleteTensorIdx
    for (auto selfIt = toDeleteTensorIdxes.begin(); selfIt != toDeleteTensorIdxes.end(); selfIt++) {
      if (*selfIt > deleteIdx) {
        (*selfIt)--;
      }
    }
    graphT->allTensors.erase(graphT->allTensors.begin() + deleteIdx);
    iter = toDeleteTensorIdxes.erase(iter);
  }
  return RET_OK;
}

STATUS UpdateNodeIndex(CNodeT *node, uint32_t deleteIdx) {
  for (auto inIdxIt = node->inputIndex.begin(); inIdxIt != node->inputIndex.end();) {
    if (*inIdxIt == deleteIdx) {
      inIdxIt = node->inputIndex.erase(inIdxIt);
    } else {
      if (*inIdxIt > deleteIdx) {
        (*inIdxIt)--;
      }
      inIdxIt++;
    }
  }
  // update nodes output indexes
  for (auto outIdxIt = node->outputIndex.begin(); outIdxIt != node->outputIndex.end();) {
    if (*outIdxIt == deleteIdx) {
      outIdxIt = node->outputIndex.erase(outIdxIt);
    } else {
      if (*outIdxIt > deleteIdx) {
        (*outIdxIt)--;
      }
      outIdxIt++;
    }
  }
  return RET_OK;
}

STATUS AddTensor2Node(schema::MetaGraphT *graphT, uint32_t nodeIdx, std::unique_ptr<TensorT> tensor,
                      InsertPlace place) {
  if (nodeIdx >= graphT->nodes.size()) {
    MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
    return RET_PARAM_INVALID;
  }
  graphT->allTensors.emplace_back(std::move(tensor));
  uint32_t newTensorIdx = graphT->allTensors.size() - 1;
  auto node = graphT->nodes.at(nodeIdx).get();
  if (place == kBefore) {
    node->inputIndex.emplace_back(newTensorIdx);
  } else {
    node->outputIndex.emplace_back(newTensorIdx);
  }
  return RET_OK;
}

STATUS ReplaceTensorOfNode(schema::MetaGraphT *graphT, uint32_t nodeIdx, uint32_t inTensorIdx,
                           std::unique_ptr<TensorT> tensor) {
  if (nodeIdx >= graphT->nodes.size()) {
    MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
    return RET_PARAM_INVALID;
  }
  auto node = graphT->nodes.at(nodeIdx).get();
  if (inTensorIdx >= graphT->allTensors.size()) {
    MS_LOG(ERROR) << "inTensorIdx out of range: " << nodeIdx;
    return RET_PARAM_INVALID;
  }
  if (!IsContain(node->inputIndex, inTensorIdx)) {
    MS_LOG(ERROR) << "inTensorIdx(" << inTensorIdx << ") is not a inputIdx of node(" << nodeIdx << ")";
    return RET_PARAM_INVALID;
  }
  graphT->allTensors.at(inTensorIdx).swap(tensor);
  return RET_OK;
}

NodeIter InsertNode(schema::MetaGraphT *graphT, uint32_t existNodeIdx, InsertPlace place, size_t inoutIndex,
                    std::unique_ptr<CNodeT> toAddNode, STATUS *errorCode, OpDefCopyer opDefCopyer) {
  if (existNodeIdx >= graphT->nodes.size()) {
    MS_LOG(ERROR) << "nodeIdx out of range: " << existNodeIdx;
    return graphT->nodes.end();
  }
  auto nodeIter = graphT->nodes.begin() + existNodeIdx;
  MS_ASSERT(nodeIter != graphT->nodes.begin());
  MS_ASSERT((*nodeIter) != nullptr);
  return InsertNode(graphT, nodeIter, place, inoutIndex, std::move(toAddNode), errorCode);
}

NodeIter InsertNode(schema::MetaGraphT *graphT, NodeIter existNodeIter, InsertPlace place, size_t inoutIndexIdx,
                    std::unique_ptr<CNodeT> toAddNode, STATUS *errorCode, OpDefCopyer opDefCopyer) {
  if (place == kBefore) {
    return InsertNodeBefore(graphT, existNodeIter, inoutIndexIdx, std::move(toAddNode), errorCode, opDefCopyer);
  } else if (place == kAfter) {
    return InsertNodeAfter(graphT, existNodeIter, inoutIndexIdx, std::move(toAddNode), errorCode, opDefCopyer);
  } else {
    MS_LOG(ERROR) << "Invalid InsertPlace : " << place;
    return graphT->nodes.end();
  }
}

NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t inputIndexIdx,
                          std::unique_ptr<CNodeT> toAddNodeIn, STATUS *errorCode, OpDefCopyer opDefCopyer) {
  auto &existNode = *existNodeIter;
  MS_ASSERT(existNode != nullptr);
  MS_ASSERT(existNode->inputIndex.size() > inputIndexIdx);
  MS_ASSERT(toAddNodeIn != nullptr);
  auto preTensorIdx = existNode->inputIndex.at(inputIndexIdx);
  MS_ASSERT(graphT->allTensors.size() > preTensorIdx);

  auto preNodeIdxes = GetInputNodeIdx(*graphT, *(existNode.get()), inputIndexIdx);
  if (preNodeIdxes.empty()) {
    auto &preTensor = graphT->allTensors.at(preTensorIdx);
    MS_ASSERT(preTensor != nullptr);
    auto toAddTensor = CopyTensorDefT(preTensor);
    if (toAddTensor == nullptr) {
      MS_LOG(ERROR) << "Copy TensorT failed";
      *errorCode = RET_NULL_PTR;
      return graphT->nodes.end();
    }
    preTensor->refCount = 0;
    preTensor->data.clear();
    if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
      preTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT;
      toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT;
    }
    graphT->allTensors.emplace_back(std::move(toAddTensor));
    size_t toAddTensorIdx = graphT->allTensors.size() - 1;
    auto toAddNode = opDefCopyer(toAddNodeIn.get());
    if (toAddNode == nullptr) {
      MS_LOG(ERROR) << "copy toAddNodeIn failed";
      *errorCode = RET_NULL_PTR;
      return graphT->nodes.end();
    }
    toAddNode->inputIndex.clear();
    toAddNode->inputIndex.push_back(preTensorIdx);
    toAddNode->outputIndex.clear();
    toAddNode->outputIndex.push_back(toAddTensorIdx);
    for (auto iter = existNode->inputIndex.begin(); iter != existNode->inputIndex.end(); iter++) {
      if (*iter == preTensorIdx) {
        *iter = toAddTensorIdx;
        break;
      }
    }
    existNodeIter = graphT->nodes.insert(existNodeIter, std::move(toAddNode));
    existNodeIter++;
  } else {
    std::vector<std::unique_ptr<CNodeT>> toAddNodes;
    for (size_t i = 0; i < preNodeIdxes.size(); i++) {
      MS_ASSERT(graphT->nodes.size() > preNodeIdxes.at(i));
      auto &preTensor = graphT->allTensors.at(preTensorIdx);
      MS_ASSERT(preTensor != nullptr);
      auto toAddTensor = CopyTensorDefT(preTensor);
      if (toAddTensor == nullptr) {
        *errorCode = RET_NULL_PTR;
        MS_LOG(ERROR) << "Copy TensorT failed";
        return graphT->nodes.end();
      }
      if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
        preTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT;
        toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT;
      }
      graphT->allTensors.emplace_back(std::move(toAddTensor));
      size_t toAddTensorIdx = graphT->allTensors.size() - 1;
      auto toAddNode = opDefCopyer(toAddNodeIn.get());
      if (toAddNode == nullptr) {
        MS_LOG(ERROR) << "copy toAddNodeIn failed";
        *errorCode = RET_NULL_PTR;
        return graphT->nodes.end();
      }
      toAddNode->name = toAddNodeIn->name + "_" + std::to_string(i++);
      toAddNode->inputIndex.clear();
      toAddNode->inputIndex.push_back(preTensorIdx);
      toAddNode->outputIndex.clear();
      toAddNode->outputIndex.push_back(toAddTensorIdx);
      for (auto iter = existNode->inputIndex.begin(); iter != existNode->inputIndex.end(); iter++) {
        if (*iter == preTensorIdx) {
          *iter = toAddTensorIdx;
          break;
        }
      }
      toAddNodes.emplace_back(std::move(toAddNode));
    }
    for (auto &toAddNode : toAddNodes) {
      existNodeIter = graphT->nodes.insert(existNodeIter, std::move(toAddNode));
      existNodeIter++;
    }
  }
  *errorCode = RET_OK;
  return existNodeIter;
}

NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t outputIndexIdx,
                         std::unique_ptr<schema::CNodeT> toAddNodeIn, STATUS *errorCode, OpDefCopyer opDefCopyer) {
  auto &existNode = *existNodeIter;
  MS_ASSERT(existNode != nullptr);
  MS_ASSERT(existNode->outputIndex.size() > outputIndexIdx);
  MS_ASSERT(toAddNodeIn != nullptr);
  auto postTensorIdx = existNode->outputIndex.at(outputIndexIdx);
  MS_ASSERT(graphT->allTensors.size() > postTensorIdx);

  auto postNodeIdxes = GetOutputNodeIdx(*graphT, *(existNode.get()), outputIndexIdx);
  if (postNodeIdxes.empty()) {
    auto &postTensor = graphT->allTensors.at(postTensorIdx);
    MS_ASSERT(postTensor != nullptr);
    auto toAddTensor = CopyTensorDefT(postTensor);
    if (toAddTensor == nullptr) {
      MS_LOG(ERROR) << "Copy TensorT failed";
      *errorCode = RET_NULL_PTR;
      return graphT->nodes.end();
    }
    if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
      postTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT;
      toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT;
    }
    graphT->allTensors.emplace_back(std::move(toAddTensor));
    size_t toAddTensorIdx = graphT->allTensors.size() - 1;
    auto toAddNode = opDefCopyer(toAddNodeIn.get());
    if (toAddNode == nullptr) {
      MS_LOG(ERROR) << "copy toAddNodeIn failed";
      *errorCode = RET_NULL_PTR;
      return graphT->nodes.end();
    }
    toAddNode->inputIndex.clear();
    toAddNode->inputIndex.push_back(postTensorIdx);
    toAddNode->outputIndex.clear();
    toAddNode->outputIndex.push_back(toAddTensorIdx);
    for (auto iter = graphT->outputIndex.begin(); iter != graphT->outputIndex.end(); iter++) {
      if (*iter == postTensorIdx) {
        *iter = toAddTensorIdx;
        break;
      }
    }
    existNodeIter = graphT->nodes.insert(existNodeIter, std::move(toAddNode));
    existNodeIter++;
  } else {
    std::vector<std::unique_ptr<schema::CNodeT>> toAddNodes;
    int i = 0;
    for (size_t postNodeIdx : postNodeIdxes) {
      MS_ASSERT(graphT->nodes.size() > postNodeIdx);
      auto &postNode = graphT->nodes.at(postNodeIdx);
      MS_ASSERT(postNode != nullptr);
      auto &postTensor = graphT->allTensors.at(postTensorIdx);
      MS_ASSERT(postTensor != nullptr);
      auto toAddTensor = CopyTensorDefT(postTensor);
      if (toAddTensor == nullptr) {
        MS_LOG(ERROR) << "Copy TensorT failed";
        *errorCode = RET_NULL_PTR;
        return graphT->nodes.end();
      }
      if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
        postTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT;
        toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT;
      }
      graphT->allTensors.emplace_back(std::move(toAddTensor));
      size_t toAddTensorIdx = graphT->allTensors.size() - 1;
      auto toAddNode = opDefCopyer(toAddNodeIn.get());
      if (toAddNode == nullptr) {
        MS_LOG(ERROR) << "copy toAddNodeIn failed";
        *errorCode = RET_NULL_PTR;
        return graphT->nodes.end();
      }
      toAddNode->name = toAddNodeIn->name + "_" + std::to_string(i++);
      toAddNode->inputIndex.clear();
      toAddNode->inputIndex.push_back(postTensorIdx);
      toAddNode->outputIndex.clear();
      toAddNode->outputIndex.push_back(toAddTensorIdx);
      MS_ASSERT(IsContain(postNode->inputIndex, postTensorIdx));
      for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
        if (*iter == postTensorIdx) {
          *iter = toAddTensorIdx;
          break;
        }
      }
      toAddNodes.emplace_back(std::move(toAddNode));
    }
    for (auto &toAddNode : toAddNodes) {
      existNodeIter = graphT->nodes.insert(existNodeIter, std::move(toAddNode));
      existNodeIter++;
    }
  }
  *errorCode = RET_OK;
  return existNodeIter;
}

STATUS ValidateFileStr(const std::string &modelFile, std::string fileType) {
  if (modelFile.size() > fileType.size()) {
    if (modelFile.substr(modelFile.size() - fileType.size()) == fileType) {
      return RET_OK;
    } else {
      return RET_ERROR;
    }
  } else {
    return RET_ERROR;
  }
}

std::string GetModelName(const std::string &modelFile) {
  std::string modelName = modelFile;
  modelName = modelName.substr(modelName.find_last_of('/') + 1);
  modelName = modelName.substr(0, modelName.find_last_of('.'));

  srand((unsigned)time(NULL));
  modelName = modelName + std::to_string(rand());

  return modelName;
}
}  // namespace lite
}  // namespace mindspore


In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "backend/optimizer/common/node_pass.h"

#include <unordered_set>
#include <deque>
#include <algorithm>

#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ir/manager.h"
#include "tools/optimizer/common/gllo_utils.h"

namespace mindspore {
namespace opt {
bool NodePass::Run(const FuncGraphPtr &func_graph) {
  MS_EXCEPTION_IF_NULL(func_graph);
  FuncGraphManagerPtr manager = func_graph->manager();
  MS_EXCEPTION_IF_NULL(manager);
  manager->AddFuncGraph(func_graph);

  std::unordered_set<AnfNodePtr> seen_node;
  std::deque<AnfNodePtr> to_process{func_graph->output()};
  bool changes = false;
  while (!to_process.empty()) {
    AnfNodePtr node = to_process.front();
    to_process.pop_front();
    if (seen_node.count(node) > 0 || !manager->all_nodes().contains(node)) {
      continue;
    }
    (void) seen_node.insert(node);
    AnfNodePtr new_node = Run(func_graph, node);
    bool change = (new_node != nullptr);
    if (new_node != nullptr && new_node != node) {
      (void) manager->Replace(node, new_node);
      (void) seen_node.erase(node);
    } else if (new_node == nullptr) {
      new_node = node;
    }
    if (new_node && IsValueNode<FuncGraph>(new_node)) {
      auto const_func_graph = GetValueNode<FuncGraphPtr>(new_node);
      MS_EXCEPTION_IF_NULL(const_func_graph);
      to_process.push_back(const_func_graph->output());
    } else if (new_node && new_node->isa<CNode>()) {
      if (IsGraphKernel(new_node)) {
        to_process.push_back(new_node);
      }
      auto cnode = new_node->cast<CNodePtr>();
      MS_EXCEPTION_IF_NULL(cnode);
      auto inputs = cnode->inputs();
      (void) to_process.insert(to_process.end(), inputs.begin(), inputs.end());
    }
    changes = changes || change;
    if (changes) {
      MS_LOG(DEBUG) << "pass " << this->name() << "changed node:" << new_node->fullname_with_scope();
    }
  }
  return changes;
}
}  // namespace opt
}  // namespace mindspore


In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "tools/optimizer/fusion/constant_folding_fusion.h"
#include <memory>
#include <set>
#include <vector>
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/anf_exporter/anf_exporter.h"
#include "src/kernel_registry.h"
#include "include/context.h"
#include "src/populate_parameter.h"
#include "src/ops/primitive_c.h"

using mindspore::lite::KernelRegistry;
using mindspore::lite::PrimitiveC;
using mindspore::lite::tensor::Tensor;
namespace mindspore::opt {
namespace {
std::vector<Tensor *> GetCNodeInputTensors(const CNodePtr &CNode) {
  MS_ASSERT(CNode != nullptr);
  auto tmp_meta_graph = std::make_unique<schema::MetaGraphT>();
  auto tmp_fb_node = std::make_unique<schema::CNodeT>();
  lite::AnfExporter anfExporter;
  anfExporter.SetOpInputNode(CNode, tmp_meta_graph, tmp_fb_node.get());
  std::vector<Tensor *> input_tensors;
  for (auto input_index : tmp_fb_node->inputIndex) {
    auto tensorT = tmp_meta_graph->allTensors.at(input_index).get();
    auto tensor_shape = tensorT->dims;
    auto lite_tensor =
        new (std::nothrow) Tensor(TypeId(tensorT->dataType), tensor_shape, tensorT->format, tensorT->nodeType);
    if (lite_tensor == nullptr) {
      MS_LOG(ERROR) << "lite tensor is nullptr";
      return input_tensors;
    }
    auto lite_tensor_size = tensorT->data.size() * sizeof(uint8_t);
    // when tensorT as graph input
    if (lite_tensor_size <= 0) {
      delete lite_tensor;
      return input_tensors;
    }
    auto tensor_data = new (std::nothrow) uint8_t[lite_tensor_size / sizeof(char)];
    if (tensor_data == nullptr) {
      MS_LOG(ERROR) << "tensor_data is nullptr";
      delete lite_tensor;
      return input_tensors;
    }
    auto ret = memcpy_s(tensor_data, lite_tensor_size, tensorT->data.data(), lite_tensor_size);
    if (ret != EOK) {
      delete lite_tensor;
      delete[](tensor_data);
      MS_LOG(EXCEPTION) << "memcpy error: " << ret;
    }
    lite_tensor->SetData(tensor_data);
    input_tensors.emplace_back(lite_tensor);
  }
  return input_tensors;
}

ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) {
  auto parameter = func_graph->add_parameter();
  std::vector<int> shape(tensor->shape());
  auto type_id = static_cast<TypeId>(tensor->data_type());
  auto type_ptr = TypeIdToType(type_id);
  auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape);
  parameter->set_abstract(abstract_tensor);

  ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
  MS_ASSERT(param_value != nullptr);
  param_value->set_tensor_shape(shape);
  param_value->set_tensor_type(type_id);
  param_value->set_format(tensor->GetFormat());
  if (tensor->Data() != nullptr) {
    auto size = tensor->ElementsNum();
    auto tensor_data = new (std::nothrow) float[size];
    if (tensor_data == nullptr) {
      MS_LOG(ERROR) << "tensor_data is nullptr";
      return nullptr;
    }
    auto ret = memcpy_s(tensor_data, size * sizeof(float), tensor->Data(), size * sizeof(float));
    if (ret != EOK) {
      delete[] tensor_data;
      MS_LOG(ERROR) << "memcpy error: " << ret;
      return nullptr;
    }
    param_value->set_tensor_addr(tensor_data);
    param_value->set_tensor_size(size * sizeof(float) / sizeof(uint8_t));
  }
  parameter->set_default_param(param_value);
  return parameter;
}
kernel::LiteKernel *GetLiteKernel(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs, OpParameter *parameter,
                                  mindspore::lite::PrimitiveC *primitive) {
  MS_ASSERT(nullptr != lite_primitive);
  auto data_type = inputs.front()->data_type();
  kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, (schema::PrimitiveType) primitive->Type()};
  lite::Context context;
  auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
  if (creator != nullptr) {
    auto lite_kernel = creator(inputs, outputs, parameter, &context, desc, primitive);
    return lite_kernel;
  }
  return nullptr;
}

lite::STATUS ReplaceCNode(const FuncGraphPtr &func_graph, const CNodePtr &any_node, const AnfNodePtr &input_node,
                          std::vector<Tensor *> output_tensors, size_t replace_index) {
  MS_ASSERT(func_graph != nullptr);
  auto manager = func_graph->manager();
  MS_ASSERT(manager != nullptr);
  if (output_tensors.size() != 1) {
    for (size_t k = 0; k < output_tensors.size(); k++) {
      auto used_node_list = GetRealNodeUsedListByOutputIdx(func_graph, input_node, k);
      if (used_node_list->size() != 1) {
        MS_LOG(ERROR) << " output must tuple_getitem";
        return lite::RET_ERROR;
      }
      auto tuple_node = used_node_list->at(0).first;
      if (GetCNodeType(tuple_node) == schema::PrimitiveType_TupleGetItem) {
        auto new_parameter = CreateNewParamter(func_graph, output_tensors.at(k));
        if (new_parameter == nullptr) {
          MS_LOG(ERROR) << "CreateNewParamter failed, name: " << input_node->fullname_with_scope();
          return lite::RET_ERROR;
        }
        new_parameter->set_name(input_node->fullname_with_scope() + "_const_" + std::to_string(k));
        manager->Replace(tuple_node, new_parameter);
      } else {
        MS_LOG(ERROR) << " multi out tensor must connect tuple-getitem: " << input_node->fullname_with_scope();
        return lite::RET_ERROR;
      }
    }
  } else {
    auto new_parameter = CreateNewParamter(func_graph, output_tensors.front());
    if (new_parameter == nullptr) {
      MS_LOG(ERROR) << "CreateNewParamter failed, name: " << input_node->fullname_with_scope();
      return lite::RET_ERROR;
    }
    new_parameter->set_name(input_node->fullname_with_scope());
    any_node->set_input(replace_index, new_parameter);
  }
  return lite::RET_OK;
}
}  //  namespace
void FreeTensors(std::vector<Tensor *> *input_tensor, std::vector<Tensor *> *output_tensor) {
  if (input_tensor != nullptr) {
    for (size_t i = 0; i < input_tensor->size(); i++) {
      delete (*input_tensor)[i];
      (*input_tensor)[i] = nullptr;
    }
  }
  if (output_tensor != nullptr) {
    for (size_t i = 0; i < output_tensor->size(); i++) {
      delete (*output_tensor)[i];
      (*output_tensor)[i] = nullptr;
    }
  }
}

const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
                                        const EquivPtr &) const {
  CheckIfFuncGraphIsNull(func_graph);
  CheckIfAnfNodeIsNull(node);
  if (!node->isa<CNode>()) {
    return nullptr;
  }
  auto any_node = node->cast<CNodePtr>();
  CheckIfCNodeIsNull(any_node);
  bool changed = false;
  for (size_t i = 1; i < any_node->inputs().size(); i++) {
    auto input_node = any_node->input(i);
    if (!input_node->isa<CNode>() || !CheckIsAllInputsParam(input_node)) {
      continue;
    }
    auto input_cnode = input_node->cast<CNodePtr>();
    auto input_tensors = GetCNodeInputTensors(input_cnode);
    if (input_tensors.empty() || input_tensors.size() != input_cnode->inputs().size() - 1) {
      FreeTensors(&input_tensors, nullptr);
      continue;
    }
    changed = true;
    auto output_nums = GetOutputTensorNum(input_cnode);
    std::vector<Tensor *> output_tensors{output_nums, new Tensor()};
    auto lite_primitive = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0));
    if (lite_primitive == nullptr) {
      MS_LOG(ERROR) << "lite_primitive is nullptr";
      FreeTensors(&input_tensors, &output_tensors);
      return nullptr;
    }
    auto inputQuantParams = lite_primitive->GetInputQuantParams();
    for (size_t m = 0; m < inputQuantParams.size(); m++) {
      for (auto inputQuantParam : inputQuantParams[m]) {
        lite::tensor::QuantArg quant_arg{};
        quant_arg.scale = inputQuantParam.scale;
        quant_arg.zeroPoint = inputQuantParam.zeroPoint;
        input_tensors[m]->AddQuantParam(quant_arg);
      }
    }
    auto outputQuantParams = lite_primitive->GetOutputQuantParams();
    for (size_t m = 0; m < outputQuantParams.size(); m++) {
      for (auto outputQuantParam : outputQuantParams[m]) {
        lite::tensor::QuantArg quant_arg{};
        quant_arg.scale = outputQuantParam.scale;
        quant_arg.zeroPoint = outputQuantParam.zeroPoint;
        output_tensors[m]->AddQuantParam(quant_arg);
      }
    }
    // here, input_tensor's format need to be transposed nhwc according to fmkType,
    // but for the time being, we only transpose the tensor with 0/1/2/3D.
    // Others should be added in future.
    for (size_t j = 0; j < input_tensors.size(); ++j) {
      input_tensors[j]->SetFormat(schema::Format_NHWC);
      if (input_tensors[j]->shape().size() == 4) {
        MS_LOG(INFO) << "init input_tensor format to nhwc";
      }
    }
    lite_primitive->InferShape(input_tensors, output_tensors);
    auto parameter = kernel::PopulateParameter(lite_primitive.get());
    if (parameter == nullptr) {
      MS_LOG(ERROR) << "PopulateParameter return nullptr, type: "
                    << schema::EnumNamePrimitiveType((schema::PrimitiveType) (lite_primitive->Type()));
      return nullptr;
    }
    auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, lite_primitive.get());
    if (lite_kernel == nullptr) {
      MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr";
      FreeTensors(&input_tensors, &output_tensors);
      return nullptr;
    }
    auto ret = lite_kernel->Run();
    if (0 != ret) {
      FreeTensors(&input_tensors, &output_tensors);
      MS_LOG(ERROR) << "run kernel failed, name: " << lite_kernel->name();
      return nullptr;
    }
    // replace cnode by new param
    if (ReplaceCNode(func_graph, any_node, input_node, output_tensors, i) != lite::RET_OK) {
      FreeTensors(&input_tensors, &output_tensors);
      delete (lite_kernel);
      MS_LOG(ERROR) << "constant_folding replace cnode failed";
      return nullptr;
    }
    MS_LOG(DEBUG) << "fold node:" << input_node->fullname_with_scope() << " success ";
    FreeTensors(&input_tensors, &output_tensors);
    delete (lite_kernel);
  }
  return changed ? any_node : nullptr;
}
}  // namespace mindspore::opt


#Data process