<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <memory>
#include "schema/inner/model_generated.h"
#include "include/model.h"
#include "common/common_test.h"
#include "include/lite_session.h"
#include "include/context.h"
#include "include/errorcode.h"
#include "utils/log_adapter.h"
#include "tools/converter/model_parser.h"
#include "tools/converter/anf_transform.h"
#include "tools/optimizer/fusion/constant_folding_fusion.h"
#include "tools/anf_exporter/anf_exporter.h"

namespace mindspore {
class ConstantFoldingFusionTest : public mindspore::CommonTest {
 public:
  ConstantFoldingFusionTest() = default;
};
using MetaGraphTptr = std::shared_ptr<schema::MetaGraphT>;
using CNodeTptr = std::unique_ptr<schema::CNodeT>;

namespace {

MetaGraphTptr BuildGraph(schema::PrimitiveType op_type, void *op_node) {
  auto meta_graph = std::make_shared<schema::MetaGraphT>();
  meta_graph->name = "graph";
  // biasadd node
  auto example_node = std::make_unique<schema::CNodeT>();
  example_node->inputIndex = {0, 1};
  example_node->outputIndex = {2};
  example_node->primitive = std::make_unique<schema::PrimitiveT>();
  example_node->primitive->value.type = op_type;
  example_node->primitive->value.value = op_node;
  example_node->name = "example";
  meta_graph->nodes.emplace_back(std::move(example_node));

  meta_graph->inputIndex = {0, 1};
  meta_graph->outputIndex = {2};

  // input 0: data1
  auto input0 = std::make_unique<schema::TensorT>();
  input0->nodeType = schema::NodeType::NodeType_ValueNode;
  input0->format = schema::Format_NHWC;
  input0->dataType = TypeId::kNumberTypeFloat32;
  input0->dims = {1, 2, 2, 3};
  input0->offset = -1;
  auto input0_data = new(std::nothrow) float[2 * 2 * 3];
  for (auto i = 0; i < 2 * 2 * 3; i++) {
    input0_data[i] = i;
  }
  input0->data.resize(sizeof(float) * 2 * 2 * 3);
  memcpy(input0->data.data(), input0_data, 2 * 2 * 3 * sizeof(float));
  delete[] input0_data;
  meta_graph->allTensors.emplace_back(std::move(input0));

  // input 1: data2
  auto input1 = std::make_unique<schema::TensorT>();
  input1->nodeType = schema::NodeType::NodeType_ValueNode;
  input1->format = schema::Format_NHWC;
  input1->dataType = TypeId::kNumberTypeFloat32;
  input1->dims = {1, 2, 2, 3};
  input1->offset = -1;
  input1->data.resize(sizeof(float) * 2 * 2 * 3);
  auto input1_data = new(std::nothrow) float[2 * 2 * 3];
  for (auto i = 0; i < 2 * 2 * 3; i++) {
    input1_data[i] = i;
  }
  memcpy(input1->data.data(), input1_data, 2 * 2 * 3 * sizeof(float));
  delete[] input1_data;
  meta_graph->allTensors.emplace_back(std::move(input1));

  // final add output
  auto add_output = std::make_unique<schema::TensorT>();
  add_output->nodeType = schema::NodeType::NodeType_Parameter;
  add_output->format = schema::Format_NHWC;
  add_output->dataType = TypeId::kNumberTypeFloat32;
  add_output->dims = {1, 2, 2, 3};
  meta_graph->allTensors.emplace_back(std::move(add_output));
  // final output
  return meta_graph;
}

MetaGraphTptr BuildGraphForOneInput(schema::PrimitiveType op_type, void *op_node) {
  auto meta_graph = std::make_shared<schema::MetaGraphT>();
  meta_graph->name = "graph";
  // biasadd node
  auto example_node = std::make_unique<schema::CNodeT>();
  example_node->inputIndex = {0};
  example_node->outputIndex = {1};
  example_node->primitive = std::make_unique<schema::PrimitiveT>();
  example_node->primitive->value.type = op_type;
  example_node->primitive->value.value = op_node;
  example_node->name = "example";
  meta_graph->nodes.emplace_back(std::move(example_node));

  meta_graph->inputIndex = {0};
  meta_graph->outputIndex = {1};

  // input 0: data1
  auto input0 = std::make_unique<schema::TensorT>();
  input0->nodeType = schema::NodeType::NodeType_ValueNode;
  input0->format = schema::Format_NHWC;
  input0->dataType = TypeId::kNumberTypeFloat32;
  input0->dims = {1, 2, 2, 3};
  input0->offset = -1;
  auto input0_data = new(std::nothrow) float[2 * 2 * 3];
  for (auto i = 0; i < 2 * 2 * 3; i++) {
    input0_data[i] = i + 1;
  }
  input0->data.resize(sizeof(float) * 2 * 2 * 3);
  memcpy(input0->data.data(), input0_data, 2 * 2 * 3 * sizeof(float));
  delete[] input0_data;
  meta_graph->allTensors.emplace_back(std::move(input0));

  // final add output
  auto add_output = std::make_unique<schema::TensorT>();
  add_output->nodeType = schema::NodeType::NodeType_Parameter;
  add_output->format = schema::Format_NHWC;
  add_output->dataType = TypeId::kNumberTypeFloat32;
  add_output->dims = {1, 2, 2, 3};
  meta_graph->allTensors.emplace_back(std::move(add_output));

  // final output
  return meta_graph;
}

MetaGraphTptr BuildMixGraph() {
  auto meta_graph = std::make_shared<schema::MetaGraphT>();
  meta_graph->name = "graph";
  // add node
  auto add_node = std::make_unique<schema::CNodeT>();
  add_node->inputIndex = {0, 1};
  add_node->outputIndex = {2};
  add_node->primitive = std::make_unique<schema::PrimitiveT>();
  add_node->primitive->value.type = schema::PrimitiveType_Add;
  add_node->primitive->value.value = new schema::AddT;
  add_node->name = "add";
  meta_graph->nodes.emplace_back(std::move(add_node));

  meta_graph->inputIndex = {0, 1, 2};
  meta_graph->outputIndex = {4};

  auto mul_node = std::make_unique<schema::CNodeT>();
  mul_node->inputIndex = {2, 3};
  mul_node->outputIndex = {4};
  mul_node->primitive = std::make_unique<schema::PrimitiveT>();
  mul_node->primitive->value.type = schema::PrimitiveType_Mul;
  mul_node->primitive->value.value = new schema::MulT;
  mul_node->name = "mul";
  meta_graph->nodes.emplace_back(std::move(mul_node));

  // input 0: data1
  auto input0 = std::make_unique<schema::TensorT>();
  input0->nodeType = schema::NodeType::NodeType_ValueNode;
  input0->format = schema::Format_NHWC;
  input0->dataType = TypeId::kNumberTypeFloat32;
  input0->dims = {1, 2, 2, 3};
  input0->offset = -1;
  auto input0_data = new(std::nothrow) float[2 * 2 * 3];
  for (auto i = 0; i < 2 * 2 * 3; i++) {
    input0_data[i] = i;
  }
  input0->data.resize(sizeof(float) * 2 * 2 * 3);
  memcpy(input0->data.data(), input0_data, 2 * 2 * 3 * sizeof(float));
  delete[] input0_data;
  meta_graph->allTensors.emplace_back(std::move(input0));

  // input 1: data2
  auto input1 = std::make_unique<schema::TensorT>();
  input1->nodeType = schema::NodeType::NodeType_ValueNode;
  input1->format = schema::Format_NHWC;
  input1->dataType = TypeId::kNumberTypeFloat32;
  input1->dims = {1, 2, 2, 3};
  input1->offset = -1;
  input1->data.resize(sizeof(float) * 2 * 2 * 3);
  auto input1_data = new(std::nothrow) float[2 * 2 * 3];
  for (auto i = 0; i < 2 * 2 * 3; i++) {
    input1_data[i] = i;
  }
  memcpy(input1->data.data(), input1_data, 2 * 2 * 3 * sizeof(float));
  delete[] input1_data;
  meta_graph->allTensors.emplace_back(std::move(input1));

  // addoutput
  auto add_output = std::make_unique<schema::TensorT>();
  add_output->nodeType = schema::NodeType::NodeType_Parameter;
  add_output->format = schema::Format_NHWC;
  add_output->dataType = TypeId::kNumberTypeFloat32;
  add_output->dims = {1, 2, 2, 3};
  add_output->offset = -1;
  add_output->data.resize(sizeof(float) * 2 * 2 * 3);
  auto add_output_data = new(std::nothrow) float[2 * 2 * 3];
  memcpy(add_output->data.data(), add_output_data, 2 * 2 * 3 * sizeof(float));
  delete[] add_output_data;
  meta_graph->allTensors.emplace_back(std::move(add_output));

  // input 2: data3
  auto input2 = std::make_unique<schema::TensorT>();
  input2->nodeType = schema::NodeType::NodeType_ValueNode;
  input2->format = schema::Format_NHWC;
  input2->dataType = TypeId::kNumberTypeFloat32;
  input2->dims = {1, 2, 2, 3};
  input2->offset = -1;
  input2->data.resize(sizeof(float) * 2 * 2 * 3);
  auto input2_data = new(std::nothrow) float[2 * 2 * 3];
  for (auto i = 0; i < 2 * 2 * 3; i++) {
    input2_data[i] = 10;
  }
  memcpy(input2->data.data(), input2_data, 2 * 2 * 3 * sizeof(float));
  delete[] input2_data;
  meta_graph->allTensors.emplace_back(std::move(input2));

  // final mul output
  auto mul_output = std::make_unique<schema::TensorT>();
  mul_output->nodeType = schema::NodeType::NodeType_Parameter;
  mul_output->format = schema::Format_NHWC;
  mul_output->dataType = TypeId::kNumberTypeFloat32;
  mul_output->dims = {1, 2, 2, 3};
  meta_graph->allTensors.emplace_back(std::move(mul_output));
  // final output
  return meta_graph;
}
MetaGraphTptr BuildSplitGraph() {
  auto meta_graph = std::make_shared<schema::MetaGraphT>();
  meta_graph->name = "graph";
  // slice node
  auto split_node = std::make_unique<schema::CNodeT>();
  split_node->inputIndex = {0};
  split_node->outputIndex = {1, 2};
  split_node->primitive = std::make_unique<schema::PrimitiveT>();
  split_node->primitive->value.type = schema::PrimitiveType_Split;
  std::unique_ptr<schema::SplitT> attr = std::make_unique<schema::SplitT>();
  attr->numberSplit = 2;
  attr->splitDim = 1;
  split_node->primitive->value.value = attr.release();
  split_node->name = "split";
  meta_graph->nodes.emplace_back(std::move(split_node));

  meta_graph->inputIndex = {0, 3, 4};
  meta_graph->outputIndex = {5, 6};

  auto mul_node1 = std::make_unique<schema::CNodeT>();
  mul_node1->inputIndex = {1, 3};
  mul_node1->outputIndex = {5};
  mul_node1->primitive = std::make_unique<schema::PrimitiveT>();
  mul_node1->primitive->value.type = schema::PrimitiveType_Mul;
  std::unique_ptr<schema::MulT> mul_attr = std::make_unique<schema::MulT>();
  mul_node1->primitive->value.value = mul_attr.release();
  mul_node1->name = "mul1";
  meta_graph->nodes.emplace_back(std::move(mul_node1));

  auto mul_node2 = std::make_unique<schema::CNodeT>();
  mul_node2->inputIndex = {2, 4};
  mul_node2->outputIndex = {6};
  mul_node2->primitive = std::make_unique<schema::PrimitiveT>();
  mul_node2->primitive->value.type = schema::PrimitiveType_Mul;
  std::unique_ptr<schema::MulT> mul2_attr = std::make_unique<schema::MulT>();
  mul_node2->primitive->value.value = mul2_attr.release();
  mul_node2->name = "mul2";
  meta_graph->nodes.emplace_back(std::move(mul_node2));

  // input 0: data1
  auto input0 = std::make_unique<schema::TensorT>();
  input0->nodeType = schema::NodeType::NodeType_ValueNode;
  input0->format = schema::Format_NHWC;
  input0->dataType = TypeId::kNumberTypeFloat32;
  input0->dims = {1, 2, 2, 3};
  input0->offset = -1;
  auto input0_data = new(std::nothrow) float[2 * 2 * 3];
  for (auto i = 0; i < 2 * 2 * 3; i++) {
    input0_data[i] = i;
  }
  input0->data.resize(sizeof(float) * 2 * 2 * 3);
  memcpy(input0->data.data(), input0_data, 2 * 2 * 3 * sizeof(float));
  delete[] input0_data;
  meta_graph->allTensors.emplace_back(std::move(input0));

  // split output1
  auto split_output1 = std::make_unique<schema::TensorT>();
  split_output1->nodeType = schema::NodeType::NodeType_Parameter;
  split_output1->format = schema::Format_NHWC;
  split_output1->dataType = TypeId::kNumberTypeFloat32;
  split_output1->dims = {1, 1, 2, 3};
  split_output1->offset = -1;
  split_output1->data.resize(sizeof(float) * 1 * 2 * 3);
  auto split_output_data1 = new(std::nothrow) float[1 * 2 * 3];
  memcpy(split_output1->data.data(), split_output_data1, 1 * 2 * 3 * sizeof(float));
  delete[] split_output_data1;
  meta_graph->allTensors.emplace_back(std::move(split_output1));

  // split output2
  auto split_output2 = std::make_unique<schema::TensorT>();
  split_output2->nodeType = schema::NodeType::NodeType_Parameter;
  split_output2->format = schema::Format_NHWC;
  split_output2->dataType = TypeId::kNumberTypeFloat32;
  split_output2->dims = {1, 1, 2, 3};
  split_output2->offset = -1;
  split_output2->data.resize(sizeof(float) * 1 * 2 * 3);
  auto split_output_data2 = new(std::nothrow) float[1 * 2 * 3];
  memcpy(split_output2->data.data(), split_output_data2, 1 * 2 * 3 * sizeof(float));
  delete[] split_output_data2;
  meta_graph->allTensors.emplace_back(std::move(split_output2));

  // input 1: data2
  auto input1 = std::make_unique<schema::TensorT>();
  input1->nodeType = schema::NodeType::NodeType_ValueNode;
  input1->format = schema::Format_NHWC;
  input1->dataType = TypeId::kNumberTypeFloat32;
  input1->dims = {1, 1, 2, 3};
  input1->offset = -1;
  input1->data.resize(sizeof(float) * 2 * 3);
  auto input1_data = new(std::nothrow) float[2 * 3];
  for (auto i = 0; i < 2 * 3; i++) {
    input1_data[i] = i;
  }
  memcpy(input1->data.data(), input1_data, 2 * 3 * sizeof(float));
  delete[] input1_data;
  meta_graph->allTensors.emplace_back(std::move(input1));

  // input 2: data3
  auto input2 = std::make_unique<schema::TensorT>();
  input2->nodeType = schema::NodeType::NodeType_ValueNode;
  input2->format = schema::Format_NHWC;
  input2->dataType = TypeId::kNumberTypeFloat32;
  input2->dims = {1, 1, 2, 3};
  input2->offset = -1;
  input2->data.resize(sizeof(float) * 2 * 3);
  auto input2_data = new(std::nothrow) float[2 * 3];
  for (auto i = 0; i < 2 * 3; i++) {
    input2_data[i] = 10;
  }
  memcpy(input2->data.data(), input2_data, 2 * 3 * sizeof(float));
  delete[] input2_data;
  meta_graph->allTensors.emplace_back(std::move(input2));

  // final mul output1
  auto mul_output = std::make_unique<schema::TensorT>();
  mul_output->nodeType = schema::NodeType::NodeType_Parameter;
  mul_output->format = schema::Format_NHWC;
  mul_output->dataType = TypeId::kNumberTypeFloat32;
  mul_output->dims = {1, 1, 2, 3};
  meta_graph->allTensors.emplace_back(std::move(mul_output));

  // final mul output2
  auto mul_output2 = std::make_unique<schema::TensorT>();
  mul_output2->nodeType = schema::NodeType::NodeType_Parameter;
  mul_output2->format = schema::Format_NHWC;
  mul_output2->dataType = TypeId::kNumberTypeFloat32;
  mul_output2->dims = {1, 1, 2, 3};
  meta_graph->allTensors.emplace_back(std::move(mul_output2));
  return meta_graph;
}
}  //  namespace
TEST_F(ConstantFoldingFusionTest, TestADDConstantFold) {
  auto meta_graph = BuildGraph(schema::PrimitiveType_Add, new schema::AddT);
  auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
  auto optimizer = std::make_shared<opt::GraphOptimizer>();
  auto pm = std::make_shared<opt::PassManager>();
  pm->AddPass(std::make_shared<opt::ConstFoldPass>());
  optimizer->AddPassManager(pm);
  FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
  ASSERT_NE(nullptr, new_graph);
  auto new_meta_graph = lite::Export(new_graph);
  ASSERT_EQ(new_meta_graph->nodes.size(), 0);
}

TEST_F(ConstantFoldingFusionTest, TestMixedConstantFold) {
  auto meta_graph = BuildMixGraph();
  auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
  auto optimizer = std::make_shared<opt::GraphOptimizer>();
  auto pm = std::make_shared<opt::PassManager>();
  pm->AddPass(std::make_shared<opt::ConstFoldPass>());
  optimizer->AddPassManager(pm);
  FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
  ASSERT_NE(nullptr, new_graph);
  auto new_meta_graph = lite::Export(new_graph);
  ASSERT_EQ(new_meta_graph->nodes.size(), 0);
}

TEST_F(ConstantFoldingFusionTest, TestSubConstantFold) {
  auto meta_graph = BuildGraph(schema::PrimitiveType_Sub, new schema::SubT);
  auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
  auto optimizer = std::make_shared<opt::GraphOptimizer>();
  auto pm = std::make_shared<opt::PassManager>();
  pm->AddPass(std::make_shared<opt::ConstFoldPass>());
  optimizer->AddPassManager(pm);
  FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
  ASSERT_NE(nullptr, new_graph);
  auto new_meta_graph = lite::Export(new_graph);
  ASSERT_EQ(new_meta_graph->nodes.size(), 0);
}

TEST_F(ConstantFoldingFusionTest, TestMulConstantFold) {
  auto meta_graph = BuildGraph(schema::PrimitiveType_Mul, new schema::MulT);
  auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
  auto optimizer = std::make_shared<opt::GraphOptimizer>();
  auto pm = std::make_shared<opt::PassManager>();
  pm->AddPass(std::make_shared<opt::ConstFoldPass>());
  optimizer->AddPassManager(pm);
  FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
  ASSERT_NE(nullptr, new_graph);
  auto new_meta_graph = lite::Export(new_graph);
  ASSERT_EQ(new_meta_graph->nodes.size(), 0);
}

TEST_F(ConstantFoldingFusionTest, TestTransposeConstantFold) {
  auto transposeT = new schema::TransposeT;
  transposeT->perm = {3, 0, 1, 2};
  auto meta_graph = BuildGraph(schema::PrimitiveType_Transpose, transposeT);
  auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
  auto optimizer = std::make_shared<opt::GraphOptimizer>();
  auto pm = std::make_shared<opt::PassManager>();
  pm->AddPass(std::make_shared<opt::ConstFoldPass>());
  optimizer->AddPassManager(pm);
  FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
  ASSERT_NE(nullptr, new_graph);
  auto new_meta_graph = lite::Export(new_graph);
  ASSERT_EQ(new_meta_graph->nodes.size(), 0);
}

TEST_F(ConstantFoldingFusionTest, TestTileConstantFold) {
  auto tileT = new schema::TileT;
  tileT->multiples = {1, 2, 2, 2};
  auto meta_graph = BuildGraph(schema::PrimitiveType_Tile, tileT);
  auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
  auto optimizer = std::make_shared<opt::GraphOptimizer>();
  auto pm = std::make_shared<opt::PassManager>();
  pm->AddPass(std::make_shared<opt::ConstFoldPass>());
  optimizer->AddPassManager(pm);
  FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
  ASSERT_NE(nullptr, new_graph);
  auto new_meta_graph = lite::Export(new_graph);
  ASSERT_EQ(new_meta_graph->nodes.size(), 0);
}

TEST_F(ConstantFoldingFusionTest, TestStridedSliceConstantFold) {
  auto stridedSliceT = new schema::StridedSliceT;
  stridedSliceT->begin = {1};
  stridedSliceT->end = {3};
  stridedSliceT->stride = {1};
  auto meta_graph = BuildGraphForOneInput(schema::PrimitiveType_StridedSlice, stridedSliceT);
  auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
  auto optimizer = std::make_shared<opt::GraphOptimizer>();
  auto pm = std::make_shared<opt::PassManager>();
  pm->AddPass(std::make_shared<opt::ConstFoldPass>());
  optimizer->AddPassManager(pm);
  FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
  ASSERT_NE(nullptr, new_graph);
  auto new_meta_graph = lite::Export(new_graph);
  ASSERT_EQ(new_meta_graph->nodes.size(), 0);
}

TEST_F(ConstantFoldingFusionTest, TestStackConstantFold) {
  auto stackT = new schema::StackT;
  stackT->axis = 1;
  auto meta_graph = BuildGraph(schema::PrimitiveType_Stack, stackT);
  auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
  auto optimizer = std::make_shared<opt::GraphOptimizer>();
  auto pm = std::make_shared<opt::PassManager>();
  pm->AddPass(std::make_shared<opt::ConstFoldPass>());
  optimizer->AddPassManager(pm);
  FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
  ASSERT_NE(nullptr, new_graph);
  auto new_meta_graph = lite::Export(new_graph);
  ASSERT_EQ(new_meta_graph->nodes.size(), 0);
}

TEST_F(ConstantFoldingFusionTest, TestSliceConstantFold) {
  auto sliceT = new schema::SliceT;
  auto meta_graph = BuildGraph(schema::PrimitiveType_Slice, sliceT);
  auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
  auto optimizer = std::make_shared<opt::GraphOptimizer>();
  auto pm = std::make_shared<opt::PassManager>();
  pm->AddPass(std::make_shared<opt::ConstFoldPass>());
  optimizer->AddPassManager(pm);
  FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
  ASSERT_NE(nullptr, new_graph);
  auto new_meta_graph = lite::Export(new_graph);
  ASSERT_EQ(new_meta_graph->nodes.size(), 0);
}

TEST_F(ConstantFoldingFusionTest, TestShapeConstantFold) {
  auto shapeT = new schema::ShapeT;
  auto meta_graph = BuildGraphForOneInput(schema::PrimitiveType_Shape, shapeT);
  auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
  auto optimizer = std::make_shared<opt::GraphOptimizer>();
  auto pm = std::make_shared<opt::PassManager>();
  pm->AddPass(std::make_shared<opt::ConstFoldPass>());
  optimizer->AddPassManager(pm);
  FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
  ASSERT_NE(nullptr, new_graph);
  auto new_meta_graph = lite::Export(new_graph);
  ASSERT_EQ(new_meta_graph->nodes.size(), 0);
}

TEST_F(ConstantFoldingFusionTest, TestRsqrtConstantFold) {
  auto rsqrtT = new schema::RsqrtT;
  auto meta_graph = BuildGraphForOneInput(schema::PrimitiveType_Rsqrt, rsqrtT);
  auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
  auto optimizer = std::make_shared<opt::GraphOptimizer>();
  auto pm = std::make_shared<opt::PassManager>();
  pm->AddPass(std::make_shared<opt::ConstFoldPass>());
  optimizer->AddPassManager(pm);
  FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
  ASSERT_NE(nullptr, new_graph);
  auto new_meta_graph = lite::Export(new_graph);
  ASSERT_EQ(new_meta_graph->nodes.size(), 0);
}

TEST_F(ConstantFoldingFusionTest, TestReshapeConstantFold) {
  auto reshapeT = new schema::ReshapeT;
  reshapeT->shape = {2, 6};
  auto meta_graph = BuildGraphForOneInput(schema::PrimitiveType_Reshape, reshapeT);
  auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
  auto optimizer = std::make_shared<opt::GraphOptimizer>();
  auto pm = std::make_shared<opt::PassManager>();
  pm->AddPass(std::make_shared<opt::ConstFoldPass>());
  optimizer->AddPassManager(pm);
  FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
  ASSERT_NE(nullptr, new_graph);
  auto new_meta_graph = lite::Export(new_graph);
  ASSERT_EQ(new_meta_graph->nodes.size(), 0);
}

TEST_F(ConstantFoldingFusionTest, TestRangeConstantFold) {
  auto rangeT = new schema::RangeT;
  rangeT->limit = 10;
  rangeT->start = 1;
  rangeT->delta = 1;
  auto meta_graph = BuildGraphForOneInput(schema::PrimitiveType_Range, rangeT);
  auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
  auto optimizer = std::make_shared<opt::GraphOptimizer>();
  auto pm = std::make_shared<opt::PassManager>();
  pm->AddPass(std::make_shared<opt::ConstFoldPass>());
  optimizer->AddPassManager(pm);
  FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
  ASSERT_NE(nullptr, new_graph);
  auto new_meta_graph = lite::Export(new_graph);
  ASSERT_EQ(new_meta_graph->nodes.size(), 0);
}
TEST_F(ConstantFoldingFusionTest, TestMatmulConstantFold) {
  auto matmulT = new schema::MatMulT;
  auto meta_graph = BuildGraph(schema::PrimitiveType_MatMul, matmulT);
  auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
  auto optimizer = std::make_shared<opt::GraphOptimizer>();
  auto pm = std::make_shared<opt::PassManager>();
  pm->AddPass(std::make_shared<opt::ConstFoldPass>());
  optimizer->AddPassManager(pm);
  FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
  ASSERT_NE(nullptr, new_graph);
  auto new_meta_graph = lite::Export(new_graph);
  ASSERT_EQ(new_meta_graph->nodes.size(), 0);
}

TEST_F(ConstantFoldingFusionTest, TestExpandDimsConstantFold) {
  auto expandDimsT = new schema::ExpandDimsT;
  auto meta_graph = BuildGraphForOneInput(schema::PrimitiveType_ExpandDims, expandDimsT);
  auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
  auto optimizer = std::make_shared<opt::GraphOptimizer>();
  auto pm = std::make_shared<opt::PassManager>();
  pm->AddPass(std::make_shared<opt::ConstFoldPass>());
  optimizer->AddPassManager(pm);
  FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
  ASSERT_NE(nullptr, new_graph);
  auto new_meta_graph = lite::Export(new_graph);
  ASSERT_EQ(new_meta_graph->nodes.size(), 0);
}

TEST_F(ConstantFoldingFusionTest, TestConcatDimsConstantFold) {
  auto concatT = new schema::ConcatT;
  auto meta_graph = BuildGraph(schema::PrimitiveType_Concat, concatT);
  auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
  auto optimizer = std::make_shared<opt::GraphOptimizer>();
  auto pm = std::make_shared<opt::PassManager>();
  pm->AddPass(std::make_shared<opt::ConstFoldPass>());
  optimizer->AddPassManager(pm);
  FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
  ASSERT_NE(nullptr, new_graph);
  auto new_meta_graph = lite::Export(new_graph);
  ASSERT_EQ(new_meta_graph->nodes.size(), 0);
}

TEST_F(ConstantFoldingFusionTest, TestCastDimsConstantFold) {
  auto castT = new schema::CastT;
  castT->srcT = kNumberTypeUInt8;
  castT->dstT = kNumberTypeFloat32;
  auto meta_graph = BuildGraphForOneInput(schema::PrimitiveType_Cast, castT);
  auto input_tensor = meta_graph->allTensors.at(0).get();
  input_tensor->dataType = kNumberTypeUInt8;
  auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
  auto optimizer = std::make_shared<opt::GraphOptimizer>();
  auto pm = std::make_shared<opt::PassManager>();
  pm->AddPass(std::make_shared<opt::ConstFoldPass>());
  optimizer->AddPassManager(pm);
  FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
  ASSERT_NE(nullptr, new_graph);
  auto new_meta_graph = lite::Export(new_graph);
  ASSERT_EQ(new_meta_graph->nodes.size(), 0);
}

TEST_F(ConstantFoldingFusionTest, TestSplitConstantFold) {
  auto meta_graph = BuildSplitGraph();
  auto input_tensor = meta_graph->allTensors.at(0).get();
  input_tensor->dataType = kNumberTypeFloat32;
  auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
  auto optimizer = std::make_shared<opt::GraphOptimizer>();
  auto pm = std::make_shared<opt::PassManager>("test", false);
  pm->AddPass(std::make_shared<opt::ConstFoldPass>());
  optimizer->AddPassManager(pm);
  FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
  ASSERT_NE(nullptr, new_graph);
  auto new_meta_graph = lite::Export(new_graph);
  ASSERT_EQ(new_meta_graph->nodes.size(), 0);
}
}  // namespace mindspore


In [None]:
def reduce_fn(vals):
    return sum(vals) / len(vals)

In [None]:
from transformers import *
import tokenizers

In [None]:
!mkdir -p ./input/roberta-base

In [None]:
save_path = './input/roberta-base'
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaModel.from_pretrained('roberta-base')
config = RobertaConfig.from_pretrained('roberta-base')
tokenizer.save_vocabulary(save_path)
model.save_pretrained(save_path)
config.save_pretrained(save_path)

In [None]:
class config:
    FOLD = 0
    LEARNING_RATE = 0.2 * 3e-5
    MAX_LEN = 192
    TRAIN_BATCH_SIZE = 16
    VALID_BATCH_SIZE = 8
    EPOCHS = 3
    TRAINING_FILE = "./tweet-sentiment/train_folds.csv"
    ROBERTA_PATH = "./input/roberta-base"
    TOKENIZER = tokenizers.ByteLevelBPETokenizer(
        vocab_file=f"{ROBERTA_PATH}/vocab.json", 
        merges_file=f"{ROBERTA_PATH}/merges.txt", 
        lowercase=True,
        add_prefix_space=True
    )

#Data process

In [None]:
def process_data(tweet, selected_text, sentiment, tokenizer, max_len):
    tweet = " " + " ".join(str(tweet).split())
    selected_text = " " + " ".join(str(selected_text).split())

    len_st = len(selected_text) - 1
    idx0 = None
    idx1 = None

    for ind in (i for i, e in enumerate(tweet) if e == selected_text[1]):
        if " " + tweet[ind: ind+len_st] == selected_text:
            idx0 = ind
            idx1 = ind + len_st - 1
            break

    char_targets = [0] * len(tweet)
    if idx0 != None and idx1 != None:
        for ct in range(idx0, idx1 + 1):
            char_targets[ct] = 1
    
    tok_tweet = tokenizer.encode(tweet)
    input_ids_orig = tok_tweet.ids
    tweet_offsets = tok_tweet.offsets
    
    target_idx = []
    for j, (offset1, offset2) in enumerate(tweet_offsets):
        if sum(char_targets[offset1: offset2]) > 0:
            target_idx.append(j)
    
    targets_start = target_idx[0]
    targets_end = target_idx[-1]

    sentiment_id = {
        'positive': 1313,
        'negative': 2430,
        'neutral': 7974
    }
    
    input_ids = [0] + [sentiment_id[sentiment]] + [2] + [2] + input_ids_orig + [2]
    token_type_ids = [0, 0, 0, 0] + [0] * (len(input_ids_orig) + 1)
    mask = [1] * len(token_type_ids)
    tweet_offsets = [(0, 0)] * 4 + tweet_offsets + [(0, 0)]
    targets_start += 4
    targets_end += 4

    padding_length = max_len - len(input_ids)
    if padding_length > 0:
        input_ids = input_ids + ([1] * padding_length)
        mask = mask + ([0] * padding_length)
        token_type_ids = token_type_ids + ([0] * padding_length)
        tweet_offsets = tweet_offsets + ([(0, 0)] * padding_length)
    
    return {
        'ids': input_ids,
        'mask': mask,
        'token_type_ids': token_type_ids,
        'targets_start': targets_start,
        'targets_end': targets_end,
        'orig_tweet': tweet,
        'orig_selected': selected_text,
        'sentiment': sentiment,
        'offsets': tweet_offsets
    }

#Data loader 

In [None]:
class TweetDataset:
    def __init__(self, tweet, sentiment, selected_text):
        self.tweet = tweet
        self.sentiment = sentiment
        self.selected_text = selected_text
        self.tokenizer = config.TOKENIZER
        self.max_len = config.MAX_LEN
    
    def __len__(self):
        return len(self.tweet)

    def __getitem__(self, item):
        data = process_data(
            self.tweet[item], 
            self.selected_text[item], 
            self.sentiment[item],
            self.tokenizer,
            self.max_len
        )

        return {
            'ids': torch.tensor(data["ids"], dtype=torch.long),
            'mask': torch.tensor(data["mask"], dtype=torch.long),
            'token_type_ids': torch.tensor(data["token_type_ids"], dtype=torch.long),
            'targets_start': torch.tensor(data["targets_start"], dtype=torch.long),
            'targets_end': torch.tensor(data["targets_end"], dtype=torch.long),
            'orig_tweet': data["orig_tweet"],
            'orig_selected': data["orig_selected"],
            'sentiment': data["sentiment"],
            'offsets': torch.tensor(data["offsets"], dtype=torch.long)
        }


In [None]:
class TweetModel(transformers.BertPreTrainedModel):
    def __init__(self, conf):
        super(TweetModel, self).__init__(conf)
        self.roberta = transformers.RobertaModel.from_pretrained(config.ROBERTA_PATH, config=conf)
        self.drop_out = nn.Dropout(0.1)
        self.l0 = nn.Linear(768 * 2, 2)
        torch.nn.init.normal_(self.l0.weight, std=0.02)
    
    def forward(self, ids, mask, token_type_ids):
        _, _, out = self.roberta(
            ids,
            attention_mask=mask,
            token_type_ids=token_type_ids
        )

        out = torch.cat((out[-1], out[-2]), dim=-1)
        out = self.drop_out(out)
        logits = self.l0(out)

        start_logits, end_logits = logits.split(1, dim=-1)

        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        return start_logits, end_logits

In [None]:
def loss_fn(start_logits, end_logits, start_positions, end_positions):
    loss_fct = nn.CrossEntropyLoss()
    start_loss = loss_fct(start_logits, start_positions)
    end_loss = loss_fct(end_logits, end_positions)
    total_loss = (start_loss + end_loss)
    return total_loss

In [None]:
def train_fn(data_loader, model, optimizer, device, num_batches, scheduler=None):
    model.train()
    tk0 = tqdm(data_loader, total=num_batches, desc="Training", disable=not xm.is_master_ordinal())
    for bi, d in enumerate(tk0):
        ids = d["ids"]
        token_type_ids = d["token_type_ids"]
        mask = d["mask"]
        targets_start = d["targets_start"]
        targets_end = d["targets_end"]
        sentiment = d["sentiment"]
        orig_selected = d["orig_selected"]
        orig_tweet = d["orig_tweet"]
        targets_start = d["targets_start"]
        targets_end = d["targets_end"]
        offsets = d["offsets"]

        ids = ids.to(device, dtype=torch.long)
        token_type_ids = token_type_ids.to(device, dtype=torch.long)
        mask = mask.to(device, dtype=torch.long)
        targets_start = targets_start.to(device, dtype=torch.long)
        targets_end = targets_end.to(device, dtype=torch.long)

        model.zero_grad()
        outputs_start, outputs_end = model(
            ids=ids,
            mask=mask,
            token_type_ids=token_type_ids,
        )
        loss = loss_fn(outputs_start, outputs_end, targets_start, targets_end)
        loss.backward()
        xm.optimizer_step(optimizer)
        scheduler.step()
        print_loss = xm.mesh_reduce('loss_reduce', loss, reduce_fn)
        tk0.set_postfix(loss=print_loss.item())

In [None]:
def calculate_jaccard_score(
    original_tweet, 
    target_string, 
    sentiment_val, 
    idx_start, 
    idx_end, 
    offsets,
    verbose=False):
    
    if idx_end < idx_start:
        idx_end = idx_start
    
    filtered_output  = ""
    for ix in range(idx_start, idx_end + 1):
        filtered_output += original_tweet[offsets[ix][0]: offsets[ix][1]]
        if (ix+1) < len(offsets) and offsets[ix][1] < offsets[ix+1][0]:
            filtered_output += " "

    if len(original_tweet.split()) < 2:
        filtered_output = original_tweet

    jac = jaccard(target_string.strip(), filtered_output.strip())
    return jac, filtered_output


def eval_fn(data_loader, model, device):
    model.eval()
    losses = AverageMeter()
    jaccards = AverageMeter()
    
    with torch.no_grad():
        for bi, d in enumerate(data_loader):
            ids = d["ids"]
            token_type_ids = d["token_type_ids"]
            mask = d["mask"]
            sentiment = d["sentiment"]
            orig_selected = d["orig_selected"]
            orig_tweet = d["orig_tweet"]
            targets_start = d["targets_start"]
            targets_end = d["targets_end"]
            offsets = d["offsets"].cpu().numpy()

            ids = ids.to(device, dtype=torch.long)
            token_type_ids = token_type_ids.to(device, dtype=torch.long)
            mask = mask.to(device, dtype=torch.long)
            targets_start = targets_start.to(device, dtype=torch.long)
            targets_end = targets_end.to(device, dtype=torch.long)

            outputs_start, outputs_end = model(
                ids=ids,
                mask=mask,
                token_type_ids=token_type_ids
            )
            loss = loss_fn(outputs_start, outputs_end, targets_start, targets_end)
            outputs_start = torch.softmax(outputs_start, dim=1).cpu().detach().numpy()
            outputs_end = torch.softmax(outputs_end, dim=1).cpu().detach().numpy()
            jaccard_scores = []
            for px, tweet in enumerate(orig_tweet):
                selected_tweet = orig_selected[px]
                tweet_sentiment = sentiment[px]
                jaccard_score, _ = calculate_jaccard_score(
                    original_tweet=tweet,
                    target_string=selected_tweet,
                    sentiment_val=tweet_sentiment,
                    idx_start=np.argmax(outputs_start[px, :]),
                    idx_end=np.argmax(outputs_end[px, :]),
                    offsets=offsets[px]
                )
                jaccard_scores.append(jaccard_score)

            jaccards.update(np.mean(jaccard_scores), ids.size(0))
            losses.update(loss.item(), ids.size(0))

    return jaccards.avg

In [None]:
model_config = transformers.RobertaConfig.from_pretrained(config.ROBERTA_PATH)
model_config.output_hidden_states = True
MX = TweetModel(conf=model_config)

dfx = pd.read_csv(config.TRAINING_FILE)

df_train = dfx[dfx.kfold != config.FOLD].reset_index(drop=True)
df_valid = dfx[dfx.kfold == config.FOLD].reset_index(drop=True)

training

In [None]:
def run():
    device = xm.xla_device()
    model = MX.to(device)

    train_dataset = TweetDataset(
        tweet=df_train.text.values,
        sentiment=df_train.sentiment.values,
        selected_text=df_train.selected_text.values
    )

    train_sampler = torch.utils.data.distributed.DistributedSampler(
      train_dataset,
      num_replicas=xm.xrt_world_size(),
      rank=xm.get_ordinal(),
      shuffle=True
    )

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN_BATCH_SIZE,
        sampler=train_sampler,
        drop_last=True,
        num_workers=2
    )

    valid_dataset = TweetDataset(
        tweet=df_valid.text.values,
        sentiment=df_valid.sentiment.values,
        selected_text=df_valid.selected_text.values
    )

    valid_sampler = torch.utils.data.distributed.DistributedSampler(
      valid_dataset,
      num_replicas=xm.xrt_world_size(),
      rank=xm.get_ordinal(),
      shuffle=False
    )

    valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.VALID_BATCH_SIZE,
        sampler=valid_sampler,
        drop_last=False,
        num_workers=1
    )

    num_train_steps = int(len(df_train) / config.TRAIN_BATCH_SIZE * config.EPOCHS)
    param_optimizer = list(model.named_parameters())
    no_decay = [
        "bias",
        "LayerNorm.bias",
        "LayerNorm.weight"
    ]
    optimizer_parameters = [
        {
            'params': [
                p for n, p in param_optimizer if not any(
                    nd in n for nd in no_decay
                )
            ], 
         'weight_decay': 0.001
        },
        {
            'params': [
                p for n, p in param_optimizer if any(
                    nd in n for nd in no_decay
                )
            ], 
            'weight_decay': 0.0
        },
    ]
    num_train_steps = int(
        len(df_train) / config.TRAIN_BATCH_SIZE / xm.xrt_world_size() * config.EPOCHS
    )
    optimizer = AdamW(
        optimizer_parameters, 
        lr=config.LEARNING_RATE * xm.xrt_world_size()
    )
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=num_train_steps
    )

    best_jac = 0
    es = EarlyStopping(patience=2, mode="max")
    num_batches = int(len(df_train) / (config.TRAIN_BATCH_SIZE * xm.xrt_world_size()))
    
    xm.master_print("Training is Starting....")

    for epoch in range(config.EPOCHS):
        para_loader = pl.ParallelLoader(train_data_loader, [device])
        train_fn(
            para_loader.per_device_loader(device), 
            model, 
            optimizer, 
            device,
            num_batches,
            scheduler
        )

        para_loader = pl.ParallelLoader(valid_data_loader, [device])
        jac = eval_fn(
            para_loader.per_device_loader(device), 
            model, 
            device
        )
        jac = xm.mesh_reduce('jac_reduce', jac, reduce_fn)
        xm.master_print(f'Epoch={epoch}, Jaccard={jac}')
        if jac > best_jac:
            xm.master_print("Model Improved!!! Saving Model")
            xm.save(model.state_dict(), f"model_{config.FOLD}.bin")
            best_jac = jac

In [None]:
def _mp_fn(rank, flags):
    torch.set_default_tensor_type('torch.FloatTensor')
    a = run()

In [None]:
FLAGS={}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')

Training is Starting....


HBox(children=(FloatProgress(value=0.0, description='Training', max=171.0, style=ProgressStyle(description_wid…