<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include <vector>
#include <memory>
#include "src/gllo/common/gllo_utils.h"
#include "src/ir/primitive_t_value.h"
#include "frontend/operator/ops.h"

using PrimitiveTValuePtr = std::shared_ptr<mindspore::lite::PrimitiveTValue>;
namespace mindspore {
namespace opt {
namespace{
constexpr auto kAnfPrimitiveIndex = 0;
bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
  MS_EXCEPTION_IF_NULL(node);
  if (!node->isa<CNode>()) {
    return false;
  }
  auto cnode = node->cast<CNodePtr>();
  MS_EXCEPTION_IF_NULL(cnode);
  return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type);
}

bool IsRealKernel(const AnfNodePtr &node) {
  MS_EXCEPTION_IF_NULL(node);
  // parameter and value node is not a real kernel too
  if (!node->isa<CNode>()) {
    return true;
  }
  auto cnode = node->cast<CNodePtr>();
  MS_EXCEPTION_IF_NULL(cnode);
  if (cnode->inputs().empty()) {
    MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString();
  }
  auto input = cnode->inputs()[0];
  bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) ||
      IsPrimitive(input, prim::kPrimTensorSummary) ||
      IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) ||
      IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) ||
      IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) ||
      IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial);
  return !is_virtual_node;
}

ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
  if (utils::isa<int>(sexp)) {
    return NewValueNode(utils::cast<int>(sexp));
  }
  if (utils::isa<float>(sexp)) {
    return NewValueNode(utils::cast<float>(sexp));
  }
  if (utils::isa<bool>(sexp)) {
    return NewValueNode(utils::cast<bool>(sexp));
  }
  if (utils::isa<ValuePtr>(sexp)) {
    return NewValueNode(utils::cast<ValuePtr>(sexp));
  }
  return nullptr;
}

CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
  if (utils::isa<FuncGraphPtr>(graph)) {
    return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
  }
  if (utils::isa<VarPtr>(graph)) {
    return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph));
  }
  return nullptr;
}

VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
  if (utils::isa<VarPtr>(graph)) {
    MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
    return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
  }
  if (utils::isa<FuncGraphPtr>(graph)) {
    MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString();
    return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph));
  }
  MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString();
  return nullptr;
}

AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
                            bool multigraph) {
  MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString();
  std::vector<AnfNodePtr> input_nodes;
  const auto &tuple = utils::cast<VectorRef>(sexp);
  if (multigraph && utils::isa<VarPtr>(graph)) {
    for (auto &x : tuple) {
      AnfNodePtr node = SexpToNode(x, std::make_shared<Var>("G"), primitive_vars, true);
      input_nodes.push_back(node);
    }
    VarPtr var_ptr = utils::cast<VarPtr>(graph);
    return std::make_shared<CNode>(input_nodes, var_ptr);
  }

  for (auto &x : tuple) {
    AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph);
    input_nodes.push_back(node);
  }
  return CreateCNodeWithGraph(input_nodes, graph);
}
}  // namespace

bool AnfEqual(const BaseRef &a, const BaseRef &b) {
  if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
    auto a_node = utils::cast<AnfNodePtr>(a);
    auto b_node = utils::cast<AnfNodePtr>(b);
    MS_EXCEPTION_IF_NULL(a_node);
    MS_EXCEPTION_IF_NULL(b_node);
    if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
      auto a_value_node = a_node->cast<ValueNodePtr>();
      MS_EXCEPTION_IF_NULL(a_value_node);
      auto a_value = a_value_node->value();
      MS_EXCEPTION_IF_NULL(a_value);
      auto a_prim = a_value->cast<PrimitivePtr>();
      MS_EXCEPTION_IF_NULL(a_prim);

      auto b_value_node = b_node->cast<ValueNodePtr>();
      MS_EXCEPTION_IF_NULL(b_value_node);
      auto b_value = b_value_node->value();
      MS_EXCEPTION_IF_NULL(b_value);
      auto b_prim = b_value->cast<PrimitivePtr>();
      MS_EXCEPTION_IF_NULL(b_prim);

      return a_prim->name() == b_prim->name();
    } else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {
      auto a_value_node_ptr = a_node->cast<ValueNodePtr>();
      if (a_value_node_ptr == nullptr) {
        MS_LOG(EXCEPTION) << "cast value node ptr fail";
      }
      auto a_value_ptr = a_value_node_ptr->value();
      if (a_value_ptr == nullptr) {
        MS_LOG(EXCEPTION) << "value ptr is nullptr";
      }

      auto b_value_node_ptr = b_node->cast<ValueNodePtr>();
      if (b_value_node_ptr == nullptr) {
        MS_LOG(EXCEPTION) << "cast value node ptr fail";
      }
      auto b_value_ptr = b_value_node_ptr->value();
      if (b_value_ptr == nullptr) {
        MS_LOG(EXCEPTION) << "value ptr is nullptr";
      }

      if (utils::isa<lite::PrimitiveTValue>(a_value_ptr) && utils::isa<lite::PrimitiveTValue>(b_value_ptr)) {
        auto a_obj = (lite::PrimitiveTValue *)(a_value_ptr.get());
        auto b_obj = (lite::PrimitiveTValue *)(b_value_ptr.get());
        return (*a_obj) == (*b_obj);
      } else {
        return (*a_value_ptr) == (*b_value_ptr);
      }
    }
  }
  if (a.m_ptr->isa<lite::PrimitiveTValue>() && b.m_ptr->isa<lite::PrimitiveTValue>()) {
    auto a_value_node_ptr = a.m_ptr->cast<PrimitiveTValuePtr>();
    auto b_value_node_ptr = b.m_ptr->cast<PrimitiveTValuePtr>();
    return a_value_node_ptr->GetPrimitiveT()->value.type == b_value_node_ptr->GetPrimitiveT()->value.type;
  }

  return a == b;
}

bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
  // To matchCNode and Kernel's type
  if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
    return true;
  }
  return a.type() == b.type();
}

AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
  MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
  MS_EXCEPTION_IF_NULL(primitive_vars);
  if (utils::isa<VectorRef>(sexp)) {
    return HandleSexpVector(sexp, graph, primitive_vars, multigraph);
  }
  if (utils::isa<VarPtr>(sexp)) {
    auto var_ptr = utils::cast<VarPtr>(sexp);
    MS_EXCEPTION_IF_NULL(var_ptr);
    if (var_ptr->primitive()) {
      (*primitive_vars)[var_ptr->primitive()] = var_ptr;
      return NewValueNode(var_ptr->primitive());
    }
    return CreateVarNodeWithSexp(sexp, graph);
  }
  if (utils::isa<AnfNodePtr>(sexp)) {
    return utils::cast<AnfNodePtr>(sexp);
  }
  auto value_node = CreateValueNodeWithSexp(sexp);
  if (value_node == nullptr) {
    MS_LOG(EXCEPTION) << "sexp cannot converted. sexp: " + sexp.ToString();
  }
  return value_node;
}


bool IsRealCNodeKernel(const AnfNodePtr &node) {
  MS_EXCEPTION_IF_NULL(node);
  // parameter and value node is not a real cnode kernel
  if (!node->isa<CNode>()) {
    return false;
  }
  // return considered as a real node
  if (CheckPrimitiveType(node, prim::kPrimReturn)) {
    return true;
  }
  return IsRealKernel(node);
}
bool IsGraphKernel(const AnfNodePtr &node) {
  MS_EXCEPTION_IF_NULL(node);
  // graph kernel should be a real cnode kernel.
  if (!IsRealCNodeKernel(node)) {
    return false;
  }

  auto cnode = node->cast<CNodePtr>();
  MS_EXCEPTION_IF_NULL(cnode);
  auto input = cnode->input(kAnfPrimitiveIndex);
  // graph kernel should has func_graph as first input.
  if (!IsValueNode<FuncGraph>(input)) {
    return false;
  }

  auto func_graph = GetValueNode<FuncGraphPtr>(input);
  MS_EXCEPTION_IF_NULL(func_graph);
  return func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
}


void CheckIfFuncGraphIsNull(const FuncGraphPtr &graph) {
  if (graph == nullptr) {
    MS_LOG(EXCEPTION) << "The graph is null.";
  }
}

void CheckIfAnfNodeIsNull(const AnfNodePtr &node) {
  if (node == nullptr) {
    MS_LOG(EXCEPTION) << "The AnfNode is null.";
  }
}

void CheckIfCNodeIsNull(const CNodePtr &node) {
  if (node == nullptr) {
    MS_LOG(EXCEPTION) << "The CNode is null.";
  }
}

void CheckIfVarIsNull(const VarPtr &var) {
  if (var == nullptr) {
    MS_LOG(EXCEPTION) << "The Var is null.";
  }
}

void CheckIfNodeIsParam(const AnfNodePtr &node) {
  if (node != nullptr && !utils::isa<ParameterPtr>(node)) {
    MS_LOG(EXCEPTION) << "The Node is not param.";
  }
}

void CheckInputSize(const CNodePtr &node, const int size) {
  if (node->inputs().size() != size) {
    MS_LOG(EXCEPTION) << "The input size of node must be " << size << ", but it is" << node->inputs().size();
  }
}

void CheckLeastInputSize(const CNodePtr &node, const int size) {
  if (node->inputs().size() < size) {
    MS_LOG(EXCEPTION) << "The input size of node must be " << size << ", but it is" << node->inputs().size();
  }
}

AnfNodePtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num,
                          const ParamValueLitePtr &weight_tensor) {
  auto bias_parameter = func_graph->add_parameter();
  MS_ASSERT(bias_parameter != nullptr);
  std::vector<int> shape = {kernel_num};
  auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(TypeIdToType(weight_tensor->tensor_type()), shape);
  bias_parameter->set_abstract(abstract_tensor);

  ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
  MS_ASSERT(param_value != nullptr);
  param_value->set_tensor_addr(bias_data);
  param_value->set_tensor_size(kernel_num * sizeof(float) / sizeof(uint8_t));
  bias_parameter->set_default_param(param_value);
  return bias_parameter;
}

schema::PrimitiveType GetCNodeType(const BaseRef &n) {
  ValueNodePtr value_node;
  if (utils::isa<CNodePtr>(n)) {
    auto in = utils::cast<CNodePtr>(n);
    value_node = in->input(0)->cast<ValueNodePtr>();
  } else if (utils::isa<ValueNodePtr>(n)) {
    value_node = utils::cast<ValueNodePtr>(n);
  } else {
    MS_LOG(EXCEPTION) << "only value node or cnode has type";
    return schema::PrimitiveType_NONE;
  }
  MS_EXCEPTION_IF_NULL(value_node);
  auto value = value_node->value();
  MS_ASSERT(value != nullptr);
  if (utils::isa<PrimitiveTValuePtr>(value)) {
    auto primitive = value->cast<PrimitiveTValuePtr>();
    MS_ASSERT(primitive != nullptr);
    return primitive->GetPrimitiveT()->value.type;
  }
  return schema::PrimitiveType_NONE;
}

bool IsParamNode(const BaseRef &n) {
  return utils::isa<ParameterPtr>(n);
}

bool IsConvNode(const BaseRef &n) {
  if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
    auto type = opt::GetCNodeType(n);
    return type == schema::PrimitiveType_Conv2D || type == schema::PrimitiveType_DepthwiseConv2D;
  }
  return false;
}
}  // namespace opt
}  // namespace mindspore


In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include <vector>
#include <memory>
#include "src/gllo/common/gllo_utils.h"
#include "src/ir/primitive_t_value.h"
#include "frontend/operator/ops.h"

namespace mindspore {
namespace opt {
namespace {
constexpr auto kAnfPrimitiveIndex = 0;
bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
  MS_EXCEPTION_IF_NULL(node);
  if (!node->isa<CNode>()) {
    return false;
  }
  auto cnode = node->cast<CNodePtr>();
  MS_EXCEPTION_IF_NULL(cnode);
  return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type);
}

bool IsRealKernel(const AnfNodePtr &node) {
  MS_EXCEPTION_IF_NULL(node);
  // parameter and value node is not a real kernel too
  if (!node->isa<CNode>()) {
    return true;
  }
  auto cnode = node->cast<CNodePtr>();
  MS_EXCEPTION_IF_NULL(cnode);
  if (cnode->inputs().empty()) {
    MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString();
  }
  auto input = cnode->inputs()[0];
  bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) ||
      IsPrimitive(input, prim::kPrimTensorSummary) ||
      IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) ||
      IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) ||
      IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) ||
      IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial);
  return !is_virtual_node;
}

ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
  if (utils::isa<int>(sexp)) {
    return NewValueNode(utils::cast<int>(sexp));
  }
  if (utils::isa<float>(sexp)) {
    return NewValueNode(utils::cast<float>(sexp));
  }
  if (utils::isa<bool>(sexp)) {
    return NewValueNode(utils::cast<bool>(sexp));
  }
  if (utils::isa<ValuePtr>(sexp)) {
    return NewValueNode(utils::cast<ValuePtr>(sexp));
  }
  return nullptr;
}

CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
  if (utils::isa<FuncGraphPtr>(graph)) {
    return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
  }
  if (utils::isa<VarPtr>(graph)) {
    return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph));
  }
  return nullptr;
}

VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
  if (utils::isa<VarPtr>(graph)) {
    MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
    return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
  }
  if (utils::isa<FuncGraphPtr>(graph)) {
    MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString();
    return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph));
  }
  MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString();
  return nullptr;
}

AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
                            bool multigraph) {
  MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString();
  std::vector<AnfNodePtr> input_nodes;
  const auto &tuple = utils::cast<VectorRef>(sexp);
  if (multigraph && utils::isa<VarPtr>(graph)) {
    for (auto &x : tuple) {
      AnfNodePtr node = SexpToNode(x, std::make_shared<Var>("G"), primitive_vars, true);
      input_nodes.push_back(node);
    }
    VarPtr var_ptr = utils::cast<VarPtr>(graph);
    return std::make_shared<CNode>(input_nodes, var_ptr);
  }

  for (auto &x : tuple) {
    AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph);
    input_nodes.push_back(node);
  }
  return CreateCNodeWithGraph(input_nodes, graph);
}
}  // namespace

bool AnfEqual(const BaseRef &a, const BaseRef &b) {
  if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
    auto a_node = utils::cast<AnfNodePtr>(a);
    auto b_node = utils::cast<AnfNodePtr>(b);
    MS_EXCEPTION_IF_NULL(a_node);
    MS_EXCEPTION_IF_NULL(b_node);
    if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
      auto a_value_node = a_node->cast<ValueNodePtr>();
      MS_EXCEPTION_IF_NULL(a_value_node);
      auto a_value = a_value_node->value();
      MS_EXCEPTION_IF_NULL(a_value);
      auto a_prim = a_value->cast<PrimitivePtr>();
      MS_EXCEPTION_IF_NULL(a_prim);

      auto b_value_node = b_node->cast<ValueNodePtr>();
      MS_EXCEPTION_IF_NULL(b_value_node);
      auto b_value = b_value_node->value();
      MS_EXCEPTION_IF_NULL(b_value);
      auto b_prim = b_value->cast<PrimitivePtr>();
      MS_EXCEPTION_IF_NULL(b_prim);

      return a_prim->name() == b_prim->name();
    } else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {
      auto a_value_node_ptr = a_node->cast<ValueNodePtr>();
      if (a_value_node_ptr == nullptr) {
        MS_LOG(EXCEPTION) << "cast value node ptr fail";
      }
      auto a_value_ptr = a_value_node_ptr->value();
      if (a_value_ptr == nullptr) {
        MS_LOG(EXCEPTION) << "value ptr is nullptr";
      }

      auto b_value_node_ptr = b_node->cast<ValueNodePtr>();
      if (b_value_node_ptr == nullptr) {
        MS_LOG(EXCEPTION) << "cast value node ptr fail";
      }
      auto b_value_ptr = b_value_node_ptr->value();
      if (b_value_ptr == nullptr) {
        MS_LOG(EXCEPTION) << "value ptr is nullptr";
      }

      if (utils::isa<lite::PrimitiveTValue>(a_value_ptr) && utils::isa<lite::PrimitiveTValue>(b_value_ptr)) {
        auto a_obj = (lite::PrimitiveTValue *) (a_value_ptr.get());
        auto b_obj = (lite::PrimitiveTValue *) (b_value_ptr.get());
        return (*a_obj) == (*b_obj);
      } else {
        return (*a_value_ptr) == (*b_value_ptr);
      }
    }
  }
  if (a.m_ptr->isa<lite::PrimitiveTValue>() && b.m_ptr->isa<lite::PrimitiveTValue>()) {
    auto a_value_node_ptr = a.m_ptr->cast<PrimitiveTValuePtr>();
    auto b_value_node_ptr = b.m_ptr->cast<PrimitiveTValuePtr>();
    return a_value_node_ptr->GetPrimitiveT()->value.type == b_value_node_ptr->GetPrimitiveT()->value.type;
  }

  return a == b;
}

bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
  // To matchCNode and Kernel's type
  if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
    return true;
  }
  return a.type() == b.type();
}

AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
  MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
  MS_EXCEPTION_IF_NULL(primitive_vars);
  if (utils::isa<VectorRef>(sexp)) {
    return HandleSexpVector(sexp, graph, primitive_vars, multigraph);
  }
  if (utils::isa<VarPtr>(sexp)) {
    auto var_ptr = utils::cast<VarPtr>(sexp);
    MS_EXCEPTION_IF_NULL(var_ptr);
    if (var_ptr->primitive()) {
      (*primitive_vars)[var_ptr->primitive()] = var_ptr;
      return NewValueNode(var_ptr->primitive());
    }
    return CreateVarNodeWithSexp(sexp, graph);
  }
  if (utils::isa<AnfNodePtr>(sexp)) {
    return utils::cast<AnfNodePtr>(sexp);
  }
  auto value_node = CreateValueNodeWithSexp(sexp);
  if (value_node == nullptr) {
    MS_LOG(EXCEPTION) << "sexp cannot converted. sexp: " + sexp.ToString();
  }
  return value_node;
}

bool IsRealCNodeKernel(const AnfNodePtr &node) {
  MS_EXCEPTION_IF_NULL(node);
  // parameter and value node is not a real cnode kernel
  if (!node->isa<CNode>()) {
    return false;
  }
  // return considered as a real node
  if (CheckPrimitiveType(node, prim::kPrimReturn)) {
    return true;
  }
  return IsRealKernel(node);
}
bool IsGraphKernel(const AnfNodePtr &node) {
  MS_EXCEPTION_IF_NULL(node);
  // graph kernel should be a real cnode kernel.
  if (!IsRealCNodeKernel(node)) {
    return false;
  }

  auto cnode = node->cast<CNodePtr>();
  MS_EXCEPTION_IF_NULL(cnode);
  auto input = cnode->input(kAnfPrimitiveIndex);
  // graph kernel should has func_graph as first input.
  if (!IsValueNode<FuncGraph>(input)) {
    return false;
  }

  auto func_graph = GetValueNode<FuncGraphPtr>(input);
  MS_EXCEPTION_IF_NULL(func_graph);
  return func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
}

void CheckIfFuncGraphIsNull(const FuncGraphPtr &graph) {
  if (graph == nullptr) {
    MS_LOG(EXCEPTION) << "The graph is null.";
  }
}

void CheckIfAnfNodeIsNull(const AnfNodePtr &node) {
  if (node == nullptr) {
    MS_LOG(EXCEPTION) << "The AnfNode is null.";
  }
}

void CheckIfCNodeIsNull(const CNodePtr &node) {
  if (node == nullptr) {
    MS_LOG(EXCEPTION) << "The CNode is null.";
  }
}

void CheckIfVarIsNull(const VarPtr &var) {
  if (var == nullptr) {
    MS_LOG(EXCEPTION) << "The Var is null.";
  }
}

void CheckIfNodeIsParam(const AnfNodePtr &node) {
  if (node != nullptr && !utils::isa<ParameterPtr>(node)) {
    MS_LOG(EXCEPTION) << "The Node is not param.";
  }
}

void CheckInputSize(const CNodePtr &node, const int size) {
  if (node->inputs().size() != size) {
    MS_LOG(EXCEPTION) << "The input size of node must be " << size << ", but it is" << node->inputs().size();
  }
}

void CheckLeastInputSize(const CNodePtr &node, const int size) {
  if (node->inputs().size() < size) {
    MS_LOG(EXCEPTION) << "The input size of node must be " << size << ", but it is" << node->inputs().size();
  }
}

AnfNodePtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num,
                          const ParamValueLitePtr &weight_tensor) {
  auto bias_parameter = func_graph->add_parameter();
  MS_ASSERT(bias_parameter != nullptr);
  std::vector<int> shape = {kernel_num};
  auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(TypeIdToType(weight_tensor->tensor_type()), shape);
  bias_parameter->set_abstract(abstract_tensor);

  ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
  MS_ASSERT(param_value != nullptr);
  param_value->set_tensor_addr(bias_data);
  param_value->set_tensor_size(kernel_num * sizeof(float) / sizeof(uint8_t));
  bias_parameter->set_default_param(param_value);
  return bias_parameter;
}

schema::PrimitiveType GetCNodeType(const BaseRef &n) {
  ValueNodePtr value_node;
  if (utils::isa<CNodePtr>(n)) {
    auto in = utils::cast<CNodePtr>(n);
    value_node = in->input(0)->cast<ValueNodePtr>();
  } else if (utils::isa<ValueNodePtr>(n)) {
    value_node = utils::cast<ValueNodePtr>(n);
  } else {
    MS_LOG(EXCEPTION) << "only value node or cnode has type";
    return schema::PrimitiveType_NONE;
  }
  MS_EXCEPTION_IF_NULL(value_node);
  auto value = value_node->value();
  MS_ASSERT(value != nullptr);
  if (utils::isa<PrimitiveTValuePtr>(value)) {
    auto primitive = value->cast<PrimitiveTValuePtr>();
    MS_ASSERT(primitive != nullptr);
    return primitive->GetPrimitiveT()->value.type;
  }
  return schema::PrimitiveType_NONE;
}

bool IsParamNode(const BaseRef &n) {
  return utils::isa<ParameterPtr>(n);
}

bool IsConvNode(const BaseRef &n) {
  if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
    auto type = opt::GetCNodeType(n);
    return type == schema::PrimitiveType_Conv2D || type == schema::PrimitiveType_DepthwiseConv2D;
  }
  return false;
}
}  // namespace opt
}  // namespace mindspore


In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_UTILS_H_
#define MINDSPORE_LITE_SRC_PASS_COMMON_UTILS_H_

#include <mindspore/lite/src/ir/primitive_t_value.h>
#include <memory>
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "src/common/utils.h"
#include "mindspore/ccsrc/backend/optimizer/common/pattern_engine.h"
#include "schema/inner/model_generated.h"
#include "src/param_value_lite.h"

using PrimitiveTValuePtr = std::shared_ptr<mindspore::lite::PrimitiveTValue>;
namespace mindspore {
namespace opt {

bool AnfEqual(const BaseRef &a, const BaseRef &b);

bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b);

AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
                      bool multigraph = false);

bool IsGraphKernel(const AnfNodePtr &node);

bool IsRealCNodeKernel(const AnfNodePtr &node);

void CheckIfFuncGraphIsNull(const FuncGraphPtr &graph);

void CheckIfAnfNodeIsNull(const AnfNodePtr &node);

void CheckIfCNodeIsNull(const CNodePtr &node);

void CheckIfVarIsNull(const VarPtr &var);

void CheckInputSize(const CNodePtr &node, int size);

void CheckIfNodeIsParam(const AnfNodePtr &node);

void CheckLeastInputSize(const CNodePtr &node, int size);

AnfNodePtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num,
                          const ParamValueLitePtr &weight_tensor);

schema::PrimitiveType GetCNodeType(const BaseRef &node);

bool IsParamNode(const BaseRef &n);

bool IsConvNode(const BaseRef &n);
}  // namespace opt
}  // namespace mindspore
#endif  // MINDSPORE_LITE_SRC_PASS_COMMON_UTILS_H_


In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "mindspore/ccsrc/backend/optimizer/common/node_pass.h"

#include <unordered_set>
#include <deque>
#include <algorithm>

#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ir/manager.h"
#include "mindspore/lite/src/gllo/common/gllo_utils.h"

namespace mindspore {
namespace opt {
bool NodePass::Run(const FuncGraphPtr &func_graph) {
  MS_EXCEPTION_IF_NULL(func_graph);
  FuncGraphManagerPtr manager = func_graph->manager();
  MS_EXCEPTION_IF_NULL(manager);
  manager->AddFuncGraph(func_graph);

  std::unordered_set<AnfNodePtr> seen_node;
  std::deque<AnfNodePtr> todo{func_graph->output()};
  bool changes = false;
  while (!todo.empty()) {
    AnfNodePtr node = todo.front();
    todo.pop_front();
    if (seen_node.count(node) > 0 || !manager->all_nodes().contains(node)) {
      continue;
    }
    (void)seen_node.insert(node);
    AnfNodePtr new_node = Run(func_graph, node);
    bool change = (new_node != nullptr);
    if (new_node != nullptr && new_node != node) {
      (void)manager->Replace(node, new_node);
      (void)seen_node.erase(node);
    } else if (new_node == nullptr) {
      new_node = node;
    }
    if (new_node && IsValueNode<FuncGraph>(new_node)) {
      auto const_func_graph = GetValueNode<FuncGraphPtr>(new_node);
      MS_EXCEPTION_IF_NULL(const_func_graph);
      todo.push_back(const_func_graph->output());
    } else if (new_node && new_node->isa<CNode>()) {
      if (IsGraphKernel(new_node)) {
        todo.push_back(new_node);
      }
      auto cnode = new_node->cast<CNodePtr>();
      MS_EXCEPTION_IF_NULL(cnode);
      auto inputs = cnode->inputs();
      (void)todo.insert(todo.end(), inputs.begin(), inputs.end());
    }
    changes = changes || change;
  }
  return changes;
}
}  // namespace opt
}  // namespace mindspore



In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "src/gllo/common/optimizer.h"

#include <functional>
#include <memory>
#include <string>
#include <vector>
#include <algorithm>
#include <utility>
#include <initializer_list>

#include "mindspore/ccsrc/backend/optimizer/common/pass_manager.h"
#include "ir/manager.h"

namespace mindspore {
namespace opt {
PatternProcessPass::PatternProcessPass(const std::string &name, bool multigraph)
    : NodePass(name),
      multigraph_(multigraph),
      pattern_engine_(PatternEngine(std::make_shared<DefaultVisitor>(),
                                    std::function<bool(const BaseRef &, const BaseRef &)>(AnfEqual),
                                    std::function<bool(const BaseRef &, const BaseRef &)>(CNodeTypeEqual))),
      primitive_vars_(std::make_shared<PrimitiveVarMap>()) {}

const BaseRef PatternProcessPass::DefinePattern() const {
  VarPtr X = std::make_shared<Var>();
  return BaseRef({X});
}

void PatternProcessPass::Build() {
  VarPtr fg = std::make_shared<Var>("RootG");
  BaseRef pattern = std::move(DefinePattern());
  pattern_ = SexpToNode(pattern, fg, primitive_vars_.get(), multigraph_);
}

AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
  if (pattern_ == nullptr) {
    Build();
  }

  auto empty_equiv = std::make_shared<Equiv>();
  MS_EXCEPTION_IF_NULL(primitive_vars_);
  EquivPtr equiv = pattern_engine_.Match(pattern_, node, *primitive_vars_, empty_equiv);
  if (equiv != nullptr && !equiv->empty()) {
    return Process(func_graph, node, equiv);
  }
  return nullptr;
}

bool MultipleOutputPatternProcessPass::MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const {
  MS_EXCEPTION_IF_NULL(node);
  MS_EXCEPTION_IF_NULL(equiv);
  VarPtr fg = std::make_shared<Var>("RootG");
  auto empty_equiv = std::make_shared<Equiv>();
  MS_EXCEPTION_IF_NULL(child_primitive_vars_);
  EquivPtr another_equiv =
    child_pattern_engine_.Match(SexpToNode(DefineAnotherPattern(), fg, child_primitive_vars_.get(), true), node,
                                *child_primitive_vars_, empty_equiv);
  if (another_equiv != nullptr && !another_equiv->empty()) {
    return IsShareNodes(equiv, another_equiv);
  }
  return false;
}

void GraphOptimizer::AddPassManager(const PassManagerPtr &pass_manager) {
  if (pass_manager != nullptr) {
    pass_managers_.push_back(pass_manager);
  }
}

FuncGraphPtr GraphOptimizer::Optimize(const FuncGraphPtr &func_graph, bool run_only_once) {
  MS_EXCEPTION_IF_NULL(func_graph);
  run_only_once_ = (pass_managers_.size() == 1) ? true : run_only_once;
  auto manager = func_graph->manager();
  if (manager == nullptr) {
    manager = Manage(func_graph, false);
    func_graph->set_manager(manager);
  }

  bool changed = true;
  while (changed) {
    changed = false;
    for (size_t i = 0; i < pass_managers_.size(); ++i) {
      const PassManagerPtr &pm = pass_managers_[i];
      if (pm != nullptr && pm->Run(func_graph)) {
        changed = true;
      }
    }
    if (run_only_once_) {
      break;
    }
  }

  std::vector<FuncGraphPtr> func_graphs;
  func_graphs.push_back(func_graph);
  manager->KeepRoots(func_graphs);
  (void)TopoSort(func_graph->get_return());
  return func_graph;
}
}  // namespace opt
}  // namespace mindspore



In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_OPTIMIZER_H_
#define MINDSPORE_LITE_SRC_PASS_COMMON_OPTIMIZER_H_

#include <memory>
#include <string>
#include <vector>
#include <unordered_map>

#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ir/graph_utils.h"
#include "src/common/utils.h"

#include "mindspore/ccsrc/backend/optimizer/common/pass_manager.h"
#include "mindspore/ccsrc/backend/optimizer/common/pattern_engine.h"
#include "mindspore/lite/src/gllo/common/gllo_utils.h"

namespace mindspore {
namespace opt {
using PatternListType = std::initializer_list<BaseRef>;

class PatternProcessPass : public NodePass {
 public:
  explicit PatternProcessPass(const std::string &name = "", bool multigraph = true);
  ~PatternProcessPass() override = default;
  virtual const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const = 0;
  virtual const BaseRef DefinePattern() const;
  AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;

 private:
  void Build();

  AnfNodePtr pattern_ = nullptr;
  bool multigraph_ = true;
  PatternEngine pattern_engine_;
  PrimitiveVarMapPtr primitive_vars_;
};

class MultipleOutputPatternProcessPass : public PatternProcessPass {
 public:
  explicit MultipleOutputPatternProcessPass(const std::string &name = "", bool multigraph = true)
      : PatternProcessPass(name, multigraph),
        child_pattern_engine_(PatternEngine(std::make_shared<DefaultVisitor>(),
                                            std::function<bool(const BaseRef &, const BaseRef &)>(AnfEqual),
                                            std::function<bool(const BaseRef &, const BaseRef &)>(CNodeTypeEqual))),
        child_primitive_vars_(std::make_shared<PrimitiveVarMap>()) {}
  ~MultipleOutputPatternProcessPass() override = default;
  virtual BaseRef DefineAnotherPattern() const = 0;
  // check two patterns whether share the same nodes or not
  virtual bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const = 0;

 protected:
  bool MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const;
  PatternEngine child_pattern_engine_;
  PrimitiveVarMapPtr child_primitive_vars_;
};

class GraphOptimizer {
 public:
  explicit GraphOptimizer(const std::string &name = "graph_optimizer") : name_(name) {}
  virtual ~GraphOptimizer() = default;

  void AddPassManager(const PassManagerPtr &pass_manager);
  FuncGraphPtr Optimize(const FuncGraphPtr &func_graph, bool run_only_once = true);

 private:
  const std::string name_ = "graph_optimizer";
  std::vector<PassManagerPtr> pass_managers_{};
  bool run_only_once_ = true;
};
}  // namespace opt
}  // namespace mindspore

#endif  // MINDSPORE_LITE_SRC_PASS_COMMON_OPTIMIZER_H_



In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "mindspore/ccsrc/backend/optimizer/common/pass_manager.h"

#include <sys/time.h>
#include <unordered_set>
#include <deque>
#include <string>
#include <algorithm>

#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ir/manager.h"
#include "utils/utils.h"

namespace mindspore {
namespace opt {
const std::vector<PassPtr> &PassManager::Passes() const { return passes_; }

void PassManager::AddPass(const PassPtr &pass) {
  if (pass != nullptr) {
    passes_.push_back(pass);
  }
}

bool PassManager::Run(const FuncGraphPtr &func_graph, const std::vector<PassPtr> &passes) const {
  if (func_graph == nullptr) {
    return false;
  }
  bool changed = false;
  size_t num = 0;
  for (const auto &pass : passes) {
    if (pass != nullptr) {
#if defined(_WIN32) || defined(_WIN64)
      auto start_time = std::chrono::steady_clock::now();
#else
      struct timeval start_time {};
      struct timeval end_time {};
      (void)gettimeofday(&start_time, nullptr);
#endif
      if (pass->Run(func_graph)) {
        MS_LOG(DEBUG) << "Run pass and find change";
        changed = true;
      }
#if defined(_WIN32) || defined(_WIN64)
      auto end_time = std::chrono::steady_clock::now();
      std::chrono::duration<double, std::ratio<1, 1000000>> cost = end_time - start_time;
      MS_LOG(INFO) << "Run pass hwopt_" + name() + "_" << num << "_" + pass->name() + " in " << cost.count() << " us";
#else
      (void)gettimeofday(&end_time, nullptr);
      const uint64_t kUSecondInSecond = 1000000;
      uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
      cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
      MS_LOG(INFO) << "Run pass hwopt_" + name() + "_" << num << "_" + pass->name() + " in " << cost << " us";
#endif
      num++;
    }
  }
  return changed;
}

bool PassManager::Run(const FuncGraphPtr &func_graph) const {
  bool changed = false;
  // run all passes
  bool change = true;
  while (change) {
    change = Run(func_graph, passes_);
    changed = change || changed;
    if (run_only_once_) {
      break;
    }
  }
  return changed;
}
}  // namespace opt
}  // namespace mindspore


#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

training

In [None]:
def run():
    device = xm.xla_device()
    model = MX.to(device)

    train_dataset = TweetDataset(
        tweet=df_train.text.values,
        sentiment=df_train.sentiment.values,
        selected_text=df_train.selected_text.values
    )

    train_sampler = torch.utils.data.distributed.DistributedSampler(
      train_dataset,
      num_replicas=xm.xrt_world_size(),
      rank=xm.get_ordinal(),
      shuffle=True
    )

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN_BATCH_SIZE,
        sampler=train_sampler,
        drop_last=True,
        num_workers=2
    )

    valid_dataset = TweetDataset(
        tweet=df_valid.text.values,
        sentiment=df_valid.sentiment.values,
        selected_text=df_valid.selected_text.values
    )

    valid_sampler = torch.utils.data.distributed.DistributedSampler(
      valid_dataset,
      num_replicas=xm.xrt_world_size(),
      rank=xm.get_ordinal(),
      shuffle=False
    )

    valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.VALID_BATCH_SIZE,
        sampler=valid_sampler,
        drop_last=False,
        num_workers=1
    )

    num_train_steps = int(len(df_train) / config.TRAIN_BATCH_SIZE * config.EPOCHS)
    param_optimizer = list(model.named_parameters())
    no_decay = [
        "bias",
        "LayerNorm.bias",
        "LayerNorm.weight"
    ]
    optimizer_parameters = [
        {
            'params': [
                p for n, p in param_optimizer if not any(
                    nd in n for nd in no_decay
                )
            ], 
         'weight_decay': 0.001
        },
        {
            'params': [
                p for n, p in param_optimizer if any(
                    nd in n for nd in no_decay
                )
            ], 
            'weight_decay': 0.0
        },
    ]
    num_train_steps = int(
        len(df_train) / config.TRAIN_BATCH_SIZE / xm.xrt_world_size() * config.EPOCHS
    )
    optimizer = AdamW(
        optimizer_parameters, 
        lr=config.LEARNING_RATE * xm.xrt_world_size()
    )
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=num_train_steps
    )

    best_jac = 0
    es = EarlyStopping(patience=2, mode="max")
    num_batches = int(len(df_train) / (config.TRAIN_BATCH_SIZE * xm.xrt_world_size()))
    
    xm.master_print("Training is Starting....")

    for epoch in range(config.EPOCHS):
        para_loader = pl.ParallelLoader(train_data_loader, [device])
        train_fn(
            para_loader.per_device_loader(device), 
            model, 
            optimizer, 
            device,
            num_batches,
            scheduler
        )

        para_loader = pl.ParallelLoader(valid_data_loader, [device])
        jac = eval_fn(
            para_loader.per_device_loader(device), 
            model, 
            device
        )
        jac = xm.mesh_reduce('jac_reduce', jac, reduce_fn)
        xm.master_print(f'Epoch={epoch}, Jaccard={jac}')
        if jac > best_jac:
            xm.master_print("Model Improved!!! Saving Model")
            xm.save(model.state_dict(), f"model_{config.FOLD}.bin")
            best_jac = jac

In [None]:
def _mp_fn(rank, flags):
    torch.set_default_tensor_type('torch.FloatTensor')
    a = run()

In [None]:
FLAGS={}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')

Training is Starting....


HBox(children=(FloatProgress(value=0.0, description='Training', max=171.0, style=ProgressStyle(description_wid…