<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

#Data process

In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <jni.h>
#include <cstring>
#include "include/errorcode.h"
#include "include/train/train_session.h"
#include "util.h"
#include "lenet_train.h"
#include "bert_train.h"
#include "src/common/log_adapter.h"

#define MS_PRINT(format, ...) __android_log_print(ANDROID_LOG_INFO, "MSJNI", format, ##__VA_ARGS__)

static jobject fbb;
static jmethodID create_string_char;
static jobject jmap;
static jstring model_path;

char *JstringToChar(JNIEnv *env, jstring jstr) {
  char *rtn = nullptr;
  jclass clsstring = env->FindClass("java/lang/String");
  jstring strencode = env->NewStringUTF("GB2312");
  jmethodID mid = env->GetMethodID(clsstring, "getBytes", "(Ljava/lang/String;)[B");
  jbyteArray barr = (jbyteArray)env->CallObjectMethod(jstr, mid, strencode);
  jsize alen = env->GetArrayLength(barr);
  jbyte *ba = env->GetByteArrayElements(barr, JNI_FALSE);
  if (alen > 0) {
    rtn = new char[alen + 1];
    memcpy(rtn, ba, alen);
    rtn[alen] = 0;
  }
  env->ReleaseByteArrayElements(barr, ba, 0);
  return rtn;
}

extern "C" jint CreateFeatureMap(JNIEnv *env, const char *name, float *data, size_t size) {
  jstring name1 = env->NewStringUTF(name);
  jint name_offset = env->CallIntMethod(fbb, create_string_char, name1);
  // 1. set data size
  jfloatArray ret = env->NewFloatArray(size);
  env->SetFloatArrayRegion(ret, 0, size, data);
  // 2. get methodid createDataVector
  jclass fm_cls = env->FindClass("mindspore/schema/FeatureMap");
  jmethodID createDataVector =
    env->GetStaticMethodID(fm_cls, "createDataVector", "(Lcom/google/flatbuffers/FlatBufferBuilder;[F)I");
  // 3. calc data offset
  jint data_offset = env->CallStaticIntMethod(fm_cls, createDataVector, fbb, ret);
  jmethodID createFeatureMap =
    env->GetStaticMethodID(fm_cls, "createFeatureMap", "(Lcom/google/flatbuffers/FlatBufferBuilder;II)I");
  jint fm_offset = env->CallStaticIntMethod(fm_cls, createFeatureMap, fbb, name_offset, data_offset);
  return fm_offset;
}

extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_NativeTrain_train(JNIEnv *env, jobject thiz,
                                                                             jlong session_ptr,
                                                                             jint batch_size, jint epoches,
                                                                             jint early_stop_type) {
  std::string model_name = JstringToChar(env, model_path);
  if(model_name.find("lenet") != std::string::npos){
    return TrainLenet(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr), JstringToChar(env, model_path),
                      batch_size, epoches);
  }
  return TrainBert(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr), JstringToChar(env, model_path),
                   batch_size, epoches);

}

extern "C" jlong JNICALL Java_com_huawei_flclient_NativeTrain_createSession(JNIEnv *env, jclass, jstring ms_file,
                                                                            jlong) {
  model_path = (jstring)env->NewGlobalRef(ms_file);
  return reinterpret_cast<jlong>(CreateSession(JstringToChar(env, ms_file)));
}
char *CreateLocalModelBuffer(JNIEnv *env, jobject modelBuffer) {
  jbyte *modelAddr = static_cast<jbyte *>(env->GetDirectBufferAddress(modelBuffer));
  int modelLen = static_cast<int>(env->GetDirectBufferCapacity(modelBuffer));
  char *buffer(new char[modelLen]);
  memcpy(buffer, modelAddr, modelLen);
  return buffer;
}

extern "C" jlong JNICALL Java_com_huawei_flclient_NativeTrain_createSessionFromBuffer(JNIEnv *env, jclass,jobject model_buffer,jint num_thread) {

  if (nullptr == model_buffer) {
//    MS_PRINT("error, buffer is nullptr!");
    return (jlong) nullptr;
  }
  jlong bufferLen = env->GetDirectBufferCapacity(model_buffer);
  if (0 == bufferLen) {
//    MS_PRINT("error, bufferLen is 0!");
    return (jlong) nullptr;
  }

  char *modelBuffer = CreateLocalModelBuffer(env, model_buffer);
  if (modelBuffer == nullptr) {
//    MS_PRINT("modelBuffer create failed!");
    return (jlong) nullptr;
  }
  return reinterpret_cast<jlong>(CreateSession(modelBuffer,bufferLen));
}

extern "C" JNIEXPORT jobject JNICALL Java_com_huawei_flclient_NativeTrain_getFeaturesMap(JNIEnv *env, jclass,
                                                                                         jlong session_ptr) {
  jclass strClass = env->FindClass("java/lang/String");
  jmethodID ctorID = env->GetMethodID(strClass, "<init>", "([BLjava/lang/String;)V");
  jstring encoding = env->NewStringUTF("GB2312");

  TrainFeatureParam **train_features = nullptr;
  int feature_size = 0;
  auto status =
    GetFeatures(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr), &train_features, &feature_size);
  if (status != mindspore::lite::RET_OK) {
    MS_LOG(ERROR) << "get features failed:";
    return NULL;
  }
  jclass jmapClass = env->FindClass("java/util/HashMap");
  if (jmapClass == NULL) {
    return NULL;
  }
  jmethodID mid = env->GetMethodID(jmapClass, "<init>", "()V");
  jmethodID putMethod = env->GetMethodID(jmapClass, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
  jmethodID getMethod = env->GetMethodID(jmapClass, "get", "(Ljava/lang/Object;)Ljava/lang/Object;");
  bool map_exist = true;
  if (jmap == nullptr) {
    jmap = env->NewGlobalRef(env->NewObject(jmapClass, mid, feature_size));
    map_exist = false;
  }
  for (int i = 0; i < feature_size; i++) {
    jbyteArray bytes = env->NewByteArray(strlen(train_features[i]->name));
    env->SetByteArrayRegion(bytes, 0, strlen(train_features[i]->name), (jbyte *)train_features[i]->name);
    auto key = (jstring)env->NewObject(strClass, ctorID, bytes, encoding);
    jfloatArray feature_data;
    if (map_exist) {
      feature_data = static_cast<jfloatArray>(env->CallObjectMethod(jmap, getMethod, key));
    } else {
      feature_data = env->NewFloatArray(train_features[i]->elenums);
    }
    if (feature_data == nullptr) {
      std::cout << "create null feature data" << std::endl;
    }
    jfloat *fd = env->GetFloatArrayElements(feature_data, NULL);
    for (int j = 0; j < train_features[i]->elenums; j++) {
      fd[j] = reinterpret_cast<float *>(train_features[i]->data)[j];
    }
    env->ReleaseFloatArrayElements(feature_data, fd, 0);
    env->CallObjectMethod(jmap, putMethod, key, feature_data);
    env->DeleteLocalRef(bytes);
    env->DeleteLocalRef(key);
  }
  env->DeleteLocalRef(encoding);
  for (int i = 0; i < feature_size; i++) {
    delete train_features[i]->name;
    free(train_features[i]->data);
    delete train_features[i];
  }
  return jmap;
}

extern "C" JNIEXPORT jintArray JNICALL Java_com_huawei_flclient_NativeTrain_getSeralizeFeaturesMap(JNIEnv *env,
                                                                                                   jobject thiz,
                                                                                                   jlong session_ptr,
                                                                                                   jobject builder) {
  fbb = builder;
  jclass fb_clazz = env->GetObjectClass(builder);
  create_string_char = env->GetMethodID(fb_clazz, "createString", "(Ljava/lang/CharSequence;)I");
  TrainFeatureParam **train_features = nullptr;
  int feature_size = 0;
  auto status =
    GetFeatures(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr), &train_features, &feature_size);
  if (status != mindspore::lite::RET_OK) {
    MS_LOG(ERROR) << "get features failed:";
    return env->NewIntArray(0);
  }
  jintArray ret = env->NewIntArray(feature_size);
  jint *data = env->GetIntArrayElements(ret, NULL);

  for (int i = 0; i < feature_size; i++) {
    data[i] = CreateFeatureMap(env, train_features[i]->name, reinterpret_cast<float *>(train_features[i]->data),
                               train_features[i]->elenums);
    MS_LOG(INFO) << "upload feature:"
                 << ", name:" << train_features[i]->name << ", elenums:" << train_features[i]->elenums;
  }
  env->ReleaseIntArrayElements(ret, data, 0);
  for (int i = 0; i < feature_size; i++) {
    delete train_features[i]->name;
    free(train_features[i]->data);
    delete train_features[i];
  }
  return ret;
}

extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_NativeTrain_updateFeatures(JNIEnv *env, jclass,
                                                                                      jlong session_ptr,
                                                                                      jobject features) {
  jclass arr_cls = env->GetObjectClass(features);
  jmethodID size_method = env->GetMethodID(arr_cls, "size", "()I");
  jmethodID get_method = env->GetMethodID(arr_cls, "get", "(I)Ljava/lang/Object;");

  jclass fm_cls = env->FindClass("mindspore/schema/FeatureMap");
  jmethodID weight_name_method = env->GetMethodID(fm_cls, "weightFullname", "()Ljava/lang/String;");
  jmethodID data_length_method = env->GetMethodID(fm_cls, "dataLength", "()I");
  jmethodID data_method = env->GetMethodID(fm_cls, "data", "(I)F");
  jclass clsstring = env->FindClass("java/lang/String");
  jmethodID mid = env->GetMethodID(clsstring, "getBytes", "(Ljava/lang/String;)[B");
  int size = env->CallIntMethod(features, size_method);
  // transform FeatureMap to TrainFeatureParm
  TrainFeatureParam *features_param = reinterpret_cast<TrainFeatureParam *>(malloc(size * sizeof(TrainFeatureParam)));
  for (int i = 0; i < size; ++i) {
    TrainFeatureParam *param = features_param + i;
    jobject feature = env->CallObjectMethod(features, get_method, i);
    // set feature_param name
    jstring weight_full_name = (jstring)env->CallObjectMethod(feature, weight_name_method);
    jstring strencode = env->NewStringUTF("GB2312");
    jbyteArray barr = (jbyteArray)env->CallObjectMethod(weight_full_name, mid, strencode);
    char *name = nullptr;
    jsize alen = env->GetArrayLength(barr);
    jbyte *ba = env->GetByteArrayElements(barr, JNI_FALSE);
    if (alen > 0) {
      name = new char[alen + 1];
      if (ba == nullptr) {
        MS_LOG(ERROR) << "name is nullptr";
        return mindspore::lite::RET_ERROR;
      }
      memcpy(name, ba, alen);
      name[alen] = 0;
    }
    param->name = name;
    env->ReleaseByteArrayElements(barr, ba, 0);
    int data_length = env->CallIntMethod(feature, data_length_method);
    float *data = static_cast<float *>(malloc(data_length * sizeof(float)));
    memset(data, 0, data_length * sizeof(float));
    for (int j = 0; j < data_length; ++j) {
      float *addr = data + j;
      *addr = env->CallFloatMethod(feature, data_method, j);
    }
    param->data = data;
    param->elenums = data_length;
    param->type = mindspore::kNumberTypeFloat32;
    MS_LOG(INFO) << "get feature:" << param->name << ",elenums:" << param->elenums;
  }
  return UpdateFeatures(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr), JstringToChar(env, model_path),
                        features_param, size);
}

extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_NativeTrain_setInput(JNIEnv *env, jobject, jstring files) {
  std::string input_files = JstringToChar(env, files);
  std::string pattern = ",";
  std::string strs = input_files + pattern;
  size_t pos = strs.find(pattern);
  std::vector<std::string> res;
  while (pos != strs.npos) {
    std::string temp = strs.substr(0, pos);
    res.push_back(temp);
    strs = strs.substr(pos + 1, strs.size());
    pos = strs.find(pattern);
  }
  if (res.size() == 2) {
    return SetLenetInputs(res[0], res[1]);
  } else if (res.size() == 3) {
    return SetBertInputs(res[0], res[1], res[2]);
  }
  std::cout << "input files error" << std::endl;
  return -1;
}

extern "C" JNIEXPORT jfloat JNICALL Java_com_huawei_flclient_NativeTrain_infer(JNIEnv *env, jclass, jlong session_ptr) {

  std::string model_name = JstringToChar(env, model_path);
  if(model_name.find("lenet") != std::string::npos){
    return InferLenet(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr));
  }
  return InferBert(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr));
}

extern "C" JNIEXPORT jintArray JNICALL Java_com_huawei_flclient_NativeTrain_getInferLabels(JNIEnv *env, jclass, jlong session_ptr) {

  std::string model_name = JstringToChar(env, model_path);
  std::vector<int> infer_result;
  if(model_name.find("lenet") != std::string::npos){
   infer_result = GetLenetInferRes(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr));
  } else {
    infer_result = GetBertInferRes(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr));
  }
  jintArray jArray = env->NewIntArray(infer_result.size());
  jint *jnum = new jint[infer_result.size()];
  for(int i=0;i<infer_result.size();i++) {
    *(jnum+i) = infer_result[i];
  }
  env->SetIntArrayRegion(jArray, 0, infer_result.size(), jnum);
  return jArray;
}

extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_NativeTrain_getInferLabel(JNIEnv *env, jclass, jlong session_ptr,jstring input_str ,jstring vocab_file) {
  return infer(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr),JstringToChar(env, input_str),JstringToChar(env, vocab_file));
}

extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_NativeTrain_getInferLabelFromVocab(JNIEnv *env, jclass, jlong session_ptr,jstring input_str ,jobjectArray vocab_array) {

  jsize size = env->GetArrayLength(vocab_array);
  std::string c_vocab_array[size];
  for(int i=0;i<size;i++) {
    jstring jstr = (jstring)env->GetObjectArrayElement(vocab_array, i);
    const jsize strLen = env->GetStringUTFLength(jstr);
    const char *charBuffer = env->GetStringUTFChars(jstr, 0);
    c_vocab_array[i] = std::string (charBuffer, strLen);
    env->ReleaseStringUTFChars(jstr, charBuffer);
    env->DeleteLocalRef(jstr);
  }
  return inferFromVocabArr(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr),JstringToChar(env, input_str),c_vocab_array);
}

extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_NativeTrain_free(JNIEnv *env, jclass, jlong session_ptr) {
  env->DeleteGlobalRef(jmap);
  env->DeleteGlobalRef(model_path);
  jmap = NULL;
  if (0 != session_ptr) {
    delete (reinterpret_cast<mindspore::session::TrainSession *>(session_ptr));
  }
  std::string model_name = JstringToChar(env, model_path);
  if(model_name.find("lenet") != std::string::npos){
    FreeLenetInput();
  } else {
    FreeBertInput();
  }
  return 0;
}


In [None]:
void CustomizedTokenizer::initFromVocab(string vocab_array[], bool do_lower_case) {
  _do_lower_case = do_lower_case;
  _load_vocabFromVocab(vocab_array);
}

In [None]:
void CustomizedTokenizer::_load_vocabFromVocab(string vocab_file[]) {
  for(int i=0;i<vocab_file->length();i++) {
    _vocab[vocab_file[i]] = i;
  }
}

In [None]:
TrainSession *CreateSession(char* model_buffer,size_t buffen_len) {
  // create model file
  mindspore::lite::Context context;
  context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = mindspore::lite::NO_BIND;
  context.thread_num_ = 1;
  bool train_mode = false;
  size_t size=0;
//  auto *model = mindspore::lite::Model::Import(ms_file.c_str(),size);
  return mindspore::session::TrainSession::CreateSession(model_buffer,buffen_len,&context, train_mode);
}