<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "tools/optimizer/fusion/conv_biasadd_fusion.h"
#include <memory>
#include "src/ops/conv2d.h"
#include "src/ops/depthwise_conv2d.h"
#include "src/ops/deconv2d.h"
#include "src/ops/add.h"
#include "src/ops/primitive_c.h"
#include "src/param_value_lite.h"
#include "schema/inner/model_generated.h"
#include "utils/utils.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "securec/include/securec.h"

namespace mindspore::opt {
namespace {
constexpr size_t kAddInputsLength = 3;
constexpr size_t kAddWEIGHTINDEX = 2;
constexpr size_t kConvWeightIndex = 2;
constexpr size_t kConvBiasIndex = 3;
constexpr size_t kConvNoBiasLen = 3;
constexpr size_t kConvWithBiasLen = 4;
bool IsConvExtendNode(const BaseRef &n) {
  if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
    auto type = opt::GetCNodeType(n);
    return type == schema::PrimitiveType_Conv2D || type == schema::PrimitiveType_DepthwiseConv2D ||
        type == schema::PrimitiveType_DeConv2D;
  }
  return false;
}
bool IsAddNode(const BaseRef &n) {
  if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
    auto type = opt::GetCNodeType(n);
    return type == schema::PrimitiveType_Add || type == schema::PrimitiveType_BiasAdd;
  }
  return false;
}

int Get_Kenrnel_nums(const CNodePtr &conv_node) {
  MS_ASSERT(conv_node != nullptr);
  auto value_primitive = conv_node->input(0);
  auto value_node = value_primitive->cast<ValueNodePtr>();
  MS_ASSERT(value_node != nullptr);
  auto value = value_node->value();
  MS_ASSERT(value != nullptr);
  auto primitive = value->cast<PrimitiveCPtr>();
  MS_ASSERT(primitive != nullptr);
  auto type = (schema::PrimitiveType) primitive->Type();
  if (type == schema::PrimitiveType_Conv2D) {
    MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Conv2D>>(primitive));
    auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitive);
    MS_ASSERT(primc != nullptr);
    return primc->GetChannelOut();
  } else if (type == schema::PrimitiveType_DepthwiseConv2D) {
    MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive));
    auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive);
    MS_ASSERT(primc != nullptr);
    return primc->GetChannelMultiplier() * primc->GetChannelIn();
  } else if (type == schema::PrimitiveType_DeConv2D) {
    MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::DeConv2D>>(primitive));
    auto primc = utils::cast<std::shared_ptr<mindspore::lite::DeConv2D>>(primitive);
    MS_ASSERT(primc != nullptr);
    return primc->GetChannelOut();
  } else {
    MS_LOG(ERROR) << "Unsupported opType, " << type;
    return 0;
  }
}
void GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, const CNodePtr &bias_node) {
  AnfNodePtr conv_bias_node = nullptr;
  AnfNodePtr conv_weight_node = nullptr;
  if (conv_node->inputs().size() == kConvNoBiasLen) {
    conv_weight_node = conv_node->input(kConvWeightIndex);
  } else if (conv_node->inputs().size() == kConvWithBiasLen) {
    conv_weight_node = conv_node->input(kConvWeightIndex);
    conv_bias_node = conv_node->input(kConvBiasIndex);
  } else {
    MS_LOG(EXCEPTION) << "conv node:" << conv_node->DebugString() << "inputs size must 3 or 4";
  }
  auto kernel_nums = Get_Kenrnel_nums(conv_node);
  if (kernel_nums <= 0) {
    MS_LOG(EXCEPTION) << "kernel num less than 0";
  }
  auto add_bias_data = new(std::nothrow) float[kernel_nums];
  if (add_bias_data == nullptr) {
    MS_LOG(ERROR) << "tensor_data is nullptr";
    return;
  }
  auto bias_add_weight = bias_node->input(kAddWEIGHTINDEX);
  CheckIfNodeIsParam(bias_add_weight);
  auto add_weight_param = bias_add_weight->cast<ParameterPtr>()->default_param();
  auto add_weight_tensor = std::dynamic_pointer_cast<ParamValueLite>(add_weight_param);
  auto add_weight_data = reinterpret_cast<float *>(add_weight_tensor->tensor_addr());
  auto add_weight_shape = add_weight_tensor->tensor_shape();
  if (add_weight_shape.empty() || (add_weight_shape.size() == 1 && add_weight_shape[0] == 1)) {
    for (int i = 0; i < kernel_nums; i++) {
      add_bias_data[i] = *add_weight_data;
    }
  } else {
    if (EOK != memcpy_s(add_bias_data, kernel_nums * sizeof(float), add_weight_data, kernel_nums * sizeof(float))) {
      MS_LOG(EXCEPTION) << "memset_s conv_bias_data failed";
    }
  }
  if (conv_bias_node != nullptr) {
    CheckIfNodeIsParam(conv_bias_node);
    auto conv_bias_param = conv_bias_node->cast<ParameterPtr>()->default_param();
    auto conv_bias_tensor = std::dynamic_pointer_cast<ParamValueLite>(conv_bias_param);
    if (conv_bias_tensor->tensor_shape().empty() || conv_bias_tensor->tensor_shape()[0] != kernel_nums) {
      MS_LOG(EXCEPTION) << "conv_bias_node shape error";
    }
    auto conv_bias_data = reinterpret_cast<float *>(conv_bias_tensor->tensor_addr());
    for (int i = 0; i < kernel_nums; i++) {
      conv_bias_data[i] += add_bias_data[i];
    }
    delete[] add_bias_data;
  } else {
    auto conv_weight_param = conv_weight_node->cast<ParameterPtr>()->default_param();
    auto conv_weight_tensor = std::dynamic_pointer_cast<ParamValueLite>(conv_weight_param);
    auto conv_new_bias = AddNewBiasNode(add_bias_data, func_graph, kernel_nums, conv_weight_tensor);
    conv_new_bias->set_name(conv_node->fullname_with_scope() + "_bias");
    conv_node->add_input(conv_new_bias);
  }
}
}  // namespace
const BaseRef ConvBiasaddFusion::DefinePattern() const {
  auto conv_var = std::make_shared<CondVar>(IsConvExtendNode);
  auto add_var = std::make_shared<CondVar>(IsAddNode);
  auto weight_var = std::make_shared<CondVar>(IsParamNode);
  return VectorRef({add_var, conv_var, weight_var});
}

const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
                                            const EquivPtr &) const {
  MS_LOG(DEBUG) << "Enter pass process";
  CheckIfFuncGraphIsNull(func_graph);

  CheckIfAnfNodeIsNull(node);
  auto add_node = node->cast<CNodePtr>();
  CheckIfCNodeIsNull(add_node);
  CheckInputSize(add_node, kAddInputsLength);

  if (GetCNodeType(add_node) == schema::PrimitiveType_Add) {
    auto primitive_c = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(add_node->input(0));
    MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Add>>(primitive_c));
    auto primc = utils::cast<std::shared_ptr<mindspore::lite::Add>>(primitive_c);
    MS_ASSERT(primc != nullptr);
    if (primc->GetActivationType() != schema::ActivationType_NO_ACTIVATION) {
      return add_node;
    }
  }

  AnfNodePtr conv_node_anf = add_node->input(1);
  CheckIfAnfNodeIsNull(conv_node_anf);
  if (IsMultiOutputTensors(func_graph, conv_node_anf)) {
    return nullptr;
  }
  auto conv_node = conv_node_anf->cast<CNodePtr>();
  CheckIfCNodeIsNull(conv_node);
  GenConvNewBias(func_graph, conv_node, add_node);
  auto primitive_c = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(conv_node->input(0));
  MS_ASSERT(primitive_c != nullptr);
  auto type = primitive_c->Type();
  if (type == schema::PrimitiveType_Conv2D) {
    MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Conv2D>>(primitive_c));
    auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitive_c);
    MS_ASSERT(primc != nullptr);
    primc->SetHasBias(true);
  } else if (type == schema::PrimitiveType_DepthwiseConv2D) {
    MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive_c));
    auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive_c);
    MS_ASSERT(primc != nullptr);
    primc->SetHasBias(true);
  } else if (type == schema::PrimitiveType_DeConv2D) {
    MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::DeConv2D>>(primitive_c));
    auto primc = utils::cast<std::shared_ptr<mindspore::lite::DeConv2D>>(primitive_c);
    MS_ASSERT(primc != nullptr);
    primc->SetHasBias(true);
  } else {
    MS_LOG(ERROR) << "Unsupported opType, " << type;
    return nullptr;
  }
  return conv_node;
}
}  // namespace mindspore::opt


In [None]:
/**
 * Copyright 2019-2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "src/ops/log.h"
#include <memory>

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Log::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
  if (this->primitive_ == nullptr) {
    this->primitive_ = new (std::nothrow) schema::PrimitiveT;
    if (this->primitive_ == nullptr) {
      MS_LOG(ERROR) << "new primitiveT failed";
      return RET_ERROR;
    }
    this->primitive_->value.type = schema::PrimitiveType_Log;
  }
  if (this->primitive_->value.type != schema::PrimitiveType_Log) {
    MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
    return RET_ERROR;
  }
  auto attr = std::make_unique<schema::LogT>();
  this->primitive_->value.value = attr.release();
  if (this->primitive_->value.value == nullptr) {
    MS_LOG(ERROR) << "new primitiveT value failed";
    return RET_ERROR;
  }
  return RET_OK;
}
#else
int Log::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
  MS_ASSERT(nullptr != primitive);
  MS_ASSERT(nullptr != fbb);
  auto val_offset = schema::CreateLog(*fbb);
  auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Log, val_offset.o);
  fbb->Finish(prim_offset);
  return RET_OK;
}
#endif
}  // namespace lite
}  // namespace mindspore


In [None]:
/**
 * Copyright 2019-2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "src/ops/elu.h"
#include <memory>

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
float Elu::GetAlpha() const { return this->primitive_->value.AsElu()->alpha; }

void Elu::SetAlpha(float alpha) { this->primitive_->value.AsElu()->alpha = alpha; }

int Elu::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
  if (this->primitive_ == nullptr) {
    this->primitive_ = new (std::nothrow) schema::PrimitiveT;
    if (this->primitive_ == nullptr) {
      MS_LOG(ERROR) << "new primitiveT failed";
      return RET_ERROR;
    }
    this->primitive_->value.type = schema::PrimitiveType_Elu;
  }
  if (this->primitive_->value.type != schema::PrimitiveType_Elu) {
    MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
    return RET_ERROR;
  }
  auto attr = std::make_unique<schema::EluT>();
  this->primitive_->value.value = attr.release();
  if (this->primitive_->value.value == nullptr) {
    MS_LOG(ERROR) << "new primitiveT value failed";
    return RET_ERROR;
  }
  return RET_OK;
}
#else
int Elu::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
  MS_ASSERT(nullptr != primitive);
  MS_ASSERT(nullptr != fbb);
  auto attr = primitive->value_as_Elu();
  if (attr == nullptr) {
    MS_LOG(ERROR) << "value_as_Elu return nullptr";
    return RET_ERROR;
  }
  auto val_offset = schema::CreateElu(*fbb, attr->alpha());
  auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Elu, val_offset.o);
  fbb->Finish(prim_offset);
  return RET_OK;
}
float Elu::GetAlpha() const { return this->primitive_->value_as_Elu()->alpha(); }

#endif
}  // namespace lite
}  // namespace mindspore


In [None]:
/**
 * Copyright 2019-2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "src/ops/deconv2d.h"
#include <memory>
#include <string>

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int DeConv2D::GetFormat() const { return this->primitive_->value.AsDeConv2D()->format; }
int DeConv2D::GetGroup() const { return this->primitive_->value.AsDeConv2D()->group; }
int DeConv2D::GetChannelIn() const { return this->primitive_->value.AsDeConv2D()->channelIn; }
int DeConv2D::GetChannelOut() const { return this->primitive_->value.AsDeConv2D()->channelOut; }
int DeConv2D::GetKernelW() const { return this->primitive_->value.AsDeConv2D()->kernelW; }
int DeConv2D::GetKernelH() const { return this->primitive_->value.AsDeConv2D()->kernelH; }
int DeConv2D::GetStrideW() const { return this->primitive_->value.AsDeConv2D()->strideW; }
int DeConv2D::GetStrideH() const { return this->primitive_->value.AsDeConv2D()->strideH; }
int DeConv2D::GetPadMode() const { return this->primitive_->value.AsDeConv2D()->padMode; }
int DeConv2D::GetPadUp() const { return this->primitive_->value.AsDeConv2D()->padUp; }
int DeConv2D::GetPadDown() const { return this->primitive_->value.AsDeConv2D()->padDown; }
int DeConv2D::GetPadLeft() const { return this->primitive_->value.AsDeConv2D()->padLeft; }
int DeConv2D::GetPadRight() const { return this->primitive_->value.AsDeConv2D()->padRight; }
int DeConv2D::GetDilateW() const { return this->primitive_->value.AsDeConv2D()->dilateW; }
int DeConv2D::GetDilateH() const { return this->primitive_->value.AsDeConv2D()->dilateH; }
bool DeConv2D::GetHasBias() const { return this->primitive_->value.AsDeConv2D()->hasBias; }
int DeConv2D::GetActivationType() const { return this->primitive_->value.AsDeConv2D()->activationType; }

void DeConv2D::SetFormat(int format) { this->primitive_->value.AsDeConv2D()->format = (schema::Format)format; }
void DeConv2D::SetGroup(int group) { this->primitive_->value.AsDeConv2D()->group = group; }
void DeConv2D::SetChannelIn(int channel_in) { this->primitive_->value.AsDeConv2D()->channelIn = channel_in; }
void DeConv2D::SetChannelOut(int channel_out) { this->primitive_->value.AsDeConv2D()->channelOut = channel_out; }
void DeConv2D::SetKernelW(int kernel_w) { this->primitive_->value.AsDeConv2D()->kernelW = kernel_w; }
void DeConv2D::SetKernelH(int kernel_h) { this->primitive_->value.AsDeConv2D()->kernelH = kernel_h; }
void DeConv2D::SetStrideW(int stride_w) { this->primitive_->value.AsDeConv2D()->strideW = stride_w; }
void DeConv2D::SetStrideH(int stride_h) { this->primitive_->value.AsDeConv2D()->strideH = stride_h; }
void DeConv2D::SetPadMode(int pad_mode) { this->primitive_->value.AsDeConv2D()->padMode = (schema::PadMode)pad_mode; }
void DeConv2D::SetPadUp(int pad_up) { this->primitive_->value.AsDeConv2D()->padUp = pad_up; }
void DeConv2D::SetPadDown(int pad_down) { this->primitive_->value.AsDeConv2D()->padDown = pad_down; }
void DeConv2D::SetPadLeft(int pad_left) { this->primitive_->value.AsDeConv2D()->padLeft = pad_left; }
void DeConv2D::SetPadRight(int pad_right) { this->primitive_->value.AsDeConv2D()->padRight = pad_right; }
void DeConv2D::SetDilateW(int dilate_w) { this->primitive_->value.AsDeConv2D()->dilateW = dilate_w; }
void DeConv2D::SetDilateH(int dilate_h) { this->primitive_->value.AsDeConv2D()->dilateH = dilate_h; }
void DeConv2D::SetHasBias(bool has_bias) { this->primitive_->value.AsDeConv2D()->hasBias = has_bias; }
void DeConv2D::SetActivationType(int activation_type) {
  this->primitive_->value.AsDeConv2D()->activationType = (schema::ActivationType)activation_type;
}
void DeConv2D::PopulaterDeConv2DSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group) {
  auto attr = std::make_unique<schema::DeConv2DT>();
  attr->group = group;
  auto format = GetValue<std::string>(prim.GetAttr("data_format"));
  if (format == "NCHW") {
    attr->format = schema::Format_NCHW;
  } else if (format == "NHWC") {
    attr->format = schema::Format_NHWC;
  } else {
    attr->format = schema::Format_NUM_OF_FORMAT;
  }
  auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
  attr->padUp = pad_list[0];
  attr->padDown = pad_list[1];
  attr->padLeft = pad_list[2];
  attr->padRight = pad_list[3];

  auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
  attr->dilateH = dilation[0];
  attr->dilateW = dilation[1];

  auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
  attr->kernelH = kernel_size[0];
  attr->kernelW = kernel_size[1];

  auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
  attr->strideH = stride[0];
  attr->strideW = stride[1];

  attr->channelOut = GetValue<int>(prim.GetAttr("out_channel"));

  auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
  if (pad_mode == "valid" || pad_mode == "VALID") {
    attr->padMode = schema::PadMode_VALID;
  } else if (pad_mode == "same" || pad_mode == "SAME") {
    attr->padMode = schema::PadMode_SAME;
  } else {
    attr->padMode = schema::PadMode_NOTSET;
  }

  if (prim.GetAttr("activation_name") != nullptr) {
    std::string activate_name = GetValue<std::string>(prim.GetAttr("activation_name"));
    attr->activationType = kActivationTypeMap[activate_name];
  } else {
    attr->activationType = schema::ActivationType_NO_ACTIVATION;
  }

  //  attr->padMode = schema::PadMode_SAME;
  //  attr->activationType = schema::ActivationType_RELU;
  primitive->value.type = schema::PrimitiveType_DeConv2D;
  primitive->value.value = attr.release();
}

int DeConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
  if (this->primitive_ == nullptr) {
    this->primitive_ = new (std::nothrow) schema::PrimitiveT;
    if (this->primitive_ == nullptr) {
      MS_LOG(ERROR) << "new primitiveT failed";
      return RET_ERROR;
    }
    this->primitive_->value.type = schema::PrimitiveType_DeConv2D;
  }
  if (this->primitive_->value.type != schema::PrimitiveType_DeConv2D) {
    MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
    return RET_ERROR;
  }
  int group = GetValue<int>(prim.GetAttr("group"));
  if (group == 1) {
    PopulaterDeConv2DSingleGroup(prim, this->primitive_, group);
  }

  if (GetQuantType() == schema::QuantType_AwareTraining) {
    std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
    std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
    PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam);
    SetInputQuantParam(vecInputQuantParam);
    SetOutputQuantParam(vecOutputQuantParam);
  }
  return RET_OK;
}
#else
int DeConv2D::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
  MS_ASSERT(nullptr != primitive);
  MS_ASSERT(nullptr != fbb);
  auto attr = primitive->value_as_DeConv2D();
  if (attr == nullptr) {
    MS_LOG(ERROR) << "value_as_DeConv2D return nullptr";
    return RET_ERROR;
  }
  auto val_offset = schema::CreateDeConv2D(
    *fbb, attr->format(), attr->group(), attr->channelIn(), attr->channelOut(), attr->kernelW(), attr->kernelH(),
    attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(),
    attr->padRight(), attr->dilateW(), attr->dilateH(), attr->hasBias(), attr->activationType());
  auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_DeConv2D, val_offset.o);
  fbb->Finish(prim_offset);
  return RET_OK;
}
int DeConv2D::GetFormat() const { return this->primitive_->value_as_DeConv2D()->format(); }
int DeConv2D::GetGroup() const { return this->primitive_->value_as_DeConv2D()->group(); }
int DeConv2D::GetChannelIn() const { return this->primitive_->value_as_DeConv2D()->channelIn(); }
int DeConv2D::GetChannelOut() const { return this->primitive_->value_as_DeConv2D()->channelOut(); }
int DeConv2D::GetKernelW() const { return this->primitive_->value_as_DeConv2D()->kernelW(); }
int DeConv2D::GetKernelH() const { return this->primitive_->value_as_DeConv2D()->kernelH(); }
int DeConv2D::GetStrideW() const { return this->primitive_->value_as_DeConv2D()->strideW(); }
int DeConv2D::GetStrideH() const { return this->primitive_->value_as_DeConv2D()->strideH(); }
int DeConv2D::GetPadMode() const { return this->primitive_->value_as_DeConv2D()->padMode(); }
int DeConv2D::GetPadUp() const { return this->primitive_->value_as_DeConv2D()->padUp(); }
int DeConv2D::GetPadDown() const { return this->primitive_->value_as_DeConv2D()->padDown(); }
int DeConv2D::GetPadLeft() const { return this->primitive_->value_as_DeConv2D()->padLeft(); }
int DeConv2D::GetPadRight() const { return this->primitive_->value_as_DeConv2D()->padRight(); }
int DeConv2D::GetDilateW() const { return this->primitive_->value_as_DeConv2D()->dilateW(); }
int DeConv2D::GetDilateH() const { return this->primitive_->value_as_DeConv2D()->dilateH(); }
bool DeConv2D::GetHasBias() const { return this->primitive_->value_as_DeConv2D()->hasBias(); }
int DeConv2D::GetActivationType() const { return this->primitive_->value_as_DeConv2D()->activationType(); }

#endif
int DeConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
  MS_ASSERT(this->primitive_ != nullptr);
  auto input = inputs_.front();
  MS_ASSERT(input != nullptr);
  auto weight = inputs_.at(1);
  MS_ASSERT(weight != nullptr);
  auto output = outputs_.front();
  MS_ASSERT(output != nullptr);
  output->SetFormat(input->GetFormat());
  output->set_data_type(input->data_type());
  if (!GetInferFlag()) {
    return RET_OK;
  }
  int32_t input_h = input->Height();
  int32_t input_w = input->Width();

  int32_t output_n = input->Batch();
  int32_t output_h = 0;
  int32_t output_w = 0;
  int32_t output_c = weight->Channel();

  int kernel_w = GetKernelW();
  int kernel_h = GetKernelH();
  int stride_w = GetStrideW();
  int stride_h = GetStrideH();
  int dilate_w = GetDilateW();
  int dilate_h = GetDilateH();
  pad_l_ = GetPadLeft();
  pad_u_ = GetPadUp();
  pad_d_ = GetPadDown();
  pad_r_ = GetPadRight();
  auto pad_mode = (schema::PadMode)GetPadMode();
  if (pad_mode == schema::PadMode_CAFFE) {
    output_h = (input_h - 1) * stride_h + ((kernel_h - 1) * dilate_h + 1) - pad_u_ - pad_d_;
    output_w = (input_w - 1) * stride_w + ((kernel_w - 1) * dilate_w + 1) - pad_l_ - pad_r_;
  } else if (pad_mode == schema::PadMode_SAME) {
    output_h = input_h * stride_h;
    output_w = input_w * stride_w;
  } else if (pad_mode == schema::PadMode_VALID) {
    output_h = (input_h - 1) * stride_h + kernel_h;
    output_w = (input_w - 1) * stride_w + kernel_w;
  } else {
    MS_LOG(ERROR) << "unsupported pad mode for deconv";
  }
  std::vector<int> out_shape = {output_n, output_h, output_w, output_c};
  output->set_shape(out_shape);

  if (pad_mode == schema::PadMode_SAME) {
    pad_u_ = ((input_h - 1) * stride_h + (kernel_h - 1) * dilate_h + 1 - output_h) / 2;
    pad_l_ = ((input_w - 1) * stride_w + (kernel_w - 1) * dilate_w + 1 - output_w) / 2;
  } else if (pad_mode == schema::PadMode_VALID) {
    pad_u_ = 0;
    pad_l_ = 0;
  } else if (pad_mode == schema::PadMode_CAFFE) {
  } else {
    MS_LOG(ERROR) << "unsupported pad mode for deconv";
  }

  return 0;
}
}  // namespace lite
}  // namespace mindspore


#Data process