<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

#Data process

In [None]:
From a3e00c43755bcd21fe29ab730b0c5e7d47d147c5 Mon Sep 17 00:00:00 2001
From: guohongzilong <guohongzilong@huawei.com>
Date: Sat, 3 Apr 2021 19:10:22 +0800
Subject: [PATCH] modify code structure

---
 .../main/java/com/huawei/flclient/LiteTrain.java   | 119 ++---
 .../main/java/com/huawei/flclient/NativeTrain.java |  46 ++
 .../lite/flclient/src/main/native/CMakeLists.txt   |  18 +-
 .../lite/flclient/src/main/native/bert_train.cpp   | 197 ++++++++
 .../lite/flclient/src/main/native/bert_train.h     |  26 ++
 .../lite/flclient/src/main/native/data_prepare.cpp | 517 ---------------------
 .../src/main/native/dataset/data_prepare.cpp       | 517 +++++++++++++++++++++
 .../lite/flclient/src/main/native/lenet_train.cpp  | 167 +++++++
 .../lite/flclient/src/main/native/lenet_train.h    |  27 ++
 .../lite/flclient/src/main/native/lite_train.cpp   | 287 ------------
 .../lite/flclient/src/main/native/lite_train.h     |  35 --
 .../flclient/src/main/native/lite_train_jni.cpp    | 131 ++++--
 mindspore/lite/flclient/src/main/native/util.cpp   | 100 ++++
 mindspore/lite/flclient/src/main/native/util.h     |  33 ++
 14 files changed, 1257 insertions(+), 963 deletions(-)
 create mode 100644 mindspore/lite/flclient/src/main/java/com/huawei/flclient/NativeTrain.java
 create mode 100644 mindspore/lite/flclient/src/main/native/bert_train.cpp
 create mode 100644 mindspore/lite/flclient/src/main/native/bert_train.h
 delete mode 100644 mindspore/lite/flclient/src/main/native/data_prepare.cpp
 create mode 100644 mindspore/lite/flclient/src/main/native/dataset/data_prepare.cpp
 create mode 100644 mindspore/lite/flclient/src/main/native/lenet_train.cpp
 create mode 100644 mindspore/lite/flclient/src/main/native/lenet_train.h
 delete mode 100644 mindspore/lite/flclient/src/main/native/lite_train.cpp
 delete mode 100644 mindspore/lite/flclient/src/main/native/lite_train.h
 create mode 100644 mindspore/lite/flclient/src/main/native/util.cpp
 create mode 100644 mindspore/lite/flclient/src/main/native/util.h

diff --git a/mindspore/lite/flclient/src/main/java/com/huawei/flclient/LiteTrain.java b/mindspore/lite/flclient/src/main/java/com/huawei/flclient/LiteTrain.java
index 5f9a8de..6c8cfd2 100644
--- a/mindspore/lite/flclient/src/main/java/com/huawei/flclient/LiteTrain.java
+++ b/mindspore/lite/flclient/src/main/java/com/huawei/flclient/LiteTrain.java
@@ -1,21 +1,6 @@
-/**
- * Copyright 2020 Huawei Technologies Co., Ltd
- * <p>
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- * <p>
- * http://www.apache.org/licenses/LICENSE-2.0
- * <p>
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
 package com.huawei.flclient;
 
+
 import com.google.flatbuffers.FlatBufferBuilder;
 import mindspore.schema.FeatureMap;
 import mindspore.schema.FeatureMapList;
@@ -27,15 +12,13 @@ import java.util.Iterator;
 import java.util.Map;
 import java.util.TreeMap;
 
-public  class LiteTrain {
-    static {
-        System.loadLibrary("fl");
-    }
-    private static HashMap featureMaps;
+public class LiteTrain {
     private static LiteTrain train;
+    private long sessionPtr;
 
-    private LiteTrain() {
+    private LenetTrain() {
     }
+
     public static synchronized LiteTrain getInstance() {
         if (train == null) {
             train = new LiteTrain();
@@ -43,9 +26,9 @@ public  class LiteTrain {
         return train;
     }
 
-    public FlatBufferBuilder FeatureMapBuilder(String modelName){
+    public FlatBufferBuilder FeatureMapBuilder(String modelName) {
         FlatBufferBuilder builder = new FlatBufferBuilder();
-        int[] fmOffsets = getSeralizeFeaturesMap(modelName,builder);
+        int[] fmOffsets = getSeralizeFeaturesMap(builder);
         int fmOffset = FeatureMapList.createFeatureMapVector(builder, fmOffsets);
         RequestUpdateModel.startRequestUpdateModel(builder);
         RequestUpdateModel.addFlName(builder, 0);
@@ -86,60 +69,42 @@ public  class LiteTrain {
         }
         return map;
     }
+    // todo add msconfig
+    public int init(String modelPath) {
+         sessionPtr = NativeTrain.createSession(modelPath,0L);
+         return 0;
+    }
 
+    public int setInput(String fileSet) {
+        return NativeTrain.setInput(fileSet);
+    }
+
+    public int train(int batch_size,int epoches, int earlyStopMod) {
+        return NativeTrain.train(sessionPtr, epoches, earlyStopMod);
+    }
+
+    public float infer() {
+        return NativeTrain.infer(sessionPtr);
+    }
 
+    public int[] getInferLabels() {
+        return NativeTrain.getInferLables(sessionPtr);
+    }
+
+    public Map<String, float[]> getFeaturesMap() {
+        return NativeTrain.getFeaturesMap(sessionPtr);
+    }
+
+    public int[] getSeralizeFeaturesMap(FlatBufferBuilder builder) {
+        return NativeTrain.getSeralizeFeaturesMap(sessionPtr, builder);
+    }
+
+    public int updateFeatures(ArrayList<FeatureMap> featureMaps) {
+        return NativeTrain.updateFeatures(sessionPtr, featureMaps);
+    }
+
+    public int free() {
+        return NativeTrain.free(sessionPtr);
+    }
 
-    /**
-     * set the Inference set or Train set
-     *
-     * @param fileSet   input binary file path which format is NHWC
-     * @param batch_num binary file batch num
-     * @return
-     */
-    native int setInput(String fileSet, int batch_num);
-
-    /**
-     * inference
-     *
-     * @return status
-     */
-    public native float inference(String modelName,int batch_num);
-
-    /**
-     * train
-     *
-     * @return status
-     */
-    public native int train(String modelName, int batch_num,int iterations);
-
-    /**
-     * get the features map of training model
-     *
-     * @param builder FlatBufferBuilder
-     * @return features offset
-     */
-   native int[] getSeralizeFeaturesMap(String modelName, FlatBufferBuilder builder);
-
-         /**
-             * get the features map of training model
-             *
-             * @param modelName model name
-             * @return features map
-             */
-   native Map<String,float[]> getFeaturesMap(String modelName);
-
-    /**
-     * update the features map of training model
-     *
-     * @param featureMaps
-     * @return status
-     */
-    native int updateFeatures(String modelName,ArrayList<FeatureMap> featureMaps);
-
-    /**
-     * free Inference or Train runtime memory resource
-     *
-     * @return status
-     */
-    native int free();
 }
diff --git a/mindspore/lite/flclient/src/main/java/com/huawei/flclient/NativeTrain.java b/mindspore/lite/flclient/src/main/java/com/huawei/flclient/NativeTrain.java
new file mode 100644
index 0000000..b0dc4d0
--- /dev/null
+++ b/mindspore/lite/flclient/src/main/java/com/huawei/flclient/NativeTrain.java
@@ -0,0 +1,46 @@
+/**
+ * Copyright 2020 Huawei Technologies Co., Ltd
+ * <p>
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.huawei.flclient;
+
+import com.google.flatbuffers.FlatBufferBuilder;
+import mindspore.schema.FeatureMap;
+import java.util.ArrayList;
+import java.util.Map;
+
+public  class NativeTrain {
+    static {
+        System.loadLibrary("fl");
+    }
+
+    static native int setInput(String fileSet);
+
+    static native long createSession(String modelPath,long msConfigPtr);
+
+    static native float infer(long sessionPtr);
+
+    static native int[]  getInferLables(long sessionPtr);
+
+    static native int train(long sessionPtr,int batch_size,int epoches,int earlyStopMod);
+
+   static native int[] getSeralizeFeaturesMap(long sessionPtr, FlatBufferBuilder builder);
+
+   static native Map<String,float[]> getFeaturesMap(long sessionPtr);
+
+   static native int updateFeatures(long sessionPtr,ArrayList<FeatureMap> featureMaps);
+
+   static native int free(long sessionPtr);
+}
diff --git a/mindspore/lite/flclient/src/main/native/CMakeLists.txt b/mindspore/lite/flclient/src/main/native/CMakeLists.txt
index 5171696..89bc06f 100644
--- a/mindspore/lite/flclient/src/main/native/CMakeLists.txt
+++ b/mindspore/lite/flclient/src/main/native/CMakeLists.txt
@@ -20,6 +20,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
 
 include_directories(${CMAKE_CURRENT_SOURCE_DIR})
 include_directories(${CMAKE_CURRENT_SOURCE_DIR}/linux)
+include_directories(${CMAKE_CURRENT_SOURCE_DIR}/dataset)
 
 
 include_directories(${LITE_DIR}) ## lite include
@@ -28,17 +29,18 @@ include_directories(${TOP_DIR}/mindspore/core/) ## core include
 include_directories(${LITE_DIR}/build) ## flatbuffers
 
 set(OP_SRC
-        lite_train.cpp
-        data_prepare.cpp
-            )
-
-set(SRC_FILES
         lite_train_jni.cpp
-        )
+        util.cpp
+        bert_train.cpp
+        lenet_train.cpp
+        dataset/CustomizedTokenizer.cc
+            )
 find_library(log-lib glog)
 
-add_library(fl SHARED ${SRC_FILES} ${OP_SRC})
-target_link_libraries(fl mindspore-lite  glog)
+add_library(fl SHARED ${OP_SRC})
+
 link_directories(${CMAKE_CURRENT_SOURCE_DIR}/lib/)
 
 install(TARGETS fl LIBRARY DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/lib)
+add_executable(test test_train.cc ${OP_SRC} )
+target_link_libraries(test mindspore-lite  glog)
\ No newline at end of file
diff --git a/mindspore/lite/flclient/src/main/native/bert_train.cpp b/mindspore/lite/flclient/src/main/native/bert_train.cpp
new file mode 100644
index 0000000..d5fca9a
--- /dev/null
+++ b/mindspore/lite/flclient/src/main/native/bert_train.cpp
@@ -0,0 +1,197 @@
+/**
+ * Copyright 2020 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "bert_train.h"
+#include "util.h"
+#include <cstring>
+#include <fstream>
+#include <iostream>
+#include "include/context.h"
+#include "include/errorcode.h"
+#include "src/common/log_adapter.h"
+#include "dataset/CustomizedTokenizer.h"
+#include <climits>
+
+static int *lable_ids = 0;
+static int *input_ids = 0;
+static int *input_mask = 0;
+static int *token_type_ids = 0;
+static int  batch_num = 0;
+static int total_size = 0;
+
+#define MAX_SEQ_LENGTH 32
+#define BATCH_SIZE 16
+#define LABEL_CLASS 107
+
+std::vector<int> FillBertInputData(mindspore::session::TrainSession *train_session, int batch_idx) {
+  auto inputs = train_session->GetInputs();
+  int batch_size = inputs[0]->shape()[0];
+
+  auto model_input_mask = reinterpret_cast<int *>(inputs.at(0)->MutableData());
+  auto model_input_ids = reinterpret_cast<int *>(inputs.at(1)->MutableData());
+  auto model_token_id = reinterpret_cast<int *>(inputs.at(2)->MutableData());
+  auto model_label_ids = reinterpret_cast<int *>(inputs.at(3)->MutableData());
+  std::vector<int> labels_vec(batch_size);
+  for (int i = 0; i < batch_size; i++) {
+    std::memcpy(model_input_mask + i * MAX_SEQ_LENGTH,
+                input_mask + +batch_idx * inputs[0]->ElementsNum() + i * MAX_SEQ_LENGTH, MAX_SEQ_LENGTH * sizeof(int));
+    std::memcpy(model_input_ids + i * MAX_SEQ_LENGTH,
+                input_ids + +batch_idx * inputs[1]->ElementsNum() + i * MAX_SEQ_LENGTH, MAX_SEQ_LENGTH * sizeof(int));
+    std::memcpy(model_token_id + i * MAX_SEQ_LENGTH,
+                token_type_ids + +batch_idx * inputs[2]->ElementsNum() + i * MAX_SEQ_LENGTH, MAX_SEQ_LENGTH * sizeof(int));
+    model_label_ids[i] = lable_ids[batch_idx*batch_size+i];
+    labels_vec[i] = lable_ids[batch_idx*batch_size+i];
+
+  }
+//  std::ofstream  ofs("model_input_mask.bin", std::ios::binary | std::ios::out);
+//  ofs.write((const char*)model_input_mask, sizeof(int) * inputs[0]->ElementsNum());
+//  ofs.close();
+//
+//  std::ofstream  ofs1("model_input_ids.bin", std::ios::binary | std::ios::out);
+//  ofs1.write((const char*)model_input_ids, sizeof(int) * inputs[1]->ElementsNum());
+//  ofs1.close();
+//
+//  std::ofstream  ofs3("model_label_ids.bin", std::ios::binary | std::ios::out);
+//  ofs3.write((const char*)model_label_ids, sizeof(int) * inputs[3]->ElementsNum());
+//  ofs3.close();
+//
+//  std::ofstream  ofs2("model_token_id.bin", std::ios::binary | std::ios::out);
+//  ofs2.write((const char*)model_token_id, sizeof(int) * inputs[2]->ElementsNum());
+//  ofs2.close();
+  return labels_vec;
+}
+
+
+
+// net inference function
+float InferBert(TrainSession *session) {
+  auto labels = FillBertInputData(session, 0);
+  auto infer_acc = CalculateAccuracy(session, labels,LABEL_CLASS);
+  std::cout << "inference acc is:" << infer_acc << std::endl;
+  return infer_acc;
+}
+
+
+// net training function
+int TrainBert(TrainSession *session,const std::string &ms_file,int batch_size, int epoches) {
+  if (epoches <= 0) {
+    std::cout << "error iterations or epoch!, epoch:"
+                  << ", iterations" << epoches;
+    return mindspore::lite::RET_ERROR;
+  }
+  batch_num = total_size/batch_size;
+  std::cout << "total epoches :" << epoches << ",batch_num:" << batch_num<<std::endl;
+  for (int j = 0; j < epoches; ++j) {
+    float sum_loss_per_epoch = 0.0f;
+    float sum_acc_per_epoch = 0.0f;
+    for (int k = 0; k < batch_num; ++k) {
+      session->Train();
+      auto lables = FillBertInputData(session, k);
+      session->RunGraph(nullptr, nullptr);
+      auto loss = GetLoss(session);
+      sum_loss_per_epoch += loss;
+      std::cout<< "batch:"<< k<<",loss:"<< loss <<std::endl;
+      sum_acc_per_epoch += CalculateAccuracy(session, lables,LABEL_CLASS);
+
+    }
+    std::cout << "epoch "
+              << "[" << j << "]"
+              << ",mean Loss " << sum_loss_per_epoch / batch_num << ",train acc " << sum_acc_per_epoch / batch_num
+              << std::endl;
+  }
+  session->SaveToFile(ms_file);
+  return mindspore::lite::RET_OK;
+}
+
+void ReadTxt(const std::string &file, std::vector<std::string> *train_data) {
+  std::fstream fin;
+  fin.open(file, std::ios::in);
+  if (fin.is_open()) {
+    std::string train_sentense;
+    while (!fin.eof()) {
+      getline(fin, train_sentense, '\n');
+      size_t endpos = train_sentense.find_last_not_of("\r");
+      if(endpos != std::string::npos) {
+        train_sentense.substr(0,endpos+1).swap(train_sentense);
+      }
+      if (!train_sentense.empty()) {
+        train_data->push_back(train_sentense);
+      }
+    }
+    fin.close();
+  }
+}
+
+// Set input tensors.
+int SetBertInputs(const std::string &train_file,const std::string &vocab_file,const std::string &labels_file) {
+  if (train_file.empty()) {
+    std::cout << "files empty";
+    return -1;
+  }
+  std::vector<std::string> train_data;
+  ReadTxt(train_file, &train_data);
+  std::vector<std::string> labels;
+  ReadTxt(labels_file, &labels);
+  std::map<std::string,int> labels_map;
+  for(int i=0;i<labels.size();i++) {
+    labels_map[labels[i]] = i;
+  }
+  total_size = train_data.size();
+  int total_batch_num = total_size / BATCH_SIZE;
+  if(total_batch_num == 0) {
+    std::cout << "train data size less than one batch,not support now";
+    return -1;
+  }
+  int train_size = total_batch_num * BATCH_SIZE;
+  input_ids = new (std::nothrow) int[train_size * MAX_SEQ_LENGTH];
+  input_mask = new (std::nothrow) int[train_size * MAX_SEQ_LENGTH];
+  token_type_ids = new (std::nothrow) int[train_size * MAX_SEQ_LENGTH];
+  lable_ids = new (std::nothrow) int[train_size];
+
+  CustomizedTokenizer customized_tokenizer;
+  bool do_lower_case = true;
+  customized_tokenizer.init(vocab_file, do_lower_case);
+
+  int s_input_ids[MAX_SEQ_LENGTH];
+  int s_attention_mask[MAX_SEQ_LENGTH];
+  int s_token_type_ids[MAX_SEQ_LENGTH];
+  int seq_length;
+  // less than one batch would  drop
+  for (int i = 0; i < train_size; i++) {
+    std::vector<std::string> dataset_tuple(2);
+    dataset_tuple.clear();
+    char* token = strtok(reinterpret_cast<char *>(train_data[i].data()),"\t");
+    while (token != NULL) {
+      dataset_tuple.push_back(token);
+      token = strtok(NULL, "\t");
+    }
+    if(dataset_tuple.size() != 2) {
+      std::cout<< "train data error,must 2 word"<<std::endl;
+    }
+    customized_tokenizer.tokenize(dataset_tuple[1], s_input_ids, s_attention_mask, s_token_type_ids, seq_length);
+    memcpy(input_ids + i * MAX_SEQ_LENGTH, s_input_ids, MAX_SEQ_LENGTH);
+    memcpy(input_mask + i * MAX_SEQ_LENGTH, s_attention_mask, MAX_SEQ_LENGTH);
+    memcpy(token_type_ids + i * MAX_SEQ_LENGTH, s_token_type_ids, MAX_SEQ_LENGTH);
+    std::string key = dataset_tuple[0];
+    lable_ids[i] = labels_map[key];
+  }
+  return train_size;
+}
+
+void FreeBertInput() {
+  delete input_ids;
+  delete input_mask;
+  delete token_type_ids;
+}
\ No newline at end of file
diff --git a/mindspore/lite/flclient/src/main/native/bert_train.h b/mindspore/lite/flclient/src/main/native/bert_train.h
new file mode 100644
index 0000000..f222be2
--- /dev/null
+++ b/mindspore/lite/flclient/src/main/native/bert_train.h
@@ -0,0 +1,26 @@
+/**
+ * Copyright 2020 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef MSLITE_FL_BERT_TRAIN_H
+#define MSLITE_FL_BERT_TRAIN_H
+
+#include "include/train/train_session.h"
+#include <string>
+int SetBertInputs(const std::string &train_file,const std::string &vocab_file,const std::string &labels_file);
+void FreeBertInput();
+int TrainBert(mindspore::session::TrainSession *session,const std::string &ms_file,int batch_size ,int epoches);
+float InferBert(mindspore::session::TrainSession *session);
+#endif  // MSLITE_FL_BERT_TRAIN_H
diff --git a/mindspore/lite/flclient/src/main/native/data_prepare.cpp b/mindspore/lite/flclient/src/main/native/data_prepare.cpp
deleted file mode 100644
index c8a1b5e..0000000
--- a/mindspore/lite/flclient/src/main/native/data_prepare.cpp
+++ /dev/null
@@ -1,517 +0,0 @@
-/**
- * Copyright 2020 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-
-#include <fstream>
-#include <iostream>
-#include <cstdio>
-#include <vector>
-#include <map>
-#include <ctime>
-//#include <cstring>
-//#include <algorithm>
-//#include <cstdlib>
-using namespace std;
-#define MAX_SEQ_LENGTH 256
-
-class CustomizedTokenizer
-{
- public:
-  CustomizedTokenizer();
-  ~CustomizedTokenizer();
-
-  void init(const string &vocab_file, bool do_lower_case);
-  void tokenize(const string &text, string output_tokens[MAX_SEQ_LENGTH], int &seq_length);
-  void tokenize(const string &text, int input_ids[MAX_SEQ_LENGTH],
-                int attention_mask[MAX_SEQ_LENGTH], int token_type_ids[MAX_SEQ_LENGTH], int &seq_length);
-
-//private:
-  string _text;
-  string _tokens[MAX_SEQ_LENGTH];
-  int _tokens_length = 0;
-  int _text_length = 0;
-
-  map<string, int> _vocab;
-  bool _do_lower_case = true;
-  int _max_input_chars_per_word = 100;
-
-  static void _lower_token(const string &token, string &new_token);
-  void _split_text();
-  void _clean_text();
-  void _fixed_matching(int &pos, string &token);
-  void _load_vocab(const string &vocab_file);
-};
-
-CustomizedTokenizer::CustomizedTokenizer() = default;
-
-CustomizedTokenizer::~CustomizedTokenizer() {
-  _vocab.clear();
-}
-
-void CustomizedTokenizer::init(const string &vocab_file, bool do_lower_case) {
-  _do_lower_case = do_lower_case;
-  _load_vocab(vocab_file);
-}
-
-void CustomizedTokenizer::tokenize(const string &text, string output_tokens[MAX_SEQ_LENGTH], int &seq_length) {
-//  clock_t startTime;
-//  double time_cost = 0.0;
-  _text = text;
-  _split_text();
-
-  int output_tokens_pos = 0;
-
-//  startTime = clock();
-  for (int i = 0; i < _tokens_length; ++i) {
-    int length = _tokens[i].length();
-    if (length > _max_input_chars_per_word) {
-      output_tokens[output_tokens_pos] = "[UNK]";
-      output_tokens_pos++;
-      if (MAX_SEQ_LENGTH <= output_tokens_pos) {
-        _tokens_length = 0;
-        output_tokens[MAX_SEQ_LENGTH - 1] = "[SEP]";
-        return;
-      }
-      continue;
-    }
-
-    bool is_bad = false;
-    int start = 0;
-    vector<string> sub_tokens;
-    while (start < length) {
-      int end = length;
-      string cur_substr;
-      while (start < end) {
-        string substr = _tokens[i].substr(start, end - start);
-        if (start > 0) {
-          substr.insert(0,"##");
-        }
-        if (_vocab.find(substr) != _vocab.end()) {
-          cur_substr = substr;
-          break;
-        }
-        end--;
-      }
-      if (cur_substr.empty()) {
-        is_bad = true;
-        break;
-      }
-      sub_tokens.emplace_back(cur_substr);
-      start = end;
-    }
-    if (is_bad) {
-      output_tokens[output_tokens_pos] = "[UNK]";
-      output_tokens_pos++;
-      if (MAX_SEQ_LENGTH <= output_tokens_pos) {
-        _tokens_length = 0;
-        output_tokens[MAX_SEQ_LENGTH - 1] = "[SEP]";
-        return;
-      }
-    } else {
-      for (const string& sub_token: sub_tokens) {
-        output_tokens[output_tokens_pos] = sub_token;
-        output_tokens_pos++;
-        if (MAX_SEQ_LENGTH <= output_tokens_pos) {
-          _tokens_length = 0;
-          output_tokens[MAX_SEQ_LENGTH - 1] = "[SEP]";
-          return;
-        }
-      }
-    }
-  }
-
-  output_tokens[output_tokens_pos] = "[SEP]";
-  for (int i = output_tokens_pos + 1; i < MAX_SEQ_LENGTH; ++i) {
-    output_tokens[i] = "[PAD]";
-  }
-//  time_cost += (double)(clock() - startTime) / CLOCKS_PER_SEC;
-  _tokens_length = 0;
-
-//  cout << "Local run time is: " << time_cost << endl;
-}
-
-void CustomizedTokenizer::tokenize(const string &text, int input_ids[MAX_SEQ_LENGTH],
-                                   int attention_mask[MAX_SEQ_LENGTH], int token_type_ids[MAX_SEQ_LENGTH],
-                                   int &seq_length) {
-//  clock_t startTime;
-//  double time_cost = 0.0;
-  _text = text;
-  _split_text();
-
-  int output_tokens_pos = 0;
-
-//  startTime = clock();
-  for (int i = 0; i < _tokens_length; ++i) {
-    int length = _tokens[i].length();
-    if (length > _max_input_chars_per_word) {
-      input_ids[output_tokens_pos] = _vocab["[UNK]"];
-      attention_mask[output_tokens_pos] = 1;
-      token_type_ids[output_tokens_pos] = 0;
-      output_tokens_pos++;
-      if (MAX_SEQ_LENGTH <= output_tokens_pos) {
-        _tokens_length = 0;
-        input_ids[MAX_SEQ_LENGTH - 1] = _vocab["[SEP]"];
-        attention_mask[MAX_SEQ_LENGTH - 1] = 1;
-        token_type_ids[MAX_SEQ_LENGTH - 1] = 0;
-        return;
-      }
-      continue;
-    }
-
-    bool is_bad = false;
-    int start = 0;
-    vector<string> sub_tokens;
-    while (start < length) {
-      int end = length;
-      string cur_substr;
-      while (start < end) {
-        string substr = _tokens[i].substr(start, end - start);
-        if (start > 0) {
-          substr.insert(0,"##");
-        }
-        if (_vocab.find(substr) != _vocab.end()) {
-          cur_substr = substr;
-          break;
-        }
-        end--;
-      }
-      if (cur_substr.empty()) {
-        is_bad = true;
-        break;
-      }
-      sub_tokens.emplace_back(cur_substr);
-      start = end;
-    }
-    if (is_bad) {
-      input_ids[output_tokens_pos] = _vocab["[UNK]"];
-      attention_mask[output_tokens_pos] = 1;
-      token_type_ids[output_tokens_pos] = 0;
-      output_tokens_pos++;
-      if (MAX_SEQ_LENGTH <= output_tokens_pos) {
-        _tokens_length = 0;
-        input_ids[MAX_SEQ_LENGTH - 1] = _vocab["[SEP]"];
-        attention_mask[MAX_SEQ_LENGTH - 1] = 1;
-        token_type_ids[MAX_SEQ_LENGTH - 1] = 0;
-        return;
-      }
-    } else {
-      for (const string& sub_token: sub_tokens) {
-        input_ids[output_tokens_pos] = _vocab[sub_token];
-        attention_mask[output_tokens_pos] = 1;
-        token_type_ids[output_tokens_pos] = 0;
-        output_tokens_pos++;
-        if (MAX_SEQ_LENGTH <= output_tokens_pos) {
-          _tokens_length = 0;
-          input_ids[MAX_SEQ_LENGTH - 1] = _vocab["[SEP]"];
-          attention_mask[MAX_SEQ_LENGTH - 1] = 1;
-          token_type_ids[MAX_SEQ_LENGTH - 1] = 0;
-          return;
-        }
-      }
-    }
-  }
-//  time_cost += (double)(clock() - startTime) / CLOCKS_PER_SEC;
-  input_ids[output_tokens_pos] = _vocab["[SEP]"];
-  attention_mask[output_tokens_pos] = 1;
-  token_type_ids[output_tokens_pos] = 0;
-  for (int i = output_tokens_pos + 1; i < MAX_SEQ_LENGTH; ++i) {
-    input_ids[i] = 0;
-    attention_mask[i] = 0;
-    token_type_ids[i] = 0;
-  }
-  _tokens_length = 0;
-
-//  cout << "Local run time is: " << time_cost << endl;
-}
-
-void CustomizedTokenizer::_lower_token(const string &token, string &new_token) {
-  int length = token.length();
-  if (token[0] == '[') {
-    if (length == 5) {
-      if (token[4] == ']') {
-        if (token[1] == 'U' && token[2] == 'N' && token[3] == 'K') {
-          new_token = token;
-          return;
-        }
-        if (token[1] == 'S' && token[2] == 'E' && token[3] == 'P') {
-          new_token = token;
-          return;
-        }
-        if (token[1] == 'P' && token[2] == 'A' && token[3] == 'D') {
-          new_token = token;
-          return;
-        }
-        if (token[1] == 'C' && token[2] == 'L' && token[3] == 'S') {
-          new_token = token;
-          return;
-        }
-      }
-    }
-    if (length == 6) {
-      if (token[1] == 'M' && token[2] == 'A' && token[3] == 'S' && token[4] == 'K' && token[5] == ']') {
-        new_token = token;
-        return;
-      }
-    }
-  }
-
-  for (char ch: token) {
-    if (ch <= 90 && ch >= 65) {
-      new_token += char(ch + 32);
-    } else {
-      new_token += ch;
-    }
-  }
-}
-
-void CustomizedTokenizer::_clean_text() {
-  int pos;
-  pos = _text.find("\u00A0");
-  while (pos > 0) {
-    _text = _text.replace(pos, 2, " ");
-    pos = _text.find("\u00A0");
-  }
-  pos = _text.find("\u2800");
-  while (pos > 0) {
-    _text = _text.replace(pos, 3, " ");
-    pos = _text.find("\u2800");
-  }
-  pos = _text.find("\u3000");
-  while (pos > 0) {
-    _text = _text.replace(pos, 3, " ");
-    pos = _text.find("\u3000");
-  }
-  pos = _text.find("\ufeff");
-  while (pos > 0) {
-    _text = _text.replace(pos, 3, "");
-    pos = _text.find("\ufeff");
-  }
-  pos = _text.find("\ue312");
-  while (pos > 0) {
-    _text = _text.replace(pos, 3, "");
-    pos = _text.find("\ue312");
-  }
-}
-
-void CustomizedTokenizer::_load_vocab(const string &vocab_file) {
-  int index = 0;
-  fstream fin;
-  fin.open(vocab_file, ios::in);
-  if (fin.is_open()) {
-    string basicString;
-    while (!fin.eof()) {
-      getline(fin, basicString, '\n');
-      if (!basicString.empty()) {
-        _vocab[basicString] = index;
-      }
-      index++;
-    }
-    fin.close();
-  }
-}
-
-void CustomizedTokenizer::_fixed_matching(int &pos, string &token) {
-  token += _text[pos];
-  int rest_length = _text_length - pos;
-  if (rest_length < 2) {
-    pos++;
-    return;
-  }
-  if (_text[pos] == char(-16)) {
-    if (rest_length < 4) {
-      pos++;
-      return;
-    }
-    if (_text[pos+1] < char(0) && _text[pos+2] < char(0) && _text[pos+3] < char(0)) {
-      token += _text[pos+1];
-      token += _text[pos+2];
-      token += _text[pos+3];
-      pos += 4;
-      return;
-    }
-    pos++;
-    return;
-  }
-  if (_text[pos] >= char(-62) && _text[pos] <= char(-37)) {
-    if (_text[pos+1] < char(0)) {
-      token += _text[pos+1];
-      pos += 2;
-      return;
-    }
-    pos++;
-    return;
-  }
-  if (rest_length < 3) {
-    pos++;
-    return;
-  }
-  if (_text[pos+1] < char(0) && _text[pos+2] < char(0)) {
-    token += _text[pos+1];
-    token += _text[pos+2];
-    pos += 3;
-    return;
-  }
-  pos++;
-}
-
-void CustomizedTokenizer::_split_text() {
-  // Performs invalid character removal and whitespace cleanup on text.
-  _clean_text();
-  _text_length = _text.length();
-  int pos = 0;
-  string token;
-  _tokens[_tokens_length++] = "[CLS]";
-  while (pos < _text_length) {
-    if (_tokens_length == MAX_SEQ_LENGTH) {
-      break;
-    }
-    if (_text[pos] == ' ') {
-      if (!token.empty()) {
-        if (token[0] >= char(0) && _do_lower_case) {
-          string new_token;
-          _lower_token(token, new_token);
-          _tokens[_tokens_length++] = new_token;
-        } else {
-          _tokens[_tokens_length++] = token;
-        }
-        token.clear();
-      }
-      pos++;
-      continue;
-    }
-    // We treat all non-letter/number ASCII as punctuation.
-    // Characters such as "^", "$", and "`" are not in the Unicode
-    // Punctuation class but we treat them as punctuation anyways, for
-    // consistency.
-    if ((_text[pos] >= char(33) && _text[pos] <= char(47)) ||
-      (_text[pos] >= char(58) && _text[pos] <= char(64)) ||
-      (_text[pos] >= char(91) && _text[pos] <= char(96)) ||
-      (_text[pos] >= char(123) && _text[pos] <= char(126))
-      ) {
-      if (!token.empty()) {
-        if (token[0] >= char(0) && _do_lower_case) {
-          string new_token;
-          _lower_token(token, new_token);
-          _tokens[_tokens_length++] = new_token;
-        } else {
-          _tokens[_tokens_length++] = token;
-        }
-        token.clear();
-      }
-      if (_text[pos] == char(91)) {
-        if (pos < _text_length - 4) {
-          if (_text[pos+1] == 'S' && _text[pos+2] == 'E' && _text[pos+3] == 'P' && _text[pos+4] == ']') {
-            _tokens[_tokens_length++] = "[SEP]";
-            pos += 5;
-            continue;
-          }
-          if (_text[pos+1] == 'U' && _text[pos+2] == 'N' && _text[pos+3] == 'K' && _text[pos+4] == ']') {
-            _tokens[_tokens_length++] = "[UNK]";
-            pos += 5;
-            continue;
-          }
-          if (_text[pos+1] == 'P' && _text[pos+2] == 'A' && _text[pos+3] == 'D' && _text[pos+4] == ']') {
-            _tokens[_tokens_length++] = "[PAD]";
-            pos += 5;
-            continue;
-          }
-          if (_text[pos+1] == 'C' && _text[pos+2] == 'L' && _text[pos+3] == 'S' && _text[pos+4] == ']') {
-            _tokens[_tokens_length++] = "[CLS]";
-            pos += 5;
-            continue;
-          }
-        }
-        if (pos < _text_length - 5) {
-          if (_text[pos+1] == 'M' && _text[pos+2] == 'A' && _text[pos+3] == 'S' && _text[pos+4] == 'K' &&
-            _text[pos+5] == ']') {
-            _tokens[_tokens_length++] = "[MASK]";
-            pos += 6;
-            continue;
-          }
-        }
-      }
-      string temp = {_text[pos]};
-      _tokens[_tokens_length++] = temp;
-      pos++;
-      continue;
-    }
-    if (_text[pos] < char(0)) {
-      if (!token.empty()) {
-        if (token[0] >= char(0) && _do_lower_case) {
-          string new_token;
-          _lower_token(token, new_token);
-          _tokens[_tokens_length++] = new_token;
-        } else {
-          _tokens[_tokens_length++] = token;
-        }
-        token.clear();
-      }
-      _fixed_matching(pos, token);
-      _tokens[_tokens_length++] = token;
-      token.clear();
-    } else {
-      token += _text[pos];
-      pos++;
-    }
-  }
-  if (!token.empty()) {
-    if (token[0] >= char(0) && _do_lower_case) {
-      string new_token;
-      _lower_token(token, new_token);
-      _tokens[_tokens_length++] = new_token;
-    } else {
-      _tokens[_tokens_length++] = token;
-    }
-  }
-}
-