<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "tools/anf_importer/import_from_protobuf.h"

#include <fcntl.h>
#include <unistd.h>

#include <fstream>
#include <map>
#include <memory>
#include <stack>
#include <unordered_map>
#include <vector>
#include "src/ops/primitive_c.h"
#include "frontend/operator/ops.h"
#include "include/errorcode.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "schema/inner/model_generated.h"
#include "securec/include/securec.h"
#include "src/ir/tensor.h"
#include "src/param_value_lite.h"
#include "tools/converter/parser/onnx/onnx.pb.h"
#include "utils/log_adapter.h"
#include "tools/common/protobuf_utils.h"

using string = std::string;
using int32 = int32_t;
using int64 = int64_t;
using uint64 = uint64_t;

namespace mindspore::lite {

static constexpr char kConstantValueNode[] = "Constant";

enum ParseForm : int {
  FORM_PARSE_TYPE = 0,
  FORM_PARSE_SCALAR = 1,
  FORM_PARSE_TENSOR = 2,
};

static std::map<std::string, ParseForm> kParseTypeSwitchMap{
    {"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR}};

static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{
    {onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, {onnx::TensorProto_DataType_INT8, kNumberTypeInt8},
    {onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, {onnx::TensorProto_DataType_INT32, kNumberTypeInt32},
    {onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8},
    {onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32},
    {onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16},
    {onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64},
    {onnx::TensorProto_DataType_STRING, kObjectTypeString},
};

std::shared_ptr<ValueTuple> ParserScalarAttrValue(const std::string &attr_name,
                                                  const std::unordered_map<string, ValuePtr> &kv) {
  std::string str = attr_name;
  auto replace = [&](const string &orgStr, const string &newStr) {
    std::string::size_type pos(0);
    while ((pos = str.find(orgStr)) != std::string::npos) {
      str.replace(pos, orgStr.length(), newStr);
    }
    return str;
  };
  // remove "scalar:"
  str = replace("scalar:", "");
  // remove "Tuple"
  str = replace("Tuple", "");
  // remove "List"
  str = replace("List", "");
  std::stack<std::string> rules;
  std::stack<ValuePtr> value;
  int num = 0, count = 0;
  for (size_t i = 0; i < str.length(); i++) {
    if (str[i] == '[') {
      rules.push("[");
    } else if (str[i] == ']') {
      // rules
      std::vector<ValuePtr> vec;
      while (rules.top() != "[") {
        rules.pop();
        vec.push_back(value.top());
        value.pop();
      }
      // pop "["
      rules.pop();
      // make tuple for names
      std::string res = "dummy";
      // make tuple for values
      reverse(vec.begin(), vec.end());
      auto vt = std::make_shared<ValueTuple>(vec);
      if (rules.empty() && value.empty()) {
        return vt;
      }
      rules.push(res);
      value.push(vt);
    } else if (str[i] == ',') {
      continue;
    } else {
      count++;
      if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') {
        auto value_name = str.substr(i - count + 1, count);
        value.push(kv.at(value_name));
        rules.push(value_name);
        count = 0;
        num++;
      }
    }
  }
  return {};
}

std::shared_ptr<abstract::AbstractTuple>
ParserAttrShape(const std::string &attr_name, const std::unordered_map<string, abstract::AbstractTensorPtr> &kv) {
  std::string str = attr_name;
  auto replace = [&](const string &orgStr, const string &newStr) {
    std::string::size_type pos(0);
    while ((pos = str.find(orgStr)) != std::string::npos) {
      str.replace(pos, orgStr.length(), newStr);
    }
    return str;
  };
  // remove "scalar:"
  str = replace("shape:", "");
  // remove "Tuple"
  str = replace("Tuple", "");
  // remove "List"
  str = replace("List", "");
  std::stack<std::string> rules;
  std::stack<abstract::AbstractBasePtr> value;
  int num = 0, count = 0;
  for (size_t i = 0; i < str.length(); i++) {
    if (str[i] == '[') {
      rules.push("[");
    } else if (str[i] == ']') {
      // rules
      std::vector<abstract::AbstractBasePtr> vec;
      while (rules.top() != "[") {
        rules.pop();
        vec.push_back(value.top());
        value.pop();
      }
      // pop "["
      rules.pop();
      // make tuple for names
      std::string res = "dummy";
      // make tuple for values
      reverse(vec.begin(), vec.end());
      auto vt = std::make_shared<abstract::AbstractTuple>(vec);
      if (rules.empty() && value.empty()) {
        return vt;
      }
      rules.push(res);
      value.push(vt);
    } else if (str[i] == ',') {
      continue;
    } else {
      count++;
      if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') {
        auto value_name = str.substr(i - count + 1, count);
        value.push(kv.at(value_name));
        rules.push(value_name);
        count = 0;
        num++;
      }
    }
  }
  return {};
}

#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype)                                    \
  ValuePtr ParseAttrInScalar_##type##_##valuetype(const onnx::TensorProto &attr_tensor) { \
    if (attr_tensor.type##_data_size() == 1) {                                            \
      auto value = static_cast<valuetype>(attr_tensor.type##_data(0));                    \
      return MakeValue<valuetype>(value);                                                 \
    } else {                                                                              \
      MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!";                          \
    }                                                                                     \
    return {};                                                                            \
  }

PARSE_ONNXATTR_IN_SCALAR_FORM(double, double)
PARSE_ONNXATTR_IN_SCALAR_FORM(float, float)
PARSE_ONNXATTR_IN_SCALAR_FORM(string, string)
PARSE_ONNXATTR_IN_SCALAR_FORM(int32, int32)
PARSE_ONNXATTR_IN_SCALAR_FORM(int32, bool)
PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64)
PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64)

bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node,
                                                         const onnx::ValueInfoProto &value_proto) {
  MS_EXCEPTION_IF_NULL(node);
  if (!value_proto.has_type() || !value_proto.has_name()) {
    MS_LOG(ERROR) << "onnx ValueInfoProto has no type or name! ";
    return false;
  }
  node->set_name(value_proto.name());
  const auto &type_proto = value_proto.type();
  if (!type_proto.has_tensor_type()) {
    MS_LOG(ERROR) << "onnx TypeProto has no tesor_type! ";
    return false;
  }
  const onnx::TypeProto_Tensor &tensor_typeproto = type_proto.tensor_type();
  if (!tensor_typeproto.has_elem_type() || !tensor_typeproto.has_shape()) {
    MS_LOG(ERROR) << "onnx TypeProto_Tensor has no elem_type or shape! ";
    return false;
  }
  const onnx::TensorShapeProto &tensor_shape = tensor_typeproto.shape();
  std::vector<int> shape;
  for (int i = 0; i < tensor_shape.dim_size(); ++i) {
    shape.push_back(tensor_shape.dim(i).dim_value());
  }

  if (kDefaultValueSwitchMap.find(tensor_typeproto.elem_type()) == kDefaultValueSwitchMap.end()) {
    MS_LOG(ERROR) << "onnx TypeProto_Tensor  elem_type is not support yet!";
    return false;
  }

  auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[tensor_typeproto.elem_type()]);
  auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape);
  node->set_abstract(abstract_tensor);

  if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) {
    tensor::Tensor *tensor_info = new tensor::Tensor(kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape);
    MS_EXCEPTION_IF_NULL(tensor_info);
    tensor_info->MallocData();
    const onnx::TensorProto initialize_proto = default_para_map_[value_proto.name()];
    std::string initial_data = initialize_proto.raw_data();
    auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data());
    MS_EXCEPTION_IF_NULL(tensor_data_buf);
    tensor_info->SetData(nullptr);
    auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), initial_data.data(), initial_data.size());
    if (EOK != ret) {
      MS_LOG(ERROR) << "memcpy_s error";
      delete tensor_data_buf;
      delete tensor_info;
      return false;
    }

    ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
    MS_EXCEPTION_IF_NULL(param_value);
    param_value->set_tensor_addr(tensor_data_buf);
    param_value->set_tensor_size(tensor_info->Size());
    param_value->set_tensor_type(tensor_info->data_type());
    param_value->set_tensor_shape(tensor_info->shape());
    node->set_default_param(param_value);
    delete tensor_info;
  }
  anfnode_build_map_[value_proto.name()] = node;
  return true;
}

bool AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
                                                       const onnx::GraphProto &importProto) {
  MS_EXCEPTION_IF_NULL(outputFuncGraph);
  MS_LOG(INFO) << "Parameters had default paramerer size is: " << importProto.initializer_size();

  for (int i = 0; i < importProto.initializer_size(); ++i) {
    const onnx::TensorProto &initializer_proto = importProto.initializer(i);
    if (!initializer_proto.has_name()) {
      MS_LOG(ERROR) << "initializer vector of onnx GraphProto has no name at index: " << i;
      return false;
    }
    default_para_map_[initializer_proto.name()] = initializer_proto;
  }

  MS_LOG(INFO) << "all parameters size: " << importProto.input_size();
  for (int i = 0; i < importProto.input_size(); ++i) {
    const onnx::ValueInfoProto &input_proto = importProto.input(i);
    if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), input_proto)) {
      MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i;
      return false;
    }
  }
  return true;
}

bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
                                                        const onnx::TensorProto &attr_tensor) {
  MS_EXCEPTION_IF_NULL(prim);
  const int attr_tensor_type = attr_tensor.data_type();
  if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
    MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type;
    return false;
  }
  prim->AddAttr(attr_name, TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
  return true;
}

ValuePtr AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor) {
  const int attr_tensor_type = attr_tensor.data_type();
  switch (attr_tensor_type) {
    case onnx::TensorProto_DataType_STRING: {
      return ParseAttrInScalar_string_string(attr_tensor);
      break;
    }
    case onnx::TensorProto_DataType_INT32: {
      return ParseAttrInScalar_int32_int32(attr_tensor);
    }
    case onnx::TensorProto_DataType_INT64: {
      return ParseAttrInScalar_int64_int64(attr_tensor);
    }
    case onnx::TensorProto_DataType_UINT64: {
      return ParseAttrInScalar_uint64_uint64(attr_tensor);
    }
    case onnx::TensorProto_DataType_FLOAT: {
      return ParseAttrInScalar_float_float(attr_tensor);
    }
    case onnx::TensorProto_DataType_DOUBLE: {
      return ParseAttrInScalar_double_double(attr_tensor);
    }
    case onnx::TensorProto_DataType_BOOL: {
      return ParseAttrInScalar_int32_bool(attr_tensor);
    }
    default:MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type;
      return {};
  }
  return {};
}

bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
                                                          const onnx::TensorProto &attr_tensor) {
  MS_EXCEPTION_IF_NULL(prim);
  const int attr_tensor_type = attr_tensor.data_type();
  const std::string &tensor_buf = attr_tensor.raw_data();
  std::vector<int> shape;
  auto ret = EOK;
  if (attr_tensor.dims_size() != 0) {
    for (int i = 0; i < attr_tensor.dims_size(); ++i) {
      shape.push_back(attr_tensor.dims(i));
    }
    tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
    tensor_info->MallocData();
    auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data());
    ret = memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(), tensor_buf.size());
    prim->set_attr(attr_name, MakeValue(tensor_info));
  } else {
    if (attr_tensor_type == onnx::TensorProto_DataType_DOUBLE) {
      size_t data_size = sizeof(double);
      double attr_value = 0.0;
      ret = memcpy_s(&attr_value, data_size, tensor_buf.data(), tensor_buf.size());
      prim->set_attr(attr_name, MakeValue<double>(attr_value));
    } else if (attr_tensor_type == onnx::TensorProto_DataType_INT64) {
      size_t data_size = sizeof(int64_t);
      int32_t attr_value = 0;
      ret = memcpy_s(&attr_value, data_size, tensor_buf.data(), tensor_buf.size());
      prim->set_attr(attr_name, MakeValue<int32_t>(attr_value));
    } else if (attr_tensor_type == onnx::TensorProto_DataType_BOOL) {
      size_t data_size = sizeof(bool);
      bool attr_value = false;
      ret = memcpy_s(&attr_value, data_size, tensor_buf.data(), tensor_buf.size());
      prim->set_attr(attr_name, MakeValue<bool>(attr_value));
    }
  }
  return ret == EOK;
}

bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) {
  MS_EXCEPTION_IF_NULL(prim);
  const std::string &attr_name = attr_proto.name();
  if (!attr_proto.has_ref_attr_name()) {
    MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
    return false;
  }
  const std::string &ref_attr_name = attr_proto.ref_attr_name();
  string type;
  std::size_t pos(0);
  if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) {
    type = ref_attr_name.substr(pos, string("scalar:").length() - 1);
  } else if ((pos = ref_attr_name.find("type:")) != std::string::npos) {
    type = ref_attr_name.substr(pos, string("type:").length() - 1);
  } else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) {
    type = ref_attr_name.substr(pos, string("tensor:").length() - 1);
  }
  std::unordered_map<std::string, ValuePtr> kv;
  for (int i = 0; i < attr_proto.tensors_size(); i++) {
    const onnx::TensorProto &attr_tensor = attr_proto.tensors(i);
    switch (kParseTypeSwitchMap[type]) {
      case FORM_PARSE_TYPE: {
        return ObtainCNodeAttrInTypeForm(prim, attr_name, attr_tensor);
      }
      case FORM_PARSE_SCALAR: {
        auto res = ObtainCNodeAttrInScalarForm(attr_tensor);
        kv.insert(std::pair<string, ValuePtr>(attr_tensor.name(), res));
        break;
      }
      case FORM_PARSE_TENSOR: {
        return ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor);
      }
      default:MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name";
        return false;
    }
  }
  if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) {
    if (kv.size() == 1) {
      std::unordered_map<std::string, ValuePtr>::iterator iter = kv.begin();
      prim->AddAttr(attr_name, iter->second);
    } else {
      auto res = ParserScalarAttrValue(ref_attr_name, kv);
      prim->AddAttr(attr_name, res);
    }
  }
  return true;
}

bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &value_node_name,
                                                          const onnx::TensorProto &attr_tensor) {
  const int attr_tensor_type = attr_tensor.data_type();
  std::vector<int> shape;
  for (int i = 0; i < attr_tensor.dims_size(); ++i) {
    shape.push_back(attr_tensor.dims(i));
  }
  tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
  tensor_info->MallocData();
  const std::string &tensor_buf = attr_tensor.raw_data();
  auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data());
  auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(), tensor_buf.size());
  if (EOK != ret) {
    MS_LOG(ERROR) << "memcpy_s error";
    return false;
  }
  auto new_value_node = NewValueNode(MakeValue(tensor_info));
  MS_EXCEPTION_IF_NULL(new_value_node);
  auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]);
  auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape);
  new_value_node->set_abstract(abstract_tensor);
  anfnode_build_map_[value_node_name] = new_value_node;
  return true;
}

bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value_node_name,
                                                        const onnx::TensorProto &attr_tensor) {
  const int attr_tensor_type = attr_tensor.data_type();
  if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
    MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
    return false;
  }
  auto new_value_node = NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
  abstract::AbstractTypePtr abs_type = std::make_shared<abstract::AbstractType>(std::make_shared<TypeType>());
  new_value_node->set_abstract(abs_type);
  anfnode_build_map_[value_node_name] = new_value_node;
  return true;
}

bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &value_node_name,
                                                       const onnx::AttributeProto &attr_proto) {
  //const std::string &attr_name = attr_proto.name();
  if (!attr_proto.has_ref_attr_name()) {
    MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
    return false;
  }
  const std::string &ref_attr_name = attr_proto.ref_attr_name();
  string type;
  std::size_t pos(0);
  if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) {
    type = ref_attr_name.substr(pos, string("scalar:").length() - 1);
  } else if ((pos = ref_attr_name.find("type:")) != std::string::npos) {
    type = ref_attr_name.substr(pos, string("type:").length() - 1);
  } else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) {
    type = ref_attr_name.substr(pos, string("tensor:").length() - 1);
  }
  std::unordered_map<std::string, ValuePtr> kv;
  for (int i = 0; i < attr_proto.tensors_size(); i++) {
    const onnx::TensorProto &attr_tensor = attr_proto.tensors(i);
    switch (kParseTypeSwitchMap[type]) {
      case FORM_PARSE_TYPE: {
        return ObtainValueNodeInTypeForm(value_node_name, attr_tensor);
      }
      case FORM_PARSE_SCALAR: {
        auto res = ObtainCNodeAttrInScalarForm(attr_tensor);
        kv.insert(std::pair<string, ValuePtr>(attr_tensor.name(), res));
        break;
      }
      case FORM_PARSE_TENSOR: {
        return ObtainValueNodeInTensorForm(value_node_name, attr_tensor);
      }
      default:MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name";
        return false;
    }
  }

  ValueNodePtr new_value_node;
  if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) {
    if (kv.size() == 1) {
      auto iter = kv.begin();
      new_value_node = NewValueNode(iter->second);
      new_value_node->set_abstract(iter->second->ToAbstract());
    } else {
      auto value_ptr = ParserScalarAttrValue(ref_attr_name, kv);
      new_value_node = NewValueNode(value_ptr);
      new_value_node->set_abstract(value_ptr->ToAbstract());
    }
    anfnode_build_map_[value_node_name] = new_value_node;
  }
  return true;
}

bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) {
  const std::string &value_node_name = node_proto.output(0);
  const onnx::AttributeProto &attr_proto = node_proto.attribute(0);
  if (!attr_proto.has_ref_attr_name()) {
    MS_LOG(ERROR) << "parse ValueNode  don't have ref_attr_name";
    return false;
  }
  return GetAttrValueForValueNode(value_node_name, attr_proto);
}

std::unordered_map<std::string, abstract::AbstractTensorPtr>
AnfImporterFromProtobuf::GetAbstractForCNode(const onnx::AttributeProto &attr_proto) {
  std::unordered_map<std::string, abstract::AbstractTensorPtr> kv;
  for (int i = 0; i < attr_proto.tensors_size(); i++) {
    std::vector<int> shape_vec;
    const onnx::TensorProto &attr_tensor = attr_proto.tensors(i);
    for (int j = 0; j < attr_tensor.dims_size(); ++j) {
      shape_vec.push_back(attr_tensor.dims(j));
    }
    auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]);
    auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vec);
    kv.insert(std::pair<string, abstract::AbstractTensorPtr>(attr_tensor.name(), abstract_tensor));
  }
  return kv;
}

CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
                                                         const onnx::NodeProto &node_proto,
                                                         const schema::QuantType &quantType) {
  MS_EXCEPTION_IF_NULL(outputFuncGraph);
  if (!node_proto.has_op_type()) {
    MS_LOG(ERROR) << "Get CNode op_type failed!";
    return nullptr;
  }
  const std::string &node_name = node_proto.output(0);
  const std::string &fullname_with_scope = node_proto.domain();
  const std::string &node_type = node_proto.op_type();
  PrimitivePtr prim = std::make_shared<mindspore::Primitive>(node_type);
  MS_EXCEPTION_IF_NULL(prim);
  prim->set_instance_name(node_type);
  std::unordered_map<std::string, abstract::AbstractTensorPtr> kv;
  string shape_ref_attr_name;
  for (int i = 0; i < node_proto.attribute_size(); ++i) {
    const onnx::AttributeProto &attr_proto = node_proto.attribute(i);
    if (attr_proto.ref_attr_name().find("shape:") != string::npos) {
      shape_ref_attr_name = attr_proto.ref_attr_name();
      kv = GetAbstractForCNode(attr_proto);
      continue;
    }
    if (!GetAttrValueForCNode(prim, attr_proto)) {
      MS_LOG(ERROR) << "Get CNode attr failed!";
      return nullptr;
    }
  }

  std::vector<AnfNodePtr> inputs;
  inputs.clear();
  for (int i = 0; i < node_proto.input_size(); ++i) {
    const std::string &input_name = node_proto.input(i);
    if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) {
      MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed";
      return nullptr;
    }
    inputs.push_back(anfnode_build_map_[input_name]);
  }
  auto primitivec_ptr = PrimitiveC::UnPackFromPrimitive(*prim, inputs, quantType);
  if (primitivec_ptr == nullptr) {
    MS_LOG(ERROR) << "Create PrimitiveC return nullptr, " << prim->name();
    return nullptr;
  }
  inputs.insert(inputs.begin(), NewValueNode(primitivec_ptr));
  CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs);
  MS_EXCEPTION_IF_NULL(cnode_ptr);
  if (0 == kv.size()) {
    AbstractBasePtrList elem;
    for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) {
      elem.push_back(cnode_ptr->input(index)->abstract());
    }
    cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
  } else if (1 == kv.size()) {
    std::unordered_map<std::string, abstract::AbstractTensorPtr>::iterator iter = kv.begin();
    cnode_ptr->set_abstract(iter->second);
  } else {
    auto abstract = ParserAttrShape(shape_ref_attr_name, kv);
    cnode_ptr->set_abstract(abstract);
  }

  cnode_ptr->set_fullname_with_scope(fullname_with_scope);
  anfnode_build_map_[node_name] = cnode_ptr;
  return cnode_ptr;
}

bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph,
                                                      const onnx::GraphProto &importProto, const CNodePtr &cnode_ptr) {
  MS_EXCEPTION_IF_NULL(outputFuncGraph);
  MS_EXCEPTION_IF_NULL(cnode_ptr);
  std::vector<AnfNodePtr> inputs;
  if (importProto.output_size() > 1) {
    inputs.clear();
    auto primitiveT = std::make_unique<schema::PrimitiveT>();
    MS_ASSERT(primitiveT != nullptr);
    primitiveT->value.type = schema::PrimitiveType_MakeTuple;
    std::shared_ptr<PrimitiveC> primitivec_ptr = std::make_shared<PrimitiveC>(primitiveT.release());
    MS_ASSERT(primitivec_ptr != nullptr);
    inputs.push_back(NewValueNode(primitivec_ptr));
    AbstractBasePtrList elem;
    for (int out_size = 0; out_size < importProto.output_size(); ++out_size) {
      const onnx::ValueInfoProto &output_node = importProto.output(out_size);
      const std::string &out_tuple = output_node.name();
      inputs.push_back(anfnode_build_map_[out_tuple]);
      elem.push_back(anfnode_build_map_[out_tuple]->abstract());
    }
    auto maketuple_ptr = outputFuncGraph->NewCNode(inputs);
    maketuple_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
    inputs.clear();
    auto primReturn = std::make_unique<schema::PrimitiveT>();
    MS_ASSERT(primReturn != nullptr);
    primReturn->value.type = schema::PrimitiveType_Return;
    std::shared_ptr<PrimitiveC> primitive_return_value_ptr = std::make_shared<PrimitiveC>(primReturn.release());
    MS_ASSERT(primitive_return_value_ptr != nullptr);
    inputs.push_back(NewValueNode(primitive_return_value_ptr));
    inputs.push_back(maketuple_ptr);
    auto return_node = outputFuncGraph->NewCNode(inputs);
    MS_EXCEPTION_IF_NULL(return_node);
    outputFuncGraph->set_return(return_node);
    MS_LOG(INFO) << "Construct funcgraph finined, all success.";
  } else {
    const onnx::ValueInfoProto &output_node = importProto.output(0);
    const onnx::TypeProto &output_typeproto = output_node.type();
    int output_type = output_typeproto.tensor_type().elem_type();
    std::vector<int> output_shape;
    for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size(); ++i) {
      output_shape.push_back(output_typeproto.tensor_type().shape().dim(i).dim_value());
    }
    auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[output_type]);
    auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, output_shape);

    inputs.clear();
    auto primReturn = std::make_unique<schema::PrimitiveT>();
    MS_ASSERT(primReturn != nullptr);
    primReturn->value.type = schema::PrimitiveType_Return;
    std::shared_ptr<PrimitiveC> primitiveTReturnValuePtr = std::make_shared<PrimitiveC>(primReturn.release());
    MS_ASSERT(primitiveTReturnValuePtr != nullptr);
    inputs.push_back(NewValueNode(primitiveTReturnValuePtr));
    inputs.push_back(cnode_ptr);
    auto return_node = outputFuncGraph->NewCNode(inputs);
    MS_EXCEPTION_IF_NULL(return_node);
    return_node->set_abstract(abstract_tensor);
    outputFuncGraph->set_return(return_node);
    MS_LOG(INFO) << "Construct funcgraph finined, all success!";
  }
  return true;
}

bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
                                                  const onnx::GraphProto &importProto,
                                                  const schema::QuantType &quantType) {
  MS_EXCEPTION_IF_NULL(outputFuncGraph);
  MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size();
  CNodePtr cnode_ptr = nullptr;
  for (int i = 0; i < importProto.node_size(); ++i) {
    const onnx::NodeProto &node_proto = importProto.node(i);
    const std::string &node_type = node_proto.op_type();
    if (node_type == kConstantValueNode) {
      if (!BuildValueNodeForFuncGraph(node_proto)) {
        MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: : " << i;
        return false;
      }
      continue;
    }
    cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto, quantType);
    if (cnode_ptr == nullptr) {
      MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i;
      return false;
    }
  }

  BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr);
  return true;
}

bool AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
                                             const schema::QuantType &quantType) {
  MS_EXCEPTION_IF_NULL(outputFuncGraph);
  GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info();
  MS_EXCEPTION_IF_NULL(debug_info_ptr);
  if (importProto.has_name()) {
    debug_info_ptr->set_name(importProto.name());
  } else {
    MS_LOG(ERROR) << "FuncGraph under converting has not name!";
  }

  if (!ImportParametersForGraph(outputFuncGraph, importProto)) {
    return false;
  }
  return ImportNodesForGraph(outputFuncGraph, importProto, quantType);
}

bool AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) {
  if (!model_proto.has_producer_name()) {
    MS_LOG(ERROR) << "Parse model producer name from pb file failed!";
    return false;
  }
  producer_name_ = model_proto.producer_name();

  if (!model_proto.has_model_version()) {
    MS_LOG(ERROR) << "Parse model producer version from pb file failed!";
    return false;
  }
  model_version_ = model_proto.model_version();

  if (!model_proto.has_ir_version()) {
    MS_LOG(ERROR) << "Parse model version from pb file failed!";
    return false;
  }
  ir_version_ = model_proto.ir_version();
  return true;
}

int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) {
  FuncGraphPtr dstGraph = std::make_shared<mindspore::FuncGraph>();
  MS_EXCEPTION_IF_NULL(dstGraph);
  if (!ParseModelConfigureInfo(*onnx_model_)) {
    MS_LOG(ERROR) << "Parse configuration info for pb file failed!";
  }
  const onnx::GraphProto &graphBuild = onnx_model_->graph();
  if (!BuildFuncGraph(dstGraph, graphBuild, quantType)) {
    MS_LOG(ERROR) << "Build funcgraph failed!";
    func_graph_ = nullptr;
    return RET_ERROR;
  }
  func_graph_ = dstGraph;
  MS_LOG(INFO) << "Parse pb to build FuncGraph Success!";
  return RET_OK;
}

onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) {
  auto onnx_model = new onnx::ModelProto;
  if (ReadProtoFromBinaryFile((const char *) model_path.c_str(), onnx_model) != RET_OK) {
    MS_LOG(ERROR) << "Read onnx model file failed, model path: " << model_path;
    return nullptr;
  }
  return onnx_model;
}

FuncGraphPtr AnfImporterFromProtobuf::GetResult() { return this->func_graph_; }
}  // namespace mindspore::lite


In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_
#define MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_

#include <map>
#include <string>
#include <unordered_map>
#include <utility>

#include "include/errorcode.h"
#include "tools/converter/parser/onnx/onnx.pb.h"
#include "tools/anf_importer/anf_importer.h"
#include "abstract/abstract_value.h"

namespace mindspore::lite {
class AnfImporterFromProtobuf : public AnfImporter {
 public:
  explicit AnfImporterFromProtobuf(onnx::ModelProto *onnx_model, FuncGraphPtr func_graph)
      : onnx_model_(onnx_model), func_graph_(std::move(func_graph)) {}

  ~AnfImporterFromProtobuf() override = default;

  static onnx::ModelProto *ReadOnnxFromBinary(const std::string &model_path);

  FuncGraphPtr GetResult() override;

  int Import(const schema::QuantType &quantType = schema::QuantType_QUANT_NONE) override;

 private:
  int ConverterConstTensor() override { return RET_ERROR; };
  int ConverterCNode() override { return RET_ERROR; };
  int AddReturnCNode() override { return RET_ERROR; };
  bool ParseModelConfigureInfo(const onnx::ModelProto &model_proto);
  bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
                      const schema::QuantType &quantType);
  bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
  bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
                           const schema::QuantType &quantType);
  bool BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto);
  CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto,
                                  const schema::QuantType &quantType);
  bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
                               const CNodePtr &cnode_ptr);
  bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto);
  bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
                                 const onnx::TensorProto &attr_tensor);
  ValuePtr ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor);
  bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
                                   const onnx::TensorProto &attr_tensor);
  bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto);
  bool ObtainValueNodeInTensorForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
  bool GetAttrValueForValueNode(const std::string &value_node_name, const onnx::AttributeProto &attr_proto);
  bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
  std::unordered_map<std::string,
                     abstract::AbstractTensorPtr> GetAbstractForCNode(const onnx::AttributeProto &attr_proto);

 private:
  std::string producer_name_;
  int model_version_{};
  int ir_version_{};
  std::unordered_map<std::string, AnfNodePtr> anfnode_build_map_;
  std::map<std::string, onnx::TensorProto> default_para_map_;
  onnx::ModelProto *onnx_model_;
  FuncGraphPtr func_graph_;
};
}  // namespace mindspore::lite

#endif  // MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_


In [None]:
from transformers import *
import tokenizers

In [None]:
!mkdir -p ./input/roberta-base

In [None]:
save_path = './input/roberta-base'
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaModel.from_pretrained('roberta-base')
config = RobertaConfig.from_pretrained('roberta-base')
tokenizer.save_vocabulary(save_path)
model.save_pretrained(save_path)
config.save_pretrained(save_path)

In [None]:
class config:
    FOLD = 0
    LEARNING_RATE = 0.2 * 3e-5
    MAX_LEN = 192
    TRAIN_BATCH_SIZE = 16
    VALID_BATCH_SIZE = 8
    EPOCHS = 3
    TRAINING_FILE = "./tweet-sentiment/train_folds.csv"
    ROBERTA_PATH = "./input/roberta-base"
    TOKENIZER = tokenizers.ByteLevelBPETokenizer(
        vocab_file=f"{ROBERTA_PATH}/vocab.json", 
        merges_file=f"{ROBERTA_PATH}/merges.txt", 
        lowercase=True,
        add_prefix_space=True
    )

#Data process

In [None]:
def process_data(tweet, selected_text, sentiment, tokenizer, max_len):
    tweet = " " + " ".join(str(tweet).split())
    selected_text = " " + " ".join(str(selected_text).split())

    len_st = len(selected_text) - 1
    idx0 = None
    idx1 = None

    for ind in (i for i, e in enumerate(tweet) if e == selected_text[1]):
        if " " + tweet[ind: ind+len_st] == selected_text:
            idx0 = ind
            idx1 = ind + len_st - 1
            break

    char_targets = [0] * len(tweet)
    if idx0 != None and idx1 != None:
        for ct in range(idx0, idx1 + 1):
            char_targets[ct] = 1
    
    tok_tweet = tokenizer.encode(tweet)
    input_ids_orig = tok_tweet.ids
    tweet_offsets = tok_tweet.offsets
    
    target_idx = []
    for j, (offset1, offset2) in enumerate(tweet_offsets):
        if sum(char_targets[offset1: offset2]) > 0:
            target_idx.append(j)
    
    targets_start = target_idx[0]
    targets_end = target_idx[-1]

    sentiment_id = {
        'positive': 1313,
        'negative': 2430,
        'neutral': 7974
    }
    
    input_ids = [0] + [sentiment_id[sentiment]] + [2] + [2] + input_ids_orig + [2]
    token_type_ids = [0, 0, 0, 0] + [0] * (len(input_ids_orig) + 1)
    mask = [1] * len(token_type_ids)
    tweet_offsets = [(0, 0)] * 4 + tweet_offsets + [(0, 0)]
    targets_start += 4
    targets_end += 4

    padding_length = max_len - len(input_ids)
    if padding_length > 0:
        input_ids = input_ids + ([1] * padding_length)
        mask = mask + ([0] * padding_length)
        token_type_ids = token_type_ids + ([0] * padding_length)
        tweet_offsets = tweet_offsets + ([(0, 0)] * padding_length)
    
    return {
        'ids': input_ids,
        'mask': mask,
        'token_type_ids': token_type_ids,
        'targets_start': targets_start,
        'targets_end': targets_end,
        'orig_tweet': tweet,
        'orig_selected': selected_text,
        'sentiment': sentiment,
        'offsets': tweet_offsets
    }

#Data loader 

In [None]:
class TweetDataset:
    def __init__(self, tweet, sentiment, selected_text):
        self.tweet = tweet
        self.sentiment = sentiment
        self.selected_text = selected_text
        self.tokenizer = config.TOKENIZER
        self.max_len = config.MAX_LEN
    
    def __len__(self):
        return len(self.tweet)

    def __getitem__(self, item):
        data = process_data(
            self.tweet[item], 
            self.selected_text[item], 
            self.sentiment[item],
            self.tokenizer,
            self.max_len
        )

        return {
            'ids': torch.tensor(data["ids"], dtype=torch.long),
            'mask': torch.tensor(data["mask"], dtype=torch.long),
            'token_type_ids': torch.tensor(data["token_type_ids"], dtype=torch.long),
            'targets_start': torch.tensor(data["targets_start"], dtype=torch.long),
            'targets_end': torch.tensor(data["targets_end"], dtype=torch.long),
            'orig_tweet': data["orig_tweet"],
            'orig_selected': data["orig_selected"],
            'sentiment': data["sentiment"],
            'offsets': torch.tensor(data["offsets"], dtype=torch.long)
        }


In [None]:
class TweetModel(transformers.BertPreTrainedModel):
    def __init__(self, conf):
        super(TweetModel, self).__init__(conf)
        self.roberta = transformers.RobertaModel.from_pretrained(config.ROBERTA_PATH, config=conf)
        self.drop_out = nn.Dropout(0.1)
        self.l0 = nn.Linear(768 * 2, 2)
        torch.nn.init.normal_(self.l0.weight, std=0.02)
    
    def forward(self, ids, mask, token_type_ids):
        _, _, out = self.roberta(
            ids,
            attention_mask=mask,
            token_type_ids=token_type_ids
        )

        out = torch.cat((out[-1], out[-2]), dim=-1)
        out = self.drop_out(out)
        logits = self.l0(out)

        start_logits, end_logits = logits.split(1, dim=-1)

        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        return start_logits, end_logits

In [None]:
def loss_fn(start_logits, end_logits, start_positions, end_positions):
    loss_fct = nn.CrossEntropyLoss()
    start_loss = loss_fct(start_logits, start_positions)
    end_loss = loss_fct(end_logits, end_positions)
    total_loss = (start_loss + end_loss)
    return total_loss

In [None]:
def train_fn(data_loader, model, optimizer, device, num_batches, scheduler=None):
    model.train()
    tk0 = tqdm(data_loader, total=num_batches, desc="Training", disable=not xm.is_master_ordinal())
    for bi, d in enumerate(tk0):
        ids = d["ids"]
        token_type_ids = d["token_type_ids"]
        mask = d["mask"]
        targets_start = d["targets_start"]
        targets_end = d["targets_end"]
        sentiment = d["sentiment"]
        orig_selected = d["orig_selected"]
        orig_tweet = d["orig_tweet"]
        targets_start = d["targets_start"]
        targets_end = d["targets_end"]
        offsets = d["offsets"]

        ids = ids.to(device, dtype=torch.long)
        token_type_ids = token_type_ids.to(device, dtype=torch.long)
        mask = mask.to(device, dtype=torch.long)
        targets_start = targets_start.to(device, dtype=torch.long)
        targets_end = targets_end.to(device, dtype=torch.long)

        model.zero_grad()
        outputs_start, outputs_end = model(
            ids=ids,
            mask=mask,
            token_type_ids=token_type_ids,
        )
        loss = loss_fn(outputs_start, outputs_end, targets_start, targets_end)
        loss.backward()
        xm.optimizer_step(optimizer)
        scheduler.step()
        print_loss = xm.mesh_reduce('loss_reduce', loss, reduce_fn)
        tk0.set_postfix(loss=print_loss.item())

In [None]:
def calculate_jaccard_score(
    original_tweet, 
    target_string, 
    sentiment_val, 
    idx_start, 
    idx_end, 
    offsets,
    verbose=False):
    
    if idx_end < idx_start:
        idx_end = idx_start
    
    filtered_output  = ""
    for ix in range(idx_start, idx_end + 1):
        filtered_output += original_tweet[offsets[ix][0]: offsets[ix][1]]
        if (ix+1) < len(offsets) and offsets[ix][1] < offsets[ix+1][0]:
            filtered_output += " "

    if len(original_tweet.split()) < 2:
        filtered_output = original_tweet

    jac = jaccard(target_string.strip(), filtered_output.strip())
    return jac, filtered_output


def eval_fn(data_loader, model, device):
    model.eval()
    losses = AverageMeter()
    jaccards = AverageMeter()
    
    with torch.no_grad():
        for bi, d in enumerate(data_loader):
            ids = d["ids"]
            token_type_ids = d["token_type_ids"]
            mask = d["mask"]
            sentiment = d["sentiment"]
            orig_selected = d["orig_selected"]
            orig_tweet = d["orig_tweet"]
            targets_start = d["targets_start"]
            targets_end = d["targets_end"]
            offsets = d["offsets"].cpu().numpy()

            ids = ids.to(device, dtype=torch.long)
            token_type_ids = token_type_ids.to(device, dtype=torch.long)
            mask = mask.to(device, dtype=torch.long)
            targets_start = targets_start.to(device, dtype=torch.long)
            targets_end = targets_end.to(device, dtype=torch.long)

            outputs_start, outputs_end = model(
                ids=ids,
                mask=mask,
                token_type_ids=token_type_ids
            )
            loss = loss_fn(outputs_start, outputs_end, targets_start, targets_end)
            outputs_start = torch.softmax(outputs_start, dim=1).cpu().detach().numpy()
            outputs_end = torch.softmax(outputs_end, dim=1).cpu().detach().numpy()
            jaccard_scores = []
            for px, tweet in enumerate(orig_tweet):
                selected_tweet = orig_selected[px]
                tweet_sentiment = sentiment[px]
                jaccard_score, _ = calculate_jaccard_score(
                    original_tweet=tweet,
                    target_string=selected_tweet,
                    sentiment_val=tweet_sentiment,
                    idx_start=np.argmax(outputs_start[px, :]),
                    idx_end=np.argmax(outputs_end[px, :]),
                    offsets=offsets[px]
                )
                jaccard_scores.append(jaccard_score)

            jaccards.update(np.mean(jaccard_scores), ids.size(0))
            losses.update(loss.item(), ids.size(0))

    return jaccards.avg

In [None]:
model_config = transformers.RobertaConfig.from_pretrained(config.ROBERTA_PATH)
model_config.output_hidden_states = True
MX = TweetModel(conf=model_config)

dfx = pd.read_csv(config.TRAINING_FILE)

df_train = dfx[dfx.kfold != config.FOLD].reset_index(drop=True)
df_valid = dfx[dfx.kfold == config.FOLD].reset_index(drop=True)

training

In [None]:
def run():
    device = xm.xla_device()
    model = MX.to(device)

    train_dataset = TweetDataset(
        tweet=df_train.text.values,
        sentiment=df_train.sentiment.values,
        selected_text=df_train.selected_text.values
    )

    train_sampler = torch.utils.data.distributed.DistributedSampler(
      train_dataset,
      num_replicas=xm.xrt_world_size(),
      rank=xm.get_ordinal(),
      shuffle=True
    )

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN_BATCH_SIZE,
        sampler=train_sampler,
        drop_last=True,
        num_workers=2
    )

    valid_dataset = TweetDataset(
        tweet=df_valid.text.values,
        sentiment=df_valid.sentiment.values,
        selected_text=df_valid.selected_text.values
    )

    valid_sampler = torch.utils.data.distributed.DistributedSampler(
      valid_dataset,
      num_replicas=xm.xrt_world_size(),
      rank=xm.get_ordinal(),
      shuffle=False
    )

    valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.VALID_BATCH_SIZE,
        sampler=valid_sampler,
        drop_last=False,
        num_workers=1
    )

    num_train_steps = int(len(df_train) / config.TRAIN_BATCH_SIZE * config.EPOCHS)
    param_optimizer = list(model.named_parameters())
    no_decay = [
        "bias",
        "LayerNorm.bias",
        "LayerNorm.weight"
    ]
    optimizer_parameters = [
        {
            'params': [
                p for n, p in param_optimizer if not any(
                    nd in n for nd in no_decay
                )
            ], 
         'weight_decay': 0.001
        },
        {
            'params': [
                p for n, p in param_optimizer if any(
                    nd in n for nd in no_decay
                )
            ], 
            'weight_decay': 0.0
        },
    ]
    num_train_steps = int(
        len(df_train) / config.TRAIN_BATCH_SIZE / xm.xrt_world_size() * config.EPOCHS
    )
    optimizer = AdamW(
        optimizer_parameters, 
        lr=config.LEARNING_RATE * xm.xrt_world_size()
    )
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=num_train_steps
    )

    best_jac = 0
    es = EarlyStopping(patience=2, mode="max")
    num_batches = int(len(df_train) / (config.TRAIN_BATCH_SIZE * xm.xrt_world_size()))
    
    xm.master_print("Training is Starting....")

    for epoch in range(config.EPOCHS):
        para_loader = pl.ParallelLoader(train_data_loader, [device])
        train_fn(
            para_loader.per_device_loader(device), 
            model, 
            optimizer, 
            device,
            num_batches,
            scheduler
        )

        para_loader = pl.ParallelLoader(valid_data_loader, [device])
        jac = eval_fn(
            para_loader.per_device_loader(device), 
            model, 
            device
        )
        jac = xm.mesh_reduce('jac_reduce', jac, reduce_fn)
        xm.master_print(f'Epoch={epoch}, Jaccard={jac}')
        if jac > best_jac:
            xm.master_print("Model Improved!!! Saving Model")
            xm.save(model.state_dict(), f"model_{config.FOLD}.bin")
            best_jac = jac

In [None]:
def _mp_fn(rank, flags):
    torch.set_default_tensor_type('torch.FloatTensor')
    a = run()

In [None]:
FLAGS={}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')

Training is Starting....


HBox(children=(FloatProgress(value=0.0, description='Training', max=171.0, style=ProgressStyle(description_wid…