<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

#Data process

In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "tools/optimizer/fusion/batchmatmul_fusion.h"
#include <memory>
#include "src/ops/primitive_c.h"
#include "src/param_value_lite.h"
#include "schema/inner/model_generated.h"
#include "utils/utils.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "securec/include/securec.h"

namespace mindspore::opt {
namespace {
bool IsStackNode(const BaseRef &n) {
  if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
    auto type = opt::GetCNodeType(n);
    return type == schema::PrimitiveType_Stack;
  }
  return false;
}
bool IsFullConnectNode(const BaseRef &n) {
  if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
    auto type = opt::GetCNodeType(n);
    return type == schema::PrimitiveType_FullConnection;
  }
  return false;
}
void *GetInputAddr(AnfNodePtr node, int input_index) {
  MS_ASSERT(node != nullptr);
  if (!node->isa<CNode>()) {
    MS_LOG(ERROR) << "GetInputAddr not cnode";
    return nullptr;
  }
  auto cnode = node->cast<CNodePtr>();
  if (input_index >= cnode->inputs().size()) {
    MS_LOG(ERROR) << "input index error";
    return nullptr;
  }
  if (cnode->input(input_index)->isa<Parameter>()) {
    auto param_input = cnode->input(input_index)->cast<ParameterPtr>();
    auto param_value = std::dynamic_pointer_cast<ParamValueLite>(param_input->default_param());;
    if (param_value == nullptr) {
      MS_LOG(ERROR) << "param not paramValueLite";
      return nullptr;
    }
    return param_value->tensor_addr();
  }
  MS_LOG(ERROR) << "input not paramter";
  return nullptr;
}
STATUS GetRightMatmulInputParamter(CNodePtr &stack_node, ParameterPtr &rmatmul_input) {
  MS_ASSERT(stack_node != nullptr);
  MS_ASSERT(right_matmul_input != nullptr);
  if (stack_node->inputs().size() != 9) {
    MS_LOG(ERROR) << "stack node inputs error";
  }
  auto fc = stack_node->input(1)->cast<CNodePtr>();
  auto fc_weight = fc->input(2)->cast<ParameterPtr>();
  auto fc_weight_param = std::dynamic_pointer_cast<ParamValueLite>(fc_weight->default_param());
  auto tensor_size = fc_weight_param->tensor_size();
  auto new_tensor_data = new(std::nothrow) uint8_t[8 * tensor_size];
  if (new_tensor_data == nullptr) {
    MS_LOG(ERROR) << "tensor_data is nullptr";
    return RET_ERROR;
  }
  for (int i = 1; i < 9; i++) {
    auto tensor_addr = GetInputAddr(stack_node->input(i), 2);
    if (tensor_addr == nullptr) {
      MS_LOG(ERROR) << "input tensor addr nullptr";
      return RET_ERROR;
    }
    if (EOK != memcpy_s(new_tensor_data + (i - 1) * tensor_size, tensor_size, tensor_addr, tensor_size)) {
      MS_LOG(ERROR) << "memcpy_s data failed";
      return RET_ERROR;
    }
  }
  auto rmatmul_input_shape = fc_weight_param->tensor_shape();
  rmatmul_input_shape.insert(rmatmul_input_shape.begin(), 8);
  auto type_ptr = TypeIdToType(fc_weight_param->tensor_type());
  auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, rmatmul_input_shape);
  rmatmul_input->set_abstract(abstract_tensor);
  rmatmul_input->set_name(stack_node->fullname_with_scope() + "right_parameter");
  ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
  MS_ASSERT(param_value != nullptr);
  param_value->set_tensor_shape(rmatmul_input_shape);
  param_value->set_tensor_type(fc_weight_param->tensor_type());
  param_value->set_format(fc_weight_param->format());
  param_value->set_tensor_addr(new_tensor_data);
  param_value->set_tensor_size(8 * tensor_size);
  rmatmul_input->set_default_param(param_value);
  return RET_OK;
}
}  // namespace
const BaseRef BatchMatMulFusion::DefinePattern() const {
  auto pack_var = std::make_shared<CondVar>(IsStackNode);
  auto fullconnect_var = std::make_shared<CondVar>(IsFullConnectNode);
  auto bn_other_var = std::make_shared<SeqVar>();
//  return VectorRef({pack_var, fullconnect_var, fullconnect_var, fullconnect_var, fullconnect_var, fullconnect_var,
//                    fullconnect_var, fullconnect_var, fullconnect_var});
  return VectorRef({pack_var, fullconnect_var, bn_other_var});
}

// slice +fullconnect ->batchmatmul
const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
                                            const EquivPtr &) const {

  MS_ASSERT(func_graph != nullptr);
  MS_ASSERT(node != nullptr);
  auto stack_cnode = node->cast<CNodePtr>();
  auto fullconnect_node = stack_cnode->input(1);
  MS_ASSERT(fullconnnect_node != nullptr);
  auto fullconnect_cnode = fullconnect_node->cast<CNodePtr>();
  MS_ASSERT(fullconnect_cnode->inputs().size() == 3);
  auto left_slice_node = fullconnect_cnode->input(1);
  auto left_slice_cnode = left_slice_node->cast<CNodePtr>();
  auto left_matmul_input = left_slice_cnode->input(1);
  auto right_reshape_node = fullconnect_cnode->input(2);

  auto matmul_primitive = std::make_unique<schema::PrimitiveT>();
  std::unique_ptr<schema::MatMulT> attr = std::make_unique<schema::MatMulT>();
  matmul_primitive->value.type = schema::PrimitiveType_MatMul;
  matmul_primitive->value.value = attr.release();
  auto matmul_cvalue = lite::PrimitiveC::Create(matmul_primitive.release());
  auto fc_prim = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(fullconnect_cnode->input(0));

  auto fc_input_quantParams = fc_prim->GetInputQuantParams();
  fc_input_quantParams.pop_back();
  matmul_cvalue->SetInputQuantParam(fc_input_quantParams);
  matmul_cvalue->SetOutputQuantParam(fc_prim->GetOutputQuantParams());

//  for (int i = 1; i < 9; i++) {
//    auto fc_node2 = stack_cnode->input(2)->cast<CNodePtr>();
//    auto fc_prim2 = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(fc_node2->input(0));
//    auto fc2_input_quant = fc_prim2->GetInputQuantParams();
//    auto fc2_output_quant = fc_prim2->GetOutputQuantParams();
//    for (auto input_quants :fc2_input_quant) {
//      for (auto input_quant:input_quants) {
//        auto scale = input_quant.scale;
//        auto zero_point = input_quant.zeroPoint;
//        MS_LOG_ERROR << stack_cnode->fullname_with_scope()  << ":scale:" << scale << ":zeroPoint:" << zero_point;
//      }
//    }
//  }

  auto matmul_value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(matmul_cvalue));
  std::vector<AnfNodePtr> matmul_inputs = {matmul_value_node, left_matmul_input};

  // right node may be const
  if (right_reshape_node->isa<Parameter>()) {
//    return stack_cnode;
    auto rmatmul_paramter = func_graph->add_parameter();
    if (GetRightMatmulInputParamter(stack_cnode, rmatmul_paramter) != RET_OK) {
      MS_LOG(ERROR) << "GetRightMatmulInputParamter failed";
      return node;
    }
    // transformer batchmatmul right input if const need set transpose
    auto prim = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(matmul_value_node);
    prim->GetPrimitiveT()->value.AsMatMul()->transposeB = true;
    matmul_inputs.push_back(rmatmul_paramter);
  } else {
    auto right_reshape_cnode = right_reshape_node->cast<CNodePtr>();
    MS_ASSERT(right_reshape_cnode->inputs().size() > 1);
    auto right_transpose_node = right_reshape_cnode->input(1);
    auto right_transpose_cnode = right_transpose_node->cast<CNodePtr>();
    auto right_slice_node = right_transpose_cnode->input(1);
    auto right_slice_cnode = right_slice_node->cast<CNodePtr>();
    auto right_matmul_input = right_slice_cnode->input(1);
    matmul_inputs.push_back(right_matmul_input);
  }
  auto matmul_cnode = func_graph->NewCNode(matmul_inputs);
  matmul_cnode->set_fullname_with_scope("matmul_" + stack_cnode->fullname_with_scope());
  return matmul_cnode;
}
}  // namespace mindspore::opt


In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CUSTOM_ATTENTION_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_CUSTOM_ATTENTION_FUSION_H_

#include "backend/optimizer/common/optimizer.h"
#include "tools/converter/converter_context.h"

namespace mindspore {
namespace opt {
class BatchMatMulFusion : public PatternProcessPass {
 public:
  explicit BatchMatMulFusion(bool multigraph = true) : PatternProcessPass("slice_fullconnect_fusion", multigraph) {}
  ~BatchMatMulFusion() override = default;
  const BaseRef DefinePattern() const override;
  const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
}  // namespace opt
}  // namespace mindspore
#endif  // MINDSPORE_LITE_SRC_PASS_FUSION_CUSTOM_ATTENTION_FUSION_H_
