<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

#Data process

In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <jni.h>
#include <cstring>
#include "include/errorcode.h"
#include "include/train/train_session.h"
#include "util.h"
#include "lenet_train.h"
#include "bert_train.h"
#include "src/common/log_adapter.h"

static jobject fbb;
static jmethodID create_string_char;
static jobject jmap;
static jstring model_path;

char *JstringToChar(JNIEnv *env, jstring jstr) {
  char *rtn = nullptr;
  jclass clsstring = env->FindClass("java/lang/String");
  jstring strencode = env->NewStringUTF("GB2312");
  jmethodID mid = env->GetMethodID(clsstring, "getBytes", "(Ljava/lang/String;)[B");
  jbyteArray barr = (jbyteArray)env->CallObjectMethod(jstr, mid, strencode);
  jsize alen = env->GetArrayLength(barr);
  jbyte *ba = env->GetByteArrayElements(barr, JNI_FALSE);
  if (alen > 0) {
    rtn = new char[alen + 1];
    memcpy(rtn, ba, alen);
    rtn[alen] = 0;
  }
  env->ReleaseByteArrayElements(barr, ba, 0);
  return rtn;
}

extern "C" jint CreateFeatureMap(JNIEnv *env, const char *name, float *data, size_t size) {
  jstring name1 = env->NewStringUTF(name);
  jint name_offset = env->CallIntMethod(fbb, create_string_char, name1);
  // 1. set data size
  jfloatArray ret = env->NewFloatArray(size);
  env->SetFloatArrayRegion(ret, 0, size, data);
  // 2. get methodid createDataVector
  jclass fm_cls = env->FindClass("mindspore/schema/FeatureMap");
  jmethodID createDataVector =
    env->GetStaticMethodID(fm_cls, "createDataVector", "(Lcom/google/flatbuffers/FlatBufferBuilder;[F)I");
  // 3. calc data offset
  jint data_offset = env->CallStaticIntMethod(fm_cls, createDataVector, fbb, ret);
  jmethodID createFeatureMap =
    env->GetStaticMethodID(fm_cls, "createFeatureMap", "(Lcom/google/flatbuffers/FlatBufferBuilder;II)I");
  jint fm_offset = env->CallStaticIntMethod(fm_cls, createFeatureMap, fbb, name_offset, data_offset);
  return fm_offset;
}

extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_NativeTrain_train(JNIEnv *env, jobject thiz,
                                                                             jlong session_ptr,
                                                                             jint batch_size, jint epoches,
                                                                             jint early_stop_type) {
  std::string model_name = JstringToChar(env, model_path);
  if(model_name.find("lenet") != std::string::npos){
    return TrainLenet(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr), JstringToChar(env, model_path),
                      batch_size, epoches);
  }
  return TrainBert(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr), JstringToChar(env, model_path),
                   batch_size, epoches);

}

extern "C" jlong JNICALL Java_com_huawei_flclient_NativeTrain_createSession(JNIEnv *env, jclass, jstring ms_file,
                                                                            jlong) {
  model_path = ms_file;
  return reinterpret_cast<jlong>(CreateSession(JstringToChar(env, ms_file)));
}

extern "C" JNIEXPORT jobject JNICALL Java_com_huawei_flclient_NativeTrain_getFeaturesMap(JNIEnv *env, jclass,
                                                                                         jlong session_ptr) {
  jclass strClass = env->FindClass("java/lang/String");
  jmethodID ctorID = env->GetMethodID(strClass, "<init>", "([BLjava/lang/String;)V");
  jstring encoding = env->NewStringUTF("GB2312");

  TrainFeatureParam **train_features = nullptr;
  int feature_size = 0;
  auto status =
    GetFeatures(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr), &train_features, &feature_size);
  if (status != mindspore::lite::RET_OK) {
    MS_LOG(ERROR) << "get features failed:";
    return NULL;
  }
  jclass jmapClass = env->FindClass("java/util/HashMap");
  if (jmapClass == NULL) {
    return NULL;
  }
  jmethodID mid = env->GetMethodID(jmapClass, "<init>", "()V");
  jmethodID putMethod = env->GetMethodID(jmapClass, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
  jmethodID getMethod = env->GetMethodID(jmapClass, "get", "(Ljava/lang/Object;)Ljava/lang/Object;");
  bool map_exist = true;
  if (jmap == nullptr) {
    jmap = env->NewGlobalRef(env->NewObject(jmapClass, mid, feature_size));
    map_exist = false;
  }
  for (int i = 0; i < feature_size; i++) {
    jbyteArray bytes = env->NewByteArray(strlen(train_features[i]->name));
    env->SetByteArrayRegion(bytes, 0, strlen(train_features[i]->name), (jbyte *)train_features[i]->name);
    auto key = (jstring)env->NewObject(strClass, ctorID, bytes, encoding);
    jfloatArray feature_data;
    if (map_exist) {
      feature_data = static_cast<jfloatArray>(env->CallObjectMethod(jmap, getMethod, key));
    } else {
      feature_data = env->NewFloatArray(train_features[i]->elenums);
    }
    if (feature_data == nullptr) {
      std::cout << "create null feature data" << std::endl;
    }
    jfloat *fd = env->GetFloatArrayElements(feature_data, NULL);
    for (int j = 0; j < train_features[i]->elenums; j++) {
      fd[j] = reinterpret_cast<float *>(train_features[i]->data)[j];
    }
    env->ReleaseFloatArrayElements(feature_data, fd, 0);
    env->CallObjectMethod(jmap, putMethod, key, feature_data);
    env->DeleteLocalRef(bytes);
    env->DeleteLocalRef(key);
  }
  env->DeleteLocalRef(encoding);
  for (int i = 0; i < feature_size; i++) {
    delete train_features[i]->name;
    free(train_features[i]->data);
    delete train_features[i];
  }
  return jmap;
}

extern "C" JNIEXPORT jintArray JNICALL Java_com_huawei_flclient_NativeTrain_getSeralizeFeaturesMap(JNIEnv *env,
                                                                                                   jobject thiz,
                                                                                                   jlong session_ptr,
                                                                                                   jobject builder) {
  fbb = builder;
  jclass fb_clazz = env->GetObjectClass(builder);
  create_string_char = env->GetMethodID(fb_clazz, "createString", "(Ljava/lang/CharSequence;)I");
  TrainFeatureParam **train_features = nullptr;
  int feature_size = 0;
  auto status =
    GetFeatures(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr), &train_features, &feature_size);
  if (status != mindspore::lite::RET_OK) {
    MS_LOG(ERROR) << "get features failed:";
    return env->NewIntArray(0);
  }
  jintArray ret = env->NewIntArray(feature_size);
  jint *data = env->GetIntArrayElements(ret, NULL);

  for (int i = 0; i < feature_size; i++) {
    data[i] = CreateFeatureMap(env, train_features[i]->name, reinterpret_cast<float *>(train_features[i]->data),
                               train_features[i]->elenums);
    MS_LOG(INFO) << "upload feature:"
                 << ", name:" << train_features[i]->name << ", elenums:" << train_features[i]->elenums;
  }
  env->ReleaseIntArrayElements(ret, data, 0);
  for (int i = 0; i < feature_size; i++) {
    delete train_features[i]->name;
    free(train_features[i]->data);
    delete train_features[i];
  }
  return ret;
}

extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_NativeTrain_updateFeatures(JNIEnv *env, jclass,
                                                                                      jlong session_ptr,
                                                                                      jobject features) {
  jclass arr_cls = env->GetObjectClass(features);
  jmethodID size_method = env->GetMethodID(arr_cls, "size", "()I");
  jmethodID get_method = env->GetMethodID(arr_cls, "get", "(I)Ljava/lang/Object;");

  jclass fm_cls = env->FindClass("mindspore/schema/FeatureMap");
  jmethodID weight_name_method = env->GetMethodID(fm_cls, "weightFullname", "()Ljava/lang/String;");
  jmethodID data_length_method = env->GetMethodID(fm_cls, "dataLength", "()I");
  jmethodID data_method = env->GetMethodID(fm_cls, "data", "(I)F");
  jclass clsstring = env->FindClass("java/lang/String");
  jmethodID mid = env->GetMethodID(clsstring, "getBytes", "(Ljava/lang/String;)[B");
  int size = env->CallIntMethod(features, size_method);
  // transform FeatureMap to TrainFeatureParm
  TrainFeatureParam *features_param = reinterpret_cast<TrainFeatureParam *>(malloc(size * sizeof(TrainFeatureParam)));
  for (int i = 0; i < size; ++i) {
    TrainFeatureParam *param = features_param + i;
    jobject feature = env->CallObjectMethod(features, get_method, i);
    // set feature_param name
    jstring weight_full_name = (jstring)env->CallObjectMethod(feature, weight_name_method);
    jstring strencode = env->NewStringUTF("GB2312");
    jbyteArray barr = (jbyteArray)env->CallObjectMethod(weight_full_name, mid, strencode);
    char *name = nullptr;
    jsize alen = env->GetArrayLength(barr);
    jbyte *ba = env->GetByteArrayElements(barr, JNI_FALSE);
    if (alen > 0) {
      name = new char[alen + 1];
      if (ba == nullptr) {
        MS_LOG(ERROR) << "name is nullptr";
        return mindspore::lite::RET_ERROR;
      }
      memcpy(name, ba, alen);
      name[alen] = 0;
    }
    param->name = name;
    env->ReleaseByteArrayElements(barr, ba, 0);
    int data_length = env->CallIntMethod(feature, data_length_method);
    float *data = static_cast<float *>(malloc(data_length * sizeof(float)));
    memset(data, 0, data_length * sizeof(float));
    for (int j = 0; j < data_length; ++j) {
      float *addr = data + j;
      *addr = env->CallFloatMethod(feature, data_method, j);
    }
    param->data = data;
    param->elenums = data_length;
    param->type = mindspore::kNumberTypeFloat32;
    MS_LOG(INFO) << "get feature:" << param->name << ",elenums:" << param->elenums;
  }
  return UpdateFeatures(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr), JstringToChar(env, model_path),
                        features_param, size);
}

extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_NativeTrain_setInput(JNIEnv *env, jobject, jstring files) {
  std::string input_files = JstringToChar(env, files);
  std::string pattern = ",";
  std::string strs = input_files + pattern;
  size_t pos = strs.find(pattern);
  std::vector<std::string> res;
  while (pos != strs.npos) {
    std::string temp = strs.substr(0, pos);
    res.push_back(temp);
    strs = strs.substr(pos + 1, strs.size());
    pos = strs.find(pattern);
  }
  if (res.size() == 2) {
    return SetLenetInputs(res[0], res[1]);
  } else if (res.size() == 3) {
    return SetBertInputs(res[0], res[1], res[1]);
  }
  std::cout << "input files error" << std::endl;
  return -1;
}

extern "C" JNIEXPORT jfloat JNICALL Java_com_huawei_flclient_NativeTrain_infer(JNIEnv *env, jclass, jlong session_ptr) {

  std::string model_name = JstringToChar(env, model_path);
  if(model_name.find("lenet") != std::string::npos){
    return InferLenet(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr));
  }
  return InferBert(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr));
}

extern "C" JNIEXPORT jintArray JNICALL Java_com_huawei_flclient_NativeTrain_getInferLables(JNIEnv *env, jclass, jlong session_ptr) {

  std::string model_name = JstringToChar(env, model_path);
  std::vector<int> infer_result;
  if(model_name.find("lenet") != std::string::npos){
   infer_result = GetLenetInferRes(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr));
  } else {
    infer_result = GetBertInferRes(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr));
  }
  jintArray jArray = env->NewIntArray(infer_result.size());
  jint *jnum = new jint[infer_result.size()];
  for(int i=0;i<infer_result.size();i++) {
    *(jnum+i) = infer_result[i];
  }
  env->SetIntArrayRegion(jArray, 0, infer_result.size(), jnum);
  return jArray;
}

extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_NativeTrain_free(JNIEnv *env, jclass, jlong session_ptr) {
  env->DeleteGlobalRef(jmap);
  jmap = NULL;
  if (0 != session_ptr) {
    delete (reinterpret_cast<mindspore::session::TrainSession *>(session_ptr));
  }
  std::string model_name = JstringToChar(env, model_path);
  if(model_name.find("lenet") != std::string::npos){
    FreeLenetInput();
  } else {
    FreeBertInput();
  }
  return 0;
}


In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef MSLITE_FL_LENET_TRAIN_H
#define MSLITE_FL_LENET_TRAIN_H

#include "include/train/train_session.h"
#include <string>
using mindspore::session::TrainSession;
int SetLenetInputs(const std::string &input_data, const std::string &label_data);
float InferLenet(TrainSession *session);
void FreeLenetInput();
std::vector<int> GetLenetInferRes(TrainSession *session);
int TrainLenet(TrainSession *session,const std::string &save_path,int batch_size,int epoches);
#endif  // MSLITE_FL_LENET_TRAIN_H


In [None]:
/**
 * Copyright 2021 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "lenet_train.h"
#include "util.h"
#include <cstring>
#include <fstream>
#include <iostream>
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#include <climits>

static char *fl_lenet_I0 = 0;
static char *fl_lenet_I1 = 0;
static int input_size = 0;
static int batch_num = 0;
#define LENET_LABEL_CLASS 62
std::vector<int> FillLenetInput(mindspore::session::TrainSession *train_session, int batch_idx,int train_mode=true) {
  std::vector<int> labels_vec;
  auto inputs = train_session->GetInputs();
  int batch_size = inputs[0]->shape()[0];
  int data_size = inputs[0]->ElementsNum() / batch_size;
  int num_classes = inputs[1]->shape()[1];
  float* input_data = reinterpret_cast<float *>(inputs.at(0)->MutableData());
  auto labels = reinterpret_cast<float *>(inputs.at(1)->MutableData());
  std::fill(labels, labels + inputs.at(1)->ElementsNum(), 0.f);
  int label_idx = 0;
  for (int i = 0; i < batch_size; i++) {
    std::memcpy(input_data + i * data_size,
                (float *)fl_lenet_I0 + +batch_idx * inputs[0]->ElementsNum() + i * data_size,
                data_size * sizeof(float));
    if(train_mode) {
      label_idx = *(reinterpret_cast<int *>(fl_lenet_I1) + batch_idx * batch_size + i);
      labels[i * num_classes + label_idx] = 1.0;  // Model expects labels in onehot representation
      labels_vec.push_back(label_idx);
    }
  }
  return labels_vec;
}


// net inference function
float InferLenet(TrainSession *session) {
  auto labels = FillLenetInput(session,0);
  auto infer_acc = CalculateAccuracy(session,labels,LENET_LABEL_CLASS);
  std::cout << "inference acc is:" << infer_acc << std::endl;
  return infer_acc;
}

// net inference function
std::vector<int> GetLenetInferRes(TrainSession *session) {
  (void)FillLenetInput(session,0,false);
  return GetInferResult(session,LENET_LABEL_CLASS);
}

// net training function
int TrainLenet(TrainSession *session,const std::string &save_path,int batch_size,int epoches) {
  if (epoches <= 0) {
    MS_LOG(ERROR) << "error iterations or epoch!, epoch:"
                  << ", iterations" << epoches;
    return mindspore::lite::RET_ERROR;
  }
  batch_num = input_size/(session->GetInputs()[0]->ElementsNum() * sizeof(float));
  std::cout << "total train epoches :" << epoches << ",batch_num:" << batch_num<<std::endl;
  for (int j = 0; j < epoches; ++j) {
    float sum_loss_per_epoch = 0.0f;
    float sum_acc_per_epoch = 0.0f;
    for(int k=0;k<batch_num;++k) {
      auto lables = FillLenetInput(session,k);
      session->Train();
      session->RunGraph(nullptr, nullptr);
      sum_loss_per_epoch+=GetLoss(session);
      sum_acc_per_epoch += CalculateAccuracy(session,lables,LENET_LABEL_CLASS);
    }
    std::cout << "epoch " << "[" <<j<<"]" << ",mean Loss " << sum_loss_per_epoch/batch_num <<",train acc "<<  sum_acc_per_epoch/batch_num<<std::endl;
  }
  session->SaveToFile(save_path);
  return mindspore::lite::RET_OK;
}

std::string RealPath(const char *path) {
  if (path == nullptr) {
    MS_LOG(ERROR) << "path is nullptr";
    return "";
  }
  if ((strlen(path)) >= PATH_MAX) {
    MS_LOG(ERROR) << "path is too long";
    return "";
  }
  auto resolved_path = std::make_unique<char[]>(PATH_MAX);
  if (resolved_path == nullptr) {
    MS_LOG(ERROR) << "new resolved_path failed";
    return "";
  }
#ifdef _WIN32
  char *real_path = _fullpath(resolved_path.get(), path, 1024);
#else
  char *real_path = realpath(path, resolved_path.get());
#endif
  if (real_path == nullptr || strlen(real_path) == 0) {
    MS_LOG(ERROR) << "file path is not valid : " << path;
    return "";
  }
  std::string res = resolved_path.get();
  return res;
}

char *ReadFile(const char *file, size_t *size) {
  if (file == nullptr) {
    MS_LOG(ERROR) << "file is nullptr";
    return nullptr;
  }
  //  MS_ASSERT(size != nullptr);
  std::string real_path = RealPath(file);
  std::ifstream ifs(real_path);
  if (!ifs.good()) {
    MS_LOG(ERROR) << "file: " << real_path << " is not exist";
    return nullptr;
  }

  if (!ifs.is_open()) {
    MS_LOG(ERROR) << "file: " << real_path << " open failed";
    return nullptr;
  }

  ifs.seekg(0, std::ios::end);
  *size = ifs.tellg();
  std::unique_ptr<char[]> buf(new (std::nothrow) char[*size]);
  if (buf == nullptr) {
    MS_LOG(ERROR) << "malloc buf failed, file: " << real_path;
    ifs.close();
    return nullptr;
  }
  ifs.seekg(0, std::ios::beg);
  ifs.read(buf.get(), *size);
  ifs.close();

  return buf.release();
}

// Set input tensors.
int SetLenetInputs(const std::string &input_data, const std::string &label_data) {
  size_t input0_size = 0;
  char *bin_buf = ReadFile(input_data.c_str(), &input0_size);
  if (bin_buf == nullptr) {
    MS_LOG(ERROR) << "ReadFile return nullptr";
    return -1;
  }
  fl_lenet_I0 = bin_buf;
  size_t input1_size = 0;
  bin_buf = ReadFile(label_data.c_str(), &input1_size);
  if (bin_buf == nullptr) {
    MS_LOG(ERROR) << "ReadFile return nullptr";
    return -1;
  }
  fl_lenet_I1 = bin_buf;
  input_size = input0_size;
  return input0_size;
}
void FreeLenetInput() {
  delete fl_lenet_I0;
  delete fl_lenet_I1;
}

In [None]:
cmake_minimum_required(VERSION 3.14)

project(FederalLearning)

option(SUPPORT_GPU "if support gpu" off)
set(BUILD_LITE "on")
set(SUPPORT_TRAIN "on")
set(PLATFORM_ARM "on")

set(TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../../../../)
set(LITE_DIR ${TOP_DIR}/mindspore/lite)
set(MS_VERSION_MAJOR ${MS_VERSION_MAJOR})
set(MS_VERSION_MINOR ${MS_VERSION_MINOR})
set(MS_VERSION_REVISION ${MS_VERSION_REVISION})
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} \
  -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} \
  -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")

include_directories(${CMAKE_CURRENT_SOURCE_DIR})
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/linux)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/dataset)


include_directories(${LITE_DIR}) ## lite include
include_directories(${TOP_DIR}) ## api include
include_directories(${TOP_DIR}/mindspore/core/) ## core include
include_directories(${LITE_DIR}/build) ## flatbuffers

set(OP_SRC
        lite_train_jni.cpp
        util.cpp
        bert_train.cpp
        lenet_train.cpp
        dataset/CustomizedTokenizer.cc
            )
find_library(log-lib glog)

add_library(fl SHARED ${OP_SRC})

link_directories(${CMAKE_CURRENT_SOURCE_DIR}/lib/)

install(TARGETS fl LIBRARY DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/lib)
add_executable(test test_train.cc ${OP_SRC} )
target_link_libraries(test mindspore-lite  glog)

In [None]:
//
// Created by meng on 3/30/21.
//
#include "CustomizedTokenizer.h"

CustomizedTokenizer::CustomizedTokenizer() = default;

CustomizedTokenizer::~CustomizedTokenizer() {
  _vocab.clear();
}

void CustomizedTokenizer::init(const string &vocab_file, bool do_lower_case) {
  _do_lower_case = do_lower_case;
  _load_vocab(vocab_file);
}

void CustomizedTokenizer::tokenize(const string &text, string output_tokens[MAX_SEQ_LENGTH], int &seq_length) {
//  clock_t startTime;
//  double time_cost = 0.0;
  _text = text;
  _split_text();

  int output_tokens_pos = 0;

//  startTime = clock();
  for (int i = 0; i < _tokens_length; ++i) {
    int length = _tokens[i].length();
    if (length > _max_input_chars_per_word) {
      output_tokens[output_tokens_pos] = "[UNK]";
      output_tokens_pos++;
      if (MAX_SEQ_LENGTH <= output_tokens_pos) {
        _tokens_length = 0;
        output_tokens[MAX_SEQ_LENGTH - 1] = "[SEP]";
        return;
      }
      continue;
    }

    bool is_bad = false;
    int start = 0;
    vector<string> sub_tokens;
    while (start < length) {
      int end = length;
      string cur_substr;
      while (start < end) {
        string substr = _tokens[i].substr(start, end - start);
        if (start > 0) {
          substr.insert(0,"##");
        }
        if (_vocab.find(substr) != _vocab.end()) {
          cur_substr = substr;
          break;
        }
        end--;
      }
      if (cur_substr.empty()) {
        is_bad = true;
        break;
      }
      sub_tokens.emplace_back(cur_substr);
      start = end;
    }
    if (is_bad) {
      output_tokens[output_tokens_pos] = "[UNK]";
      output_tokens_pos++;
      if (MAX_SEQ_LENGTH <= output_tokens_pos) {
        _tokens_length = 0;
        output_tokens[MAX_SEQ_LENGTH - 1] = "[SEP]";
        return;
      }
    } else {
      for (const string& sub_token: sub_tokens) {
        output_tokens[output_tokens_pos] = sub_token;
        output_tokens_pos++;
        if (MAX_SEQ_LENGTH <= output_tokens_pos) {
          _tokens_length = 0;
          output_tokens[MAX_SEQ_LENGTH - 1] = "[SEP]";
          return;
        }
      }
    }
  }

  output_tokens[output_tokens_pos] = "[SEP]";
  for (int i = output_tokens_pos + 1; i < MAX_SEQ_LENGTH; ++i) {
    output_tokens[i] = "[PAD]";
  }
//  time_cost += (double)(clock() - startTime) / CLOCKS_PER_SEC;
  _tokens_length = 0;

//  cout << "Local run time is: " << time_cost << endl;
}

void CustomizedTokenizer::tokenize(const string &text, int input_ids[MAX_SEQ_LENGTH],
                                   int attention_mask[MAX_SEQ_LENGTH], int token_type_ids[MAX_SEQ_LENGTH],
                                   int &seq_length) {
//  clock_t startTime;
//  double time_cost = 0.0;
  _text = text;
  _split_text();

  int output_tokens_pos = 0;

//  startTime = clock();
  for (int i = 0; i < _tokens_length; ++i) {
    int length = _tokens[i].length();
    if (length > _max_input_chars_per_word) {
      input_ids[output_tokens_pos] = _vocab["[UNK]"];
      attention_mask[output_tokens_pos] = 1;
      token_type_ids[output_tokens_pos] = 0;
      output_tokens_pos++;
      if (MAX_SEQ_LENGTH <= output_tokens_pos) {
        _tokens_length = 0;
        input_ids[MAX_SEQ_LENGTH - 1] = _vocab["[SEP]"];
        attention_mask[MAX_SEQ_LENGTH - 1] = 1;
        token_type_ids[MAX_SEQ_LENGTH - 1] = 0;
        return;
      }
      continue;
    }

    bool is_bad = false;
    int start = 0;
    vector<string> sub_tokens;
    while (start < length) {
      int end = length;
      string cur_substr;
      while (start < end) {
        string substr = _tokens[i].substr(start, end - start);
        if (start > 0) {
          substr.insert(0,"##");
        }
        if (_vocab.find(substr) != _vocab.end()) {
          cur_substr = substr;
          break;
        }
        end--;
      }
      if (cur_substr.empty()) {
        is_bad = true;
        break;
      }
      sub_tokens.emplace_back(cur_substr);
      start = end;
    }
    if (is_bad) {
      input_ids[output_tokens_pos] = _vocab["[UNK]"];
      attention_mask[output_tokens_pos] = 1;
      token_type_ids[output_tokens_pos] = 0;
      output_tokens_pos++;
      if (MAX_SEQ_LENGTH <= output_tokens_pos) {
        _tokens_length = 0;
        input_ids[MAX_SEQ_LENGTH - 1] = _vocab["[SEP]"];
        attention_mask[MAX_SEQ_LENGTH - 1] = 1;
        token_type_ids[MAX_SEQ_LENGTH - 1] = 0;
        return;
      }
    } else {
      for (const string& sub_token: sub_tokens) {
        input_ids[output_tokens_pos] = _vocab[sub_token];
        attention_mask[output_tokens_pos] = 1;
        token_type_ids[output_tokens_pos] = 0;
        output_tokens_pos++;
        if (MAX_SEQ_LENGTH <= output_tokens_pos) {
          _tokens_length = 0;
          input_ids[MAX_SEQ_LENGTH - 1] = _vocab["[SEP]"];
          attention_mask[MAX_SEQ_LENGTH - 1] = 1;
          token_type_ids[MAX_SEQ_LENGTH - 1] = 0;
          return;
        }
      }
    }
  }
//  time_cost += (double)(clock() - startTime) / CLOCKS_PER_SEC;
  input_ids[output_tokens_pos] = _vocab["[SEP]"];
  attention_mask[output_tokens_pos] = 1;
  token_type_ids[output_tokens_pos] = 0;
  for (int i = output_tokens_pos + 1; i < MAX_SEQ_LENGTH; ++i) {
    input_ids[i] = 0;
    attention_mask[i] = 0;
    token_type_ids[i] = 0;
  }
  _tokens_length = 0;

//  cout << "Local run time is: " << time_cost << endl;
}

void CustomizedTokenizer::_lower_token(const string &token, string &new_token) {
  int length = token.length();
  if (token[0] == '[') {
    if (length == 5) {
      if (token[4] == ']') {
        if (token[1] == 'U' && token[2] == 'N' && token[3] == 'K') {
          new_token = token;
          return;
        }
        if (token[1] == 'S' && token[2] == 'E' && token[3] == 'P') {
          new_token = token;
          return;
        }
        if (token[1] == 'P' && token[2] == 'A' && token[3] == 'D') {
          new_token = token;
          return;
        }
        if (token[1] == 'C' && token[2] == 'L' && token[3] == 'S') {
          new_token = token;
          return;
        }
      }
    }
    if (length == 6) {
      if (token[1] == 'M' && token[2] == 'A' && token[3] == 'S' && token[4] == 'K' && token[5] == ']') {
        new_token = token;
        return;
      }
    }
  }

  for (char ch: token) {
    if (ch <= 90 && ch >= 65) {
      new_token += char(ch + 32);
    } else {
      new_token += ch;
    }
  }
}

void CustomizedTokenizer::_clean_text() {
  int pos;
  pos = _text.find("\u00A0");
  while (pos > 0) {
    _text = _text.replace(pos, 2, " ");
    pos = _text.find("\u00A0");
  }
  pos = _text.find("\u2800");
  while (pos > 0) {
    _text = _text.replace(pos, 3, " ");
    pos = _text.find("\u2800");
  }
  pos = _text.find("\u3000");
  while (pos > 0) {
    _text = _text.replace(pos, 3, " ");
    pos = _text.find("\u3000");
  }
  pos = _text.find("\ufeff");
  while (pos > 0) {
    _text = _text.replace(pos, 3, "");
    pos = _text.find("\ufeff");
  }
  pos = _text.find("\ue312");
  while (pos > 0) {
    _text = _text.replace(pos, 3, "");
    pos = _text.find("\ue312");
  }
}

void CustomizedTokenizer::_load_vocab(const string &vocab_file) {
  int index = 0;
  fstream fin;
  fin.open(vocab_file, ios::in);
  if (fin.is_open()) {
    string basicString;
    while (!fin.eof()) {
      getline(fin, basicString, '\n');
      if (!basicString.empty()) {
        _vocab[basicString] = index;
      }
      index++;
    }
    fin.close();
  }
}

void CustomizedTokenizer::_fixed_matching(int &pos, string &token) {
  token += _text[pos];
  int rest_length = _text_length - pos;
  if (rest_length < 2) {
    pos++;
    return;
  }
  if (_text[pos] == char(-16)) {
    if (rest_length < 4) {
      pos++;
      return;
    }
    if (_text[pos+1] < char(0) && _text[pos+2] < char(0) && _text[pos+3] < char(0)) {
      token += _text[pos+1];
      token += _text[pos+2];
      token += _text[pos+3];
      pos += 4;
      return;
    }
    pos++;
    return;
  }
  if (_text[pos] >= char(-62) && _text[pos] <= char(-37)) {
    if (_text[pos+1] < char(0)) {
      token += _text[pos+1];
      pos += 2;
      return;
    }
    pos++;
    return;
  }
  if (rest_length < 3) {
    pos++;
    return;
  }
  if (_text[pos+1] < char(0) && _text[pos+2] < char(0)) {
    token += _text[pos+1];
    token += _text[pos+2];
    pos += 3;
    return;
  }
  pos++;
}

void CustomizedTokenizer::_split_text() {
  // Performs invalid character removal and whitespace cleanup on text.
  _clean_text();
  _text_length = _text.length();
  int pos = 0;
  string token;
  _tokens[_tokens_length++] = "[CLS]";
  while (pos < _text_length) {
    if (_tokens_length == MAX_SEQ_LENGTH) {
      break;
    }
    if (_text[pos] == ' ') {
      if (!token.empty()) {
        if (token[0] >= char(0) && _do_lower_case) {
          string new_token;
          _lower_token(token, new_token);
          _tokens[_tokens_length++] = new_token;
        } else {
          _tokens[_tokens_length++] = token;
        }
        token.clear();
      }
      pos++;
      continue;
    }
    // We treat all non-letter/number ASCII as punctuation.
    // Characters such as "^", "$", and "`" are not in the Unicode
    // Punctuation class but we treat them as punctuation anyways, for
    // consistency.
    if ((_text[pos] >= char(33) && _text[pos] <= char(47)) ||
      (_text[pos] >= char(58) && _text[pos] <= char(64)) ||
      (_text[pos] >= char(91) && _text[pos] <= char(96)) ||
      (_text[pos] >= char(123) && _text[pos] <= char(126))
      ) {
      if (!token.empty()) {
        if (token[0] >= char(0) && _do_lower_case) {
          string new_token;
          _lower_token(token, new_token);
          _tokens[_tokens_length++] = new_token;
        } else {
          _tokens[_tokens_length++] = token;
        }
        token.clear();
      }
      if (_text[pos] == char(91)) {
        if (pos < _text_length - 4) {
          if (_text[pos+1] == 'S' && _text[pos+2] == 'E' && _text[pos+3] == 'P' && _text[pos+4] == ']') {
            _tokens[_tokens_length++] = "[SEP]";
            pos += 5;
            continue;
          }
          if (_text[pos+1] == 'U' && _text[pos+2] == 'N' && _text[pos+3] == 'K' && _text[pos+4] == ']') {
            _tokens[_tokens_length++] = "[UNK]";
            pos += 5;
            continue;
          }
          if (_text[pos+1] == 'P' && _text[pos+2] == 'A' && _text[pos+3] == 'D' && _text[pos+4] == ']') {
            _tokens[_tokens_length++] = "[PAD]";
            pos += 5;
            continue;
          }
          if (_text[pos+1] == 'C' && _text[pos+2] == 'L' && _text[pos+3] == 'S' && _text[pos+4] == ']') {
            _tokens[_tokens_length++] = "[CLS]";
            pos += 5;
            continue;
          }
        }
        if (pos < _text_length - 5) {
          if (_text[pos+1] == 'M' && _text[pos+2] == 'A' && _text[pos+3] == 'S' && _text[pos+4] == 'K' &&
            _text[pos+5] == ']') {
            _tokens[_tokens_length++] = "[MASK]";
            pos += 6;
            continue;
          }
        }
      }
      string temp = {_text[pos]};
      _tokens[_tokens_length++] = temp;
      pos++;
      continue;
    }
    if (_text[pos] < char(0)) {
      if (!token.empty()) {
        if (token[0] >= char(0) && _do_lower_case) {
          string new_token;
          _lower_token(token, new_token);
          _tokens[_tokens_length++] = new_token;
        } else {
          _tokens[_tokens_length++] = token;
        }
        token.clear();
      }
      _fixed_matching(pos, token);
      _tokens[_tokens_length++] = token;
      token.clear();
    } else {
      token += _text[pos];
      pos++;
    }
  }
  if (!token.empty()) {
    if (token[0] >= char(0) && _do_lower_case) {
      string new_token;
      _lower_token(token, new_token);
      _tokens[_tokens_length++] = new_token;
    } else {
      _tokens[_tokens_length++] = token;
    }
  }
}

In [None]:
//
// Created by meng on 3/30/21.
//

#ifndef FEDERALLEARNING_MINDSPORE_LITE_FLCLIENT_SRC_MAIN_NATIVE_DATASET_CUSTOMIZEDTOKENIZER_H_
#define FEDERALLEARNING_MINDSPORE_LITE_FLCLIENT_SRC_MAIN_NATIVE_DATASET_CUSTOMIZEDTOKENIZER_H_
#include <fstream>
#include <iostream>
#include <cstdio>
#include <vector>
#include <map>
#include <ctime>
//#include <cstring>
//#include <algorithm>
//#include <cstdlib>
using namespace std;
#define MAX_SEQ_LENGTH 32
#define BATCH_SIZE 16
class CustomizedTokenizer
{
 public:
  CustomizedTokenizer();
  ~CustomizedTokenizer();

  void init(const std::string &vocab_file, bool do_lower_case);
  void tokenize(const string &text, string output_tokens[MAX_SEQ_LENGTH], int &seq_length);
  void tokenize(const string &text, int input_ids[MAX_SEQ_LENGTH],
                int attention_mask[MAX_SEQ_LENGTH], int token_type_ids[MAX_SEQ_LENGTH], int &seq_length);

//private:
  string _text;
  string _tokens[MAX_SEQ_LENGTH];
  int _tokens_length = 0;
  int _text_length = 0;

  map<string, int> _vocab;
  bool _do_lower_case = true;
  int _max_input_chars_per_word = 100;

  static void _lower_token(const string &token, string &new_token);
  void _split_text();
  void _clean_text();
  void _fixed_matching(int &pos, string &token);
  void _load_vocab(const string &vocab_file);
};



#endif  // FEDERALLEARNING_MINDSPORE_LITE_FLCLIENT_SRC_MAIN_NATIVE_DATASET_CUSTOMIZEDTOKENIZER_H_
