<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

#Data process

In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef MINDSPORE_LITE_SRC_PASS_GRAPH_WEIGHT_FORMAT_HARDCODE_PASS_H_
#define MINDSPORE_LITE_SRC_PASS_GRAPH_WEIGHT_FORMAT_HARDCODE_PASS_H_
#include <string>
#include "schema/inner/model_generated.h"
#include "tools/converter/converter_flags.h"
#include "backend/optimizer/common/pass.h"
#include "src/param_value_lite.h"

using mindspore::lite::converter::FmkType;
using mindspore::schema::QuantType;
namespace mindspore::opt {
class WeightFormatHardCodePass : public Pass {
 public:
  WeightFormatHardCodePass() : Pass("weight_format_hardcode_pass") {}
  ~WeightFormatHardCodePass() override = default;
  void SetQuantType(QuantType type);
  void SetFmkType(FmkType fmkType);
  bool Run(const FuncGraphPtr &graph) override;

 private:
  lite::STATUS HardCodeCAFFE(const AnfNodePtr &node, const ParamValueLitePtr &param_value) const;
  lite::STATUS HardCodeONNX(const AnfNodePtr &node, const ParamValueLitePtr &param_value) const;
  lite::STATUS HardCodeMS(const AnfNodePtr &node, const ParamValueLitePtr &param_value) const;
  lite::STATUS HardCodeTFLITE(const AnfNodePtr &node, const ParamValueLitePtr &param_value) const;

 private:
  QuantType quant_type = schema::QuantType_QUANT_NONE;
  FmkType fmk_type = lite::converter::FmkType_TF;
};
}  // namespace mindspore::opt
#endif  // MINDSPORE_LITE_SRC_PASS_GRAPH_WEIGHT_FORMAT_HARDCODE_PASS_H_


In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "tools/optimizer/graph/weight_format_hardcode_pass.h"
#include "tools/optimizer/common/gllo_utils.h"

using mindspore::lite::converter::FmkType_CAFFE;
using mindspore::lite::converter::FmkType_TFLITE;
using mindspore::lite::converter::FmkType_ONNX;
using mindspore::lite::converter::FmkType_MS;
using mindspore::schema::QuantType_WeightQuant;
using mindspore::schema::QuantType_QUANT_NONE;
using mindspore::schema::QuantType_AwareTraining;
using mindspore::schema::QuantType_PostTraining;
namespace mindspore::opt {
namespace {
constexpr size_t kConvWeightIndex = 2;
bool IsConvExtendNode(const BaseRef &n) {
  if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
    auto type = opt::GetCNodeType(n);
    return type == schema::PrimitiveType_Conv2D || type == schema::PrimitiveType_DepthwiseConv2D ||
        type == schema::PrimitiveType_DeConv2D || type == schema::PrimitiveType_DeDepthwiseConv2D;
  }
  return false;
}
}  // namespace
void WeightFormatHardCodePass::SetQuantType(QuantType type) {
  this->quant_type = type;
}
void WeightFormatHardCodePass::SetFmkType(FmkType type) {
  this->fmk_type = type;
}
lite::STATUS WeightFormatHardCodePass::HardCodeCAFFE(const AnfNodePtr &conv_node,
                                                     const ParamValueLitePtr &param_value) const {
  MS_ASSERT(conv_cnode != nullptr);
  MS_ASSERT(param_value != nullptr);
  switch (quant_type) {
    case QuantType_WeightQuant:
    case QuantType_QUANT_NONE:param_value->set_format(schema::Format::Format_KCHW);
      break;
    default: {
      MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) << ", node: "
                    << conv_node->fullname_with_scope();
      return lite::RET_ERROR;
    }
  }
  return lite::RET_OK;
}

lite::STATUS WeightFormatHardCodePass::HardCodeONNX(const AnfNodePtr &conv_node,
                                                    const ParamValueLitePtr &param_value) const {
  MS_ASSERT(conv_cnode != nullptr);
  MS_ASSERT(param_value != nullptr);
  auto op_type = GetCNodeType(conv_node);
  switch (this->quant_type) {
    case QuantType_AwareTraining: {
      // sum up from current onnx quant models
      if (op_type == schema::PrimitiveType_Conv2D) {
        param_value->set_format(schema::Format::Format_KHWC);
      } else if (op_type == schema::PrimitiveType_DepthwiseConv2D) {
        param_value->set_format(schema::Format::Format_CHWK);
      } else if (op_type == schema::PrimitiveType_DeConv2D) {
        param_value->set_format(schema::Format::Format_CKHW);
      } else {
        MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: "
                      << conv_node->fullname_with_scope();
        return lite::RET_ERROR;
      }
    }
      break;
    case QuantType_WeightQuant:
    case QuantType_QUANT_NONE: {
      // conv (K x C/group x kH x kW) group = 1
      // depth (K x C/group x kH x kW) group = channelOut ==> (K, multiplier, H, W)
      // deconv (C x K/group x kH x kW) group = 1
      // dedepth (C x K/group x kH x kW) group = channelIn ==> (C, multiplier, H, W)
      if (op_type == schema::PrimitiveType_Conv2D || op_type == schema::PrimitiveType_DepthwiseConv2D) {
        param_value->set_format(schema::Format::Format_KCHW);
      } else if (op_type == schema::PrimitiveType_DeConv2D) {
        param_value->set_format(schema::Format::Format_CKHW);
      } else {
        MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: "
                      << conv_node->fullname_with_scope();
        return lite::RET_ERROR;
      }
    }
      break;
    default: {
      MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) << ", node: "
                    << conv_node->fullname_with_scope();
      return lite::RET_ERROR;
    }
  }
  return lite::RET_OK;
}

lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node,
                                                  const ParamValueLitePtr &param_value) const {
  MS_ASSERT(conv_cnode != nullptr);
  MS_ASSERT(param_value != nullptr);
  auto op_type = GetCNodeType(conv_node);
  switch (this->quant_type) {
    case QuantType_AwareTraining: {
      if (op_type == schema::PrimitiveType_Conv2D) {
        param_value->set_format(schema::Format::Format_KCHW);
      } else if (op_type == schema::PrimitiveType_DepthwiseConv2D) {
        param_value->set_format(schema::Format::Format_CKHW);
      } else {
        param_value->set_format(schema::Format::Format_KCHW);
      }
    }
      break;
    case QuantType_WeightQuant:
    case QuantType_QUANT_NONE: {
      // sum up from current ms quant models
      if (op_type == schema::PrimitiveType_Conv2D) {
        param_value->set_format(schema::Format::Format_KCHW);
      } else if (op_type == schema::PrimitiveType_DepthwiseConv2D) {
        param_value->set_format(schema::Format::Format_CKHW);
      } else if (op_type == schema::PrimitiveType_DeConv2D) {
        param_value->set_format(schema::Format::Format_KCHW);
      } else {
        MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: "
                      << conv_node->fullname_with_scope();
        return lite::RET_ERROR;
      }
    }
      break;
    default: {
      MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) << ", node: "
                    << conv_node->fullname_with_scope();
      return lite::RET_ERROR;
    }
  }
  return lite::RET_OK;
}

lite::STATUS WeightFormatHardCodePass::HardCodeTFLITE(const AnfNodePtr &conv_node,
                                                      const ParamValueLitePtr &param_value) const {
  MS_ASSERT(conv_cnode != nullptr);
  MS_ASSERT(param_value != nullptr);
  auto op_type = GetCNodeType(conv_node);
  switch (this->quant_type) {
    case QuantType_AwareTraining:
    case QuantType_PostTraining:
    case QuantType_WeightQuant:
    case QuantType_QUANT_NONE: {
      if (op_type == schema::PrimitiveType_Conv2D) {
        param_value->set_format(schema::Format::Format_KHWC);
      } else if (op_type == schema::PrimitiveType_DepthwiseConv2D) {
        param_value->set_format(schema::Format::Format_CHWK);
      } else if (op_type == schema::PrimitiveType_DeConv2D) {
        param_value->set_format(schema::Format::Format_CHWK);
      } else {
        MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: "
                      << conv_node->fullname_with_scope();
        return lite::RET_ERROR;
      }
    }
      break;
    default: {
      MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: "
                    << conv_node->fullname_with_scope();
      return lite::RET_ERROR;
    }
  }
  return lite::RET_OK;
}

bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) {
  MS_ASSERT(graph != nullptr);
  std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return());
  for (auto &node : node_list) {
    if (!utils::isa<CNode>(node)) {
      continue;
    }
    auto conv_cnode = node->cast<CNodePtr>();
    auto type = opt::GetCNodeType(node);
    if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D
        && type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) {
      continue;
    }
    MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex);
    auto param_value = GetLiteParamValue(conv_cnode->input(kConvWeightIndex));
    lite::STATUS status;
    switch (fmk_type) {
      case FmkType_CAFFE:status = HardCodeCAFFE(node, param_value);
        break;
      case FmkType_TFLITE:status = HardCodeTFLITE(node, param_value);
        break;
      case FmkType_ONNX:status = HardCodeONNX(node, param_value);
        break;
      case FmkType_MS:status = HardCodeMS(node, param_value);
        break;
      default:MS_LOG(ERROR) << "Unsupported fmkType: " << fmk_type << ", node: " << node->fullname_with_scope();
        return false;
    }
    if (status != lite::RET_OK) {
      MS_LOG(ERROR) << "schema::Format hardCode faild: " << status << ", node: " << node->fullname_with_scope();
      return false;
    }
  }
  return false;
}
}  // namespace mindspore::opt


In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef MINDSPORE_LITE_SRC_PASS_GRAPH_WEIGHT_FORMAT_TRANSFORM_PASS_H_
#define MINDSPORE_LITE_SRC_PASS_GRAPH_WEIGHT_FORMAT_TRANSFORM_PASS_H_
#include <string>
#include "schema/inner/model_generated.h"
#include "tools/converter/converter_flags.h"
#include "backend/optimizer/common/pass.h"

using mindspore::lite::converter::FmkType;
using mindspore::schema::QuantType;
namespace mindspore::opt {
class WeightFormatTransformPass : public Pass {
 public:
  WeightFormatTransformPass() : Pass("weight_format_transform_pass") {}
  ~WeightFormatTransformPass() override = default;
  void SetQuantType(QuantType type);
  void SetFmkType(FmkType fmkType);
  void SetDstFormat(schema::Format format);
  bool Run(const FuncGraphPtr &graph) override;

 private:
  lite::STATUS ConvWeightFormatTrans(const FuncGraphPtr &graph, bool quant_flag);

 private:
  QuantType quant_type = schema::QuantType_QUANT_NONE;
  FmkType fmk_type = lite::converter::FmkType_TF;
  schema::Format dst_format = schema::Format::Format_NUM_OF_FORMAT;
};
}  // namespace mindspore::opt
#endif  // MINDSPORE_LITE_SRC_PASS_GRAPH_WEIGHT_FORMAT_TRANSFORM_PASS_H_


In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "tools/optimizer/graph/weight_format_transform_pass.h"
#include "tools/optimizer/common/gllo_utils.h"

using mindspore::lite::converter::FmkType_CAFFE;
using mindspore::lite::converter::FmkType_TFLITE;
using mindspore::lite::converter::FmkType_ONNX;
using mindspore::lite::converter::FmkType_MS;
using mindspore::schema::QuantType_WeightQuant;
using mindspore::schema::QuantType_QUANT_NONE;
using mindspore::schema::QuantType_AwareTraining;
using mindspore::schema::QuantType_PostTraining;
namespace mindspore::opt {
namespace {
constexpr size_t kConvWeightIndex = 2;
}  // namespace
void WeightFormatTransformPass::SetQuantType(QuantType type) {
  this->quant_type = type;
}
void WeightFormatTransformPass::SetFmkType(FmkType type) {
  this->fmk_type = type;
}
void WeightFormatTransformPass::SetDstFormat(schema::Format format) {
  this->dst_format = format;
}
lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr &graph, bool quant_flag) {
  MS_ASSERT(graph != nullptr);
  std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return());
  for (auto &node : node_list) {
    if (!utils::isa<CNodePtr>(node)) {
      continue;
    }
    auto type = opt::GetCNodeType(node);
    if (quant_flag) {
      if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D) {
        continue;
      }
    } else {
      if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D
          && type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) {
        continue;
      }
    }
    auto conv_cnode = node->cast<CNodePtr>();
    MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex);
    auto weight_node = conv_cnode->input(kConvWeightIndex);
    auto weight_value = GetLiteParamValue(weight_node);
    MS_ASSERT(weight_value->tensor_type() == TypeId::kNumberTypeFloat32
                  || weight_value->tensor_type() == TypeId::kNumberTypeUInt8);
    lite::STATUS status;
    schema::Format weight_dst_format = schema::Format::Format_KHWC;
    if (dst_format != schema::Format::Format_NUM_OF_FORMAT) {
      weight_dst_format = dst_format;
    }
    status = TransFilterFormat(weight_value, weight_dst_format);
    if (status == RET_OK) {
      weight_value->set_format(weight_dst_format);
    } else {
      MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(schema::EnumValuesFormat()[weight_value->format()]) << "To"
                    << EnumNameFormat(weight_dst_format) << " failed, node : " << node->fullname_with_scope()
                    << "quant type:" << quant_type;
      return ERROR;
    }
    auto type_id = static_cast<TypeId>(weight_value->tensor_type());
    auto type_ptr = TypeIdToType(type_id);
    auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, weight_value->tensor_shape());
    weight_node->set_abstract(abstract_tensor);
  }
  return RET_OK;
}

bool WeightFormatTransformPass::Run(const FuncGraphPtr &func_graph) {
  MS_ASSERT(func_graph != nullptr);
  bool quant_flag = false;
  if (this->quant_type == QuantType_AwareTraining || this->quant_type == QuantType_WeightQuant) {
    quant_flag = true;
  }
  auto status = ConvWeightFormatTrans(func_graph, quant_flag);
  if (status != lite::RET_OK) {
    MS_LOG(ERROR) << "NonQuantDataFormatTrans failed: " << status;
    return status;
  }
  return false;
}
}  // namespace mindspore::opt
