<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

#Data process

In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "tools/optimizer/fusion/batchmatmul_fusion.h"
#include <memory>
#include "src/ops/primitive_c.h"
#include "src/param_value_lite.h"
#include "schema/inner/model_generated.h"
#include "utils/utils.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "securec/include/securec.h"

namespace mindspore::opt {
namespace {
bool IsStackNode(const BaseRef &n) {
  if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
    auto type = opt::GetCNodeType(n);
    return type == schema::PrimitiveType_Stack;
  }
  return false;
}
bool IsFullConnectNode(const BaseRef &n) {
  if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
    auto type = opt::GetCNodeType(n);
    return type == schema::PrimitiveType_FullConnection;
  }
  return false;
}
void *GetInputAddr(const AnfNodePtr &node, int input_index) {
  MS_ASSERT(node != nullptr);
  if (!node->isa<CNode>()) {
    MS_LOG(ERROR) << "GetInputAddr not cnode";
    return nullptr;
  }
  auto cnode = node->cast<CNodePtr>();
  if (input_index >= cnode->inputs().size()) {
    MS_LOG(ERROR) << "input index error";
    return nullptr;
  }
  if (cnode->input(input_index)->isa<Parameter>()) {
    auto param_input = cnode->input(input_index)->cast<ParameterPtr>();
    auto param_value = std::dynamic_pointer_cast<ParamValueLite>(param_input->default_param());;
    if (param_value == nullptr) {
      MS_LOG(ERROR) << "param not paramValueLite";
      return nullptr;
    }
    return param_value->tensor_addr();
  }
  MS_LOG(ERROR) << "input not paramter";
  return nullptr;
}
STATUS GetRightMatmulInputParamter(CNodePtr &stack_node, ParameterPtr &rmatmul_input) {
  MS_ASSERT(stack_node != nullptr);
  MS_ASSERT(right_matmul_input != nullptr);
  auto joint_fullconnect_size = stack_node->inputs().size() - 1;
  auto fc = stack_node->input(1)->cast<CNodePtr>();
  auto fc_weight = fc->input(2)->cast<ParameterPtr>();
  auto fc_weight_param = std::dynamic_pointer_cast<ParamValueLite>(fc_weight->default_param());
  auto tensor_size = fc_weight_param->tensor_size();
  auto rmatmul_input_shape = fc_weight_param->tensor_shape();
  auto new_tensor_data = new(std::nothrow) int8_t[joint_fullconnect_size * tensor_size];
  if (new_tensor_data == nullptr) {
    MS_LOG(ERROR) << "tensor_data is nullptr";
    return RET_ERROR;
  }
  for (int i = 1; i < joint_fullconnect_size + 1; i++) {
    auto tensor_addr = GetInputAddr(stack_node->input(i), 2);
    if (tensor_addr == nullptr) {
      MS_LOG(ERROR) << "input tensor addr nullptr";
      return RET_ERROR;
    }
    if (EOK != memcpy_s(new_tensor_data + (i - 1) * tensor_size, tensor_size, tensor_addr, tensor_size)) {
      MS_LOG(ERROR) << "memcpy_s data failed";
      return RET_ERROR;
    }
  }
  rmatmul_input_shape.insert(rmatmul_input_shape.begin(), joint_fullconnect_size);
  auto type_ptr = TypeIdToType(fc_weight_param->tensor_type());
  auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, rmatmul_input_shape);
  rmatmul_input->set_abstract(abstract_tensor);
  rmatmul_input->set_name(stack_node->fullname_with_scope() + "right_parameter");
  ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
  MS_ASSERT(param_value != nullptr);
  param_value->set_tensor_shape(rmatmul_input_shape);
  param_value->set_tensor_type(fc_weight_param->tensor_type());
  param_value->set_format(fc_weight_param->format());
  param_value->set_tensor_addr(new_tensor_data);
  param_value->set_tensor_size(joint_fullconnect_size * tensor_size);
  rmatmul_input->set_default_param(param_value);
  return RET_OK;
}
}  // namespace
const BaseRef BatchMatMulFusion::DefinePattern() const {
  auto pack_var = std::make_shared<CondVar>(IsStackNode);
  auto fullconnect_var = std::make_shared<CondVar>(IsFullConnectNode);
  auto bn_other_var = std::make_shared<SeqVar>();
  return VectorRef({pack_var, fullconnect_var, fullconnect_var, bn_other_var});
}

// slice +fullconnect ->batchmatmul
const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
                                            const EquivPtr &) const {

  MS_ASSERT(func_graph != nullptr);
  MS_ASSERT(node != nullptr);
  auto stack_cnode = node->cast<CNodePtr>();
  // check stack node all inputs must fullconnect
  for (int i = 1; i < stack_cnode->inputs().size(); i++) {
    auto input_node = stack_cnode->input(i);
    if (!IsFullConnectNode(input_node)) {
      MS_LOG(WARNING) << "batchmatmulfusion stack node all inputs must fullconnect type";
      return nullptr;
    }
  }
  auto fullconnect_node = stack_cnode->input(1);
  MS_ASSERT(fullconnnect_node != nullptr);
  auto fullconnect_cnode = fullconnect_node->cast<CNodePtr>();
  MS_ASSERT(fullconnect_cnode->inputs().size() == 3);
  auto left_slice_node = fullconnect_cnode->input(1);
  auto left_slice_cnode = left_slice_node->cast<CNodePtr>();
  auto left_matmul_input = left_slice_cnode->input(1);
  auto right_reshape_node = fullconnect_cnode->input(2);

  auto matmul_primitive = std::make_unique<schema::PrimitiveT>();
  std::unique_ptr<schema::MatMulT> attr = std::make_unique<schema::MatMulT>();
  matmul_primitive->value.type = schema::PrimitiveType_MatMul;
  matmul_primitive->value.value = attr.release();
  auto matmul_cvalue = lite::PrimitiveC::Create(matmul_primitive.release());
  // get matmul quantParams
  std::vector<schema::QuantParamT> jointed_quant_params;
  for (int i = 1; i < 9; i++) {
    auto fullconnect_node2 = stack_cnode->input(i)->cast<CNodePtr>();
    auto fc_prim = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(fullconnect_node2->input(0));
    auto fc_input_quantParams = fc_prim->GetInputQuantParams();
    if (fc_input_quantParams.size() > 1 && !fc_input_quantParams[1].empty()) {
      jointed_quant_params.push_back(fc_input_quantParams[1][0]);
    }
  }
  auto fc_prim = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(fullconnect_cnode->input(0));
  auto rmatmul_quant_params = fc_prim->GetInputQuantParams();
  rmatmul_quant_params.pop_back();
  rmatmul_quant_params.pop_back();
  // no bias quantParams
  rmatmul_quant_params.emplace_back(jointed_quant_params);
  matmul_cvalue->SetInputQuantParam(rmatmul_quant_params);
  matmul_cvalue->SetOutputQuantParam(fc_prim->GetOutputQuantParams());
  auto matmul_value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(matmul_cvalue));
  std::vector<AnfNodePtr> matmul_inputs = {matmul_value_node, left_matmul_input};

  // batchmatmul right node may be const
  if (right_reshape_node->isa<Parameter>()) {
//    return stack_cnode;
    auto rmatmul_paramter = func_graph->add_parameter();
    if (GetRightMatmulInputParamter(stack_cnode, rmatmul_paramter) != RET_OK) {
      MS_LOG(ERROR) << "GetRightMatmulInputParamter failed";
      return node;
    }
    auto prim = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(matmul_value_node);
    prim->GetPrimitiveT()->value.AsMatMul()->transposeB = true;
    matmul_inputs.push_back(rmatmul_paramter);
  } else {
    auto right_reshape_cnode = right_reshape_node->cast<CNodePtr>();
    MS_ASSERT(right_reshape_cnode->inputs().size() > 1);
    auto right_transpose_node = right_reshape_cnode->input(1);
    auto right_transpose_cnode = right_transpose_node->cast<CNodePtr>();
    auto right_slice_node = right_transpose_cnode->input(1);
    auto right_slice_cnode = right_slice_node->cast<CNodePtr>();
    auto right_matmul_input = right_slice_cnode->input(1);
    matmul_inputs.push_back(right_matmul_input);
  }
  auto matmul_cnode = func_graph->NewCNode(matmul_inputs);
  matmul_cnode->set_fullname_with_scope("matmul_" + stack_cnode->fullname_with_scope());
  MS_LOG(INFO) << "stack node:" << stack_cnode->fullname_with_scope() << " batchmatmul fusion success";
  return matmul_cnode;
}
}  // namespace mindspore::opt


In [None]:
 if (input_tensor->shape().size() == 3
      && input_tensor->GetQuantParams().size() == input_tensor->shape()[0]) { // per batch matmul
    auto per_batch_size = input_tensor->shape()[0];
    auto quant_param = input_tensor->GetQuantParams();
    for (size_t i = 0; i < per_batch_size; i++) {
      auto param = quant_param.at(i);
      auto scale = param.scale;
      auto zero_point = param.zeroPoint;
      auto matrix_size = input_tensor->ElementsNum() / per_batch_size;
      for (int64_t j = 0; j < matrix_size; j++) {
        dequant_datas[i * matrix_size + j] =
            static_cast<float>((quant_datas[i * matrix_size + j] - zero_point) * scale);
      }
    }
    return dequant_datas;
  }

In [None]:
/**
 * Copyright 2019-2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "src/ops/matmul.h"
#include <memory>
#include <utility>
#ifdef PRIMITIVE_WRITEABLE
#include "tools/converter/quantizer/quantize_util.h"
#endif

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
bool MatMul::GetTransposeA() const { return this->primitive_->value.AsMatMul()->transposeA; }
bool MatMul::GetTransposeB() const { return this->primitive_->value.AsMatMul()->transposeB; }

void MatMul::SetTransposeA(bool transpose_a) { this->primitive_->value.AsMatMul()->transposeA = transpose_a; }
void MatMul::SetTransposeB(bool transpose_b) { this->primitive_->value.AsMatMul()->transposeB = transpose_b; }

int MatMul::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
  if (this->primitive_ == nullptr) {
    this->primitive_ = new(std::nothrow) schema::PrimitiveT;
    if (this->primitive_ == nullptr) {
      MS_LOG(ERROR) << "new primitiveT failed";
      return RET_ERROR;
    }
    this->primitive_->value.type = schema::PrimitiveType_MatMul;
  }
  if (this->primitive_->value.type != schema::PrimitiveType_MatMul) {
    MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
    return RET_ERROR;
  }
  if (this->primitive_->value.value == nullptr) {
    auto attr = new(std::nothrow) schema::MatMulT();
    if (attr == nullptr) {
      MS_LOG(ERROR) << "new primitiveT value failed";
      return RET_ERROR;
    }
    attr->transposeA = GetValue<bool>(prim.GetAttr("transpose_a"));
    attr->transposeB = GetValue<bool>(prim.GetAttr("transpose_b"));
    this->primitive_->value.value = attr;
    if (this->primitive_->value.value == nullptr) {
      MS_LOG(ERROR) << "primitive value is nullptr";
      return RET_ERROR;
    }
  }
  if (GetQuantType() == schema::QuantType_AwareTraining) {
    std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
    std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
    PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs);
    SetInputQuantParam(vecInputQuantParam);
    SetOutputQuantParam(vecOutputQuantParam);
  }
  return RET_OK;
}

#else

bool MatMul::GetTransposeA() const { return this->primitive_->value_as_MatMul()->transposeA(); }
bool MatMul::GetTransposeB() const { return this->primitive_->value_as_MatMul()->transposeB(); }

int MatMul::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
  MS_ASSERT(nullptr != primitive);
  MS_ASSERT(nullptr != fbb);
  auto attr = primitive->value_as_MatMul();
  if (attr == nullptr) {
    MS_LOG(ERROR) << "value_as_MatMul return nullptr";
    return RET_ERROR;
  }
  auto val_offset = schema::CreateMatMul(*fbb, attr->broadcast(), attr->transposeA(), attr->transposeB());
  auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_MatMul, val_offset.o);
  fbb->Finish(prim_offset);
  return RET_OK;
}

#endif

int MatMul::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
  MS_ASSERT(this->primitive_ != nullptr);
  auto input0 = inputs_.front();
  MS_ASSERT(input0 != nullptr);
  auto input1 = inputs_.at(1);
  MS_ASSERT(input1 != nullptr);
  auto output = outputs_.front();
  MS_ASSERT(output != nullptr);

  output->set_data_type(input0->data_type());
  output->SetFormat(input0->GetFormat());
  if (!GetInferFlag()) {
    return RET_OK;
  }

  std::vector<int> a_shape = input0->shape();
  std::vector<int> b_shape = input1->shape();
  if (a_shape.size() < 2 || b_shape.size() < 2) {
    MS_LOG(ERROR) << "inputs shape is invalid";
    return RET_INPUT_TENSOR_ERROR;
  }
  bool need_broadcast = false;
  if (a_shape.size() != b_shape.size() || (a_shape[0] == 1 && b_shape[0] != 1)
      || (a_shape[0] != 1 && b_shape[0] == 1)) {
    need_broadcast = true;
  }
  if (!need_broadcast) {
    for (size_t i = 0; i < a_shape.size() - 2; ++i) {
      if (a_shape[i] != b_shape[i]) {
        MS_LOG(ERROR) << "Op MatMul's dimensions must be equal";
        return RET_INPUT_TENSOR_ERROR;
      }
    }
  }

  if (GetTransposeA()) {
    std::swap(a_shape[a_shape.size() - 1], a_shape[a_shape.size() - 2]);
  }
  if (GetTransposeB()) {
    std::swap(b_shape[b_shape.size() - 1], b_shape[b_shape.size() - 2]);
  }
  std::vector<int> c_shape(a_shape);
  c_shape[c_shape.size() - 1] = b_shape[b_shape.size() - 1];
  if (need_broadcast) {
    size_t batch = 1;
    if (a_shape.size() != b_shape.size()) {
      batch = a_shape.size() > b_shape.size() ? a_shape[0] : b_shape[0];
      if (b_shape.size() > a_shape.size()) {
        c_shape.insert(c_shape.begin(), batch);
      }
    } else {
      batch = a_shape[0] != 1 ? a_shape[0] : b_shape[0];
      c_shape[0] = batch;
    }
  }
  output->set_shape(c_shape);
  return RET_OK;
}
}  // namespace lite
}  // namespace mindspore
