<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

#Data process

In [None]:
From df8f267f67caf86040fccd1d80421173b6ba4f22 Mon Sep 17 00:00:00 2001
From: guohongzilong <guohongzilong@huawei.com>
Date: Fri, 26 Mar 2021 15:51:00 +0800
Subject: [PATCH] code clean and remove micro code

---
 .../main/java/com/huawei/flclient/LiteTrain.java   |  16 +-
 .../lite/flclient/src/main/native/CMakeLists.txt   |  54 +--
 .../src/main/native/com_huawei_flclient_Train.h    |  69 ---
 .../lite/flclient/src/main/native/data_prepare.cpp | 517 +++++++++++++++++++++
 .../flclient/src/main/native/lenet_train_jni.cpp   | 157 -------
 .../lite/flclient/src/main/native/lite_train.cpp   | 287 ++++++++++++
 .../lite/flclient/src/main/native/lite_train.h     |  35 ++
 .../flclient/src/main/native/lite_train_jni.cpp    | 219 +++++++++
 mindspore/lite/include/train/train_session.h       |  11 +
 mindspore/lite/nnacl/infer/arithmetic_grad_infer.c |   1 -
 mindspore/lite/nnacl/infer/maximum_grad_infer.c    |   1 +
 mindspore/lite/src/train/train_session.cc          |  57 ++-
 mindspore/lite/src/train/train_session.h           |   4 +
 13 files changed, 1148 insertions(+), 280 deletions(-)
 delete mode 100644 mindspore/lite/flclient/src/main/native/com_huawei_flclient_Train.h
 create mode 100644 mindspore/lite/flclient/src/main/native/data_prepare.cpp
 delete mode 100644 mindspore/lite/flclient/src/main/native/lenet_train_jni.cpp
 create mode 100644 mindspore/lite/flclient/src/main/native/lite_train.cpp
 create mode 100644 mindspore/lite/flclient/src/main/native/lite_train.h
 create mode 100644 mindspore/lite/flclient/src/main/native/lite_train_jni.cpp

diff --git a/mindspore/lite/flclient/src/main/java/com/huawei/flclient/LiteTrain.java b/mindspore/lite/flclient/src/main/java/com/huawei/flclient/LiteTrain.java
index 70ea94f..5f9a8de 100644
--- a/mindspore/lite/flclient/src/main/java/com/huawei/flclient/LiteTrain.java
+++ b/mindspore/lite/flclient/src/main/java/com/huawei/flclient/LiteTrain.java
@@ -31,7 +31,7 @@ public  class LiteTrain {
     static {
         System.loadLibrary("fl");
     }
-
+    private static HashMap featureMaps;
     private static LiteTrain train;
 
     private LiteTrain() {
@@ -45,7 +45,7 @@ public  class LiteTrain {
 
     public FlatBufferBuilder FeatureMapBuilder(String modelName){
         FlatBufferBuilder builder = new FlatBufferBuilder();
-        int[] fmOffsets = getFeaturesMap(modelName,builder);
+        int[] fmOffsets = getSeralizeFeaturesMap(modelName,builder);
         int fmOffset = FeatureMapList.createFeatureMapVector(builder, fmOffsets);
         RequestUpdateModel.startRequestUpdateModel(builder);
         RequestUpdateModel.addFlName(builder, 0);
@@ -103,7 +103,7 @@ public  class LiteTrain {
      *
      * @return status
      */
-    public native int inference(String modelName,int batch_num,int test_nums);
+    public native float inference(String modelName,int batch_num);
 
     /**
      * train
@@ -118,7 +118,15 @@ public  class LiteTrain {
      * @param builder FlatBufferBuilder
      * @return features offset
      */
-    native int[] getFeaturesMap(String modelName,FlatBufferBuilder builder);
+   native int[] getSeralizeFeaturesMap(String modelName, FlatBufferBuilder builder);
+
+         /**
+             * get the features map of training model
+             *
+             * @param modelName model name
+             * @return features map
+             */
+   native Map<String,float[]> getFeaturesMap(String modelName);
 
     /**
      * update the features map of training model
diff --git a/mindspore/lite/flclient/src/main/native/CMakeLists.txt b/mindspore/lite/flclient/src/main/native/CMakeLists.txt
index 1ee548e..81ea6d7 100644
--- a/mindspore/lite/flclient/src/main/native/CMakeLists.txt
+++ b/mindspore/lite/flclient/src/main/native/CMakeLists.txt
@@ -20,63 +20,21 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
 
 include_directories(${CMAKE_CURRENT_SOURCE_DIR})
 include_directories(${CMAKE_CURRENT_SOURCE_DIR}/linux)
-include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
-include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src)
-if(ENABLE_MICRO)
-include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/runtime)
-endif()
+
+
 include_directories(${LITE_DIR}) ## lite include
 include_directories(${TOP_DIR}) ## api include
 include_directories(${TOP_DIR}/mindspore/core/) ## core include
 include_directories(${LITE_DIR}/build) ## flatbuffers
 
-if(ENABLE_MICRO)
 set(OP_SRC
-    src/nnacl/arithmetic_common.c
-    src/nnacl/common_func.c
-    src/nnacl/fp32/activation.c
-    src/nnacl/fp32/arithmetic.c
-    src/nnacl/fp32/common_func.c
-    src/nnacl/fp32/conv.c
-    src/nnacl/fp32/matmul.c
-    src/nnacl/fp32/softmax.c
-    src/nnacl/fp32_grad/activation_grad.c
-    src/nnacl/fp32_grad/gemm.c
-    src/nnacl/fp32_grad/pack_ext.c
-    src/nnacl/fp32_grad/pooling_grad.c
-    src/nnacl/int8/conv_int8.c
-    src/nnacl/int8/matmul_int8.c
-    src/nnacl/minimal_filtering_generator.c
-    src/nnacl/pack.c
-    src/nnacl/quantization/fixed_point.c
-    src/nnacl/reshape.c
-    src/nnacl/winograd_transform.c
-    src/nnacl/winograd_utils.c
-    src/runtime/kernel/fp32/max_pooling.c
-    src/runtime/kernel/fp32_grad/apply_momentum.c
-    src/runtime/kernel/fp32_grad/biasadd_grad.c
-    src/runtime/kernel/fp32_grad/compute_gradient.c
-    src/runtime/kernel/fp32_grad/conv_filter_grad.c
-    src/runtime/kernel/fp32_grad/conv_input_grad.c
-    src/runtime/kernel/fp32_grad/init_matrix.c
-    src/runtime/kernel/fp32_grad/sparse_softmax_cross_entropy_with_logist.c
-    src/runtime/load_input.c
-    src/fl_lenet.c
-    src/weight_files/fl_lenet_weight_epoch_0.c
-)
-else()
-    set(OP_SRC
-            src/lenet_train.cpp
+        lite_train.cpp
+        data_prepare.cpp
             )
-    endif()
-if(ENABLE_MICRO)
-    set(SRC_FILES
-            flearning.cpp)
-    else()
+
 set(SRC_FILES
-        lenet_train_jni.cpp
+        lite_train_jni.cpp
         )
-endif()
 find_library(log-lib glog)
 
 add_library(fl SHARED ${SRC_FILES} ${OP_SRC})
diff --git a/mindspore/lite/flclient/src/main/native/com_huawei_flclient_Train.h b/mindspore/lite/flclient/src/main/native/com_huawei_flclient_Train.h
deleted file mode 100644
index cd0345e..0000000
--- a/mindspore/lite/flclient/src/main/native/com_huawei_flclient_Train.h
+++ /dev/null
@@ -1,69 +0,0 @@
-/* DO NOT EDIT THIS FILE - it is machine generated */
-#include <jni.h>
-/* Header for class com_huawei_flclient_Train */
-
-#ifndef _Included_com_huawei_flclient_Train
-#define _Included_com_huawei_flclient_Train
-#ifdef __cplusplus
-extern "C" {
-#endif
-/*
- * Class:     com_huawei_flclient_Train
- * Method:    setInput
- * Signature: (Ljava/lang/String;I)I
- */
-JNIEXPORT jint JNICALL Java_com_huawei_flclient_Train_setInput
-  (JNIEnv *, jobject, jstring, jint);
-
-/*
- * Class:     com_huawei_flclient_Train
- * Method:    prepare
- * Signature: ()I
- */
-JNIEXPORT jint JNICALL Java_com_huawei_flclient_Train_prepare
-  (JNIEnv *, jobject);
-
-/*
- * Class:     com_huawei_flclient_Train
- * Method:    Inference
- * Signature: ()I
- */
-JNIEXPORT jint JNICALL Java_com_huawei_flclient_Train_inference
-  (JNIEnv *, jobject);
-
-/*
- * Class:     com_huawei_flclient_Train
- * Method:    Train
- * Signature: (II)I
- */
-JNIEXPORT jint JNICALL Java_com_huawei_flclient_Train_train
-  (JNIEnv *, jobject, jint, jint);
-
-/*
- * Class:     com_huawei_flclient_Train
- * Method:    getFeaturesMap
- * Signature: (Lcom/google/flatbuffers/FlatBufferBuilder;)[I
- */
-JNIEXPORT jintArray JNICALL Java_com_huawei_flclient_Train_getFeaturesMap
-  (JNIEnv *, jobject, jobject);
-
-/*
- * Class:     com_huawei_flclient_Train
- * Method:    updateFeatures
- * Signature: (Ljava/util/ArrayList;)I
- */
-JNIEXPORT jint JNICALL Java_com_huawei_flclient_Train_updateFeatures
-  (JNIEnv *, jobject, jobject);
-
-/*
- * Class:     com_huawei_flclient_Train
- * Method:    free
- * Signature: ()I
- */
-JNIEXPORT jint JNICALL Java_com_huawei_flclient_Train_free
-  (JNIEnv *, jobject);
-
-#ifdef __cplusplus
-}
-#endif
-#endif
diff --git a/mindspore/lite/flclient/src/main/native/data_prepare.cpp b/mindspore/lite/flclient/src/main/native/data_prepare.cpp
new file mode 100644
index 0000000..c8a1b5e
--- /dev/null
+++ b/mindspore/lite/flclient/src/main/native/data_prepare.cpp
@@ -0,0 +1,517 @@
+/**
+ * Copyright 2020 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+#include <fstream>
+#include <iostream>
+#include <cstdio>
+#include <vector>
+#include <map>
+#include <ctime>
+//#include <cstring>
+//#include <algorithm>
+//#include <cstdlib>
+using namespace std;
+#define MAX_SEQ_LENGTH 256
+
+class CustomizedTokenizer
+{
+ public:
+  CustomizedTokenizer();
+  ~CustomizedTokenizer();
+
+  void init(const string &vocab_file, bool do_lower_case);
+  void tokenize(const string &text, string output_tokens[MAX_SEQ_LENGTH], int &seq_length);
+  void tokenize(const string &text, int input_ids[MAX_SEQ_LENGTH],
+                int attention_mask[MAX_SEQ_LENGTH], int token_type_ids[MAX_SEQ_LENGTH], int &seq_length);
+
+//private:
+  string _text;
+  string _tokens[MAX_SEQ_LENGTH];
+  int _tokens_length = 0;
+  int _text_length = 0;
+
+  map<string, int> _vocab;
+  bool _do_lower_case = true;
+  int _max_input_chars_per_word = 100;
+
+  static void _lower_token(const string &token, string &new_token);
+  void _split_text();
+  void _clean_text();
+  void _fixed_matching(int &pos, string &token);
+  void _load_vocab(const string &vocab_file);
+};
+
+CustomizedTokenizer::CustomizedTokenizer() = default;
+
+CustomizedTokenizer::~CustomizedTokenizer() {
+  _vocab.clear();
+}
+
+void CustomizedTokenizer::init(const string &vocab_file, bool do_lower_case) {
+  _do_lower_case = do_lower_case;
+  _load_vocab(vocab_file);
+}
+
+void CustomizedTokenizer::tokenize(const string &text, string output_tokens[MAX_SEQ_LENGTH], int &seq_length) {
+//  clock_t startTime;
+//  double time_cost = 0.0;
+  _text = text;
+  _split_text();
+
+  int output_tokens_pos = 0;
+
+//  startTime = clock();
+  for (int i = 0; i < _tokens_length; ++i) {
+    int length = _tokens[i].length();
+    if (length > _max_input_chars_per_word) {
+      output_tokens[output_tokens_pos] = "[UNK]";
+      output_tokens_pos++;
+      if (MAX_SEQ_LENGTH <= output_tokens_pos) {
+        _tokens_length = 0;
+        output_tokens[MAX_SEQ_LENGTH - 1] = "[SEP]";
+        return;
+      }
+      continue;
+    }
+
+    bool is_bad = false;
+    int start = 0;
+    vector<string> sub_tokens;
+    while (start < length) {
+      int end = length;
+      string cur_substr;
+      while (start < end) {
+        string substr = _tokens[i].substr(start, end - start);
+        if (start > 0) {
+          substr.insert(0,"##");
+        }
+        if (_vocab.find(substr) != _vocab.end()) {
+          cur_substr = substr;
+          break;
+        }
+        end--;
+      }
+      if (cur_substr.empty()) {
+        is_bad = true;
+        break;
+      }
+      sub_tokens.emplace_back(cur_substr);
+      start = end;
+    }
+    if (is_bad) {
+      output_tokens[output_tokens_pos] = "[UNK]";
+      output_tokens_pos++;
+      if (MAX_SEQ_LENGTH <= output_tokens_pos) {
+        _tokens_length = 0;
+        output_tokens[MAX_SEQ_LENGTH - 1] = "[SEP]";
+        return;
+      }
+    } else {
+      for (const string& sub_token: sub_tokens) {
+        output_tokens[output_tokens_pos] = sub_token;
+        output_tokens_pos++;
+        if (MAX_SEQ_LENGTH <= output_tokens_pos) {
+          _tokens_length = 0;
+          output_tokens[MAX_SEQ_LENGTH - 1] = "[SEP]";
+          return;
+        }
+      }
+    }
+  }
+
+  output_tokens[output_tokens_pos] = "[SEP]";
+  for (int i = output_tokens_pos + 1; i < MAX_SEQ_LENGTH; ++i) {
+    output_tokens[i] = "[PAD]";
+  }
+//  time_cost += (double)(clock() - startTime) / CLOCKS_PER_SEC;
+  _tokens_length = 0;
+
+//  cout << "Local run time is: " << time_cost << endl;
+}
+
+void CustomizedTokenizer::tokenize(const string &text, int input_ids[MAX_SEQ_LENGTH],
+                                   int attention_mask[MAX_SEQ_LENGTH], int token_type_ids[MAX_SEQ_LENGTH],
+                                   int &seq_length) {
+//  clock_t startTime;
+//  double time_cost = 0.0;
+  _text = text;
+  _split_text();
+
+  int output_tokens_pos = 0;
+
+//  startTime = clock();
+  for (int i = 0; i < _tokens_length; ++i) {
+    int length = _tokens[i].length();
+    if (length > _max_input_chars_per_word) {
+      input_ids[output_tokens_pos] = _vocab["[UNK]"];
+      attention_mask[output_tokens_pos] = 1;
+      token_type_ids[output_tokens_pos] = 0;
+      output_tokens_pos++;
+      if (MAX_SEQ_LENGTH <= output_tokens_pos) {
+        _tokens_length = 0;
+        input_ids[MAX_SEQ_LENGTH - 1] = _vocab["[SEP]"];
+        attention_mask[MAX_SEQ_LENGTH - 1] = 1;
+        token_type_ids[MAX_SEQ_LENGTH - 1] = 0;
+        return;
+      }
+      continue;
+    }
+
+    bool is_bad = false;
+    int start = 0;
+    vector<string> sub_tokens;
+    while (start < length) {
+      int end = length;
+      string cur_substr;
+      while (start < end) {
+        string substr = _tokens[i].substr(start, end - start);
+        if (start > 0) {
+          substr.insert(0,"##");
+        }
+        if (_vocab.find(substr) != _vocab.end()) {
+          cur_substr = substr;
+          break;
+        }
+        end--;
+      }
+      if (cur_substr.empty()) {
+        is_bad = true;
+        break;
+      }
+      sub_tokens.emplace_back(cur_substr);
+      start = end;
+    }
+    if (is_bad) {
+      input_ids[output_tokens_pos] = _vocab["[UNK]"];
+      attention_mask[output_tokens_pos] = 1;
+      token_type_ids[output_tokens_pos] = 0;
+      output_tokens_pos++;
+      if (MAX_SEQ_LENGTH <= output_tokens_pos) {
+        _tokens_length = 0;
+        input_ids[MAX_SEQ_LENGTH - 1] = _vocab["[SEP]"];
+        attention_mask[MAX_SEQ_LENGTH - 1] = 1;
+        token_type_ids[MAX_SEQ_LENGTH - 1] = 0;
+        return;
+      }
+    } else {
+      for (const string& sub_token: sub_tokens) {
+        input_ids[output_tokens_pos] = _vocab[sub_token];
+        attention_mask[output_tokens_pos] = 1;
+        token_type_ids[output_tokens_pos] = 0;
+        output_tokens_pos++;
+        if (MAX_SEQ_LENGTH <= output_tokens_pos) {
+          _tokens_length = 0;
+          input_ids[MAX_SEQ_LENGTH - 1] = _vocab["[SEP]"];
+          attention_mask[MAX_SEQ_LENGTH - 1] = 1;
+          token_type_ids[MAX_SEQ_LENGTH - 1] = 0;
+          return;
+        }
+      }
+    }
+  }
+//  time_cost += (double)(clock() - startTime) / CLOCKS_PER_SEC;
+  input_ids[output_tokens_pos] = _vocab["[SEP]"];
+  attention_mask[output_tokens_pos] = 1;
+  token_type_ids[output_tokens_pos] = 0;
+  for (int i = output_tokens_pos + 1; i < MAX_SEQ_LENGTH; ++i) {
+    input_ids[i] = 0;
+    attention_mask[i] = 0;
+    token_type_ids[i] = 0;
+  }
+  _tokens_length = 0;
+
+//  cout << "Local run time is: " << time_cost << endl;
+}
+
+void CustomizedTokenizer::_lower_token(const string &token, string &new_token) {
+  int length = token.length();
+  if (token[0] == '[') {
+    if (length == 5) {
+      if (token[4] == ']') {
+        if (token[1] == 'U' && token[2] == 'N' && token[3] == 'K') {
+          new_token = token;
+          return;
+        }
+        if (token[1] == 'S' && token[2] == 'E' && token[3] == 'P') {
+          new_token = token;
+          return;
+        }
+        if (token[1] == 'P' && token[2] == 'A' && token[3] == 'D') {
+          new_token = token;
+          return;
+        }
+        if (token[1] == 'C' && token[2] == 'L' && token[3] == 'S') {
+          new_token = token;
+          return;
+        }
+      }
+    }
+    if (length == 6) {
+      if (token[1] == 'M' && token[2] == 'A' && token[3] == 'S' && token[4] == 'K' && token[5] == ']') {
+        new_token = token;
+        return;
+      }
+    }
+  }
+
+  for (char ch: token) {
+    if (ch <= 90 && ch >= 65) {
+      new_token += char(ch + 32);
+    } else {
+      new_token += ch;
+    }
+  }
+}
+
+void CustomizedTokenizer::_clean_text() {
+  int pos;
+  pos = _text.find("\u00A0");
+  while (pos > 0) {
+    _text = _text.replace(pos, 2, " ");
+    pos = _text.find("\u00A0");
+  }
+  pos = _text.find("\u2800");
+  while (pos > 0) {
+    _text = _text.replace(pos, 3, " ");
+    pos = _text.find("\u2800");
+  }
+  pos = _text.find("\u3000");
+  while (pos > 0) {
+    _text = _text.replace(pos, 3, " ");
+    pos = _text.find("\u3000");
+  }
+  pos = _text.find("\ufeff");
+  while (pos > 0) {
+    _text = _text.replace(pos, 3, "");
+    pos = _text.find("\ufeff");
+  }
+  pos = _text.find("\ue312");
+  while (pos > 0) {
+    _text = _text.replace(pos, 3, "");
+    pos = _text.find("\ue312");
+  }
+}
+
+void CustomizedTokenizer::_load_vocab(const string &vocab_file) {
+  int index = 0;
+  fstream fin;
+  fin.open(vocab_file, ios::in);
+  if (fin.is_open()) {
+    string basicString;
+    while (!fin.eof()) {
+      getline(fin, basicString, '\n');
+      if (!basicString.empty()) {
+        _vocab[basicString] = index;
+      }
+      index++;
+    }
+    fin.close();
+  }
+}
+
+void CustomizedTokenizer::_fixed_matching(int &pos, string &token) {
+  token += _text[pos];
+  int rest_length = _text_length - pos;
+  if (rest_length < 2) {
+    pos++;
+    return;
+  }
+  if (_text[pos] == char(-16)) {
+    if (rest_length < 4) {
+      pos++;
+      return;
+    }
+    if (_text[pos+1] < char(0) && _text[pos+2] < char(0) && _text[pos+3] < char(0)) {
+      token += _text[pos+1];
+      token += _text[pos+2];
+      token += _text[pos+3];
+      pos += 4;
+      return;
+    }
+    pos++;
+    return;
+  }
+  if (_text[pos] >= char(-62) && _text[pos] <= char(-37)) {
+    if (_text[pos+1] < char(0)) {
+      token += _text[pos+1];
+      pos += 2;
+      return;
+    }
+    pos++;
+    return;
+  }
+  if (rest_length < 3) {
+    pos++;
+    return;
+  }
+  if (_text[pos+1] < char(0) && _text[pos+2] < char(0)) {
+    token += _text[pos+1];
+    token += _text[pos+2];
+    pos += 3;
+    return;
+  }
+  pos++;
+}
+
+void CustomizedTokenizer::_split_text() {
+  // Performs invalid character removal and whitespace cleanup on text.
+  _clean_text();
+  _text_length = _text.length();
+  int pos = 0;
+  string token;
+  _tokens[_tokens_length++] = "[CLS]";
+  while (pos < _text_length) {
+    if (_tokens_length == MAX_SEQ_LENGTH) {
+      break;
+    }
+    if (_text[pos] == ' ') {
+      if (!token.empty()) {
+        if (token[0] >= char(0) && _do_lower_case) {
+          string new_token;
+          _lower_token(token, new_token);
+          _tokens[_tokens_length++] = new_token;
+        } else {
+          _tokens[_tokens_length++] = token;
+        }
+        token.clear();
+      }
+      pos++;
+      continue;
+    }
+    // We treat all non-letter/number ASCII as punctuation.
+    // Characters such as "^", "$", and "`" are not in the Unicode
+    // Punctuation class but we treat them as punctuation anyways, for
+    // consistency.
+    if ((_text[pos] >= char(33) && _text[pos] <= char(47)) ||
+      (_text[pos] >= char(58) && _text[pos] <= char(64)) ||
+      (_text[pos] >= char(91) && _text[pos] <= char(96)) ||
+      (_text[pos] >= char(123) && _text[pos] <= char(126))
+      ) {
+      if (!token.empty()) {
+        if (token[0] >= char(0) && _do_lower_case) {
+          string new_token;
+          _lower_token(token, new_token);
+          _tokens[_tokens_length++] = new_token;
+        } else {
+          _tokens[_tokens_length++] = token;
+        }
+        token.clear();
+      }
+      if (_text[pos] == char(91)) {
+        if (pos < _text_length - 4) {
+          if (_text[pos+1] == 'S' && _text[pos+2] == 'E' && _text[pos+3] == 'P' && _text[pos+4] == ']') {
+            _tokens[_tokens_length++] = "[SEP]";
+            pos += 5;
+            continue;
+          }
+          if (_text[pos+1] == 'U' && _text[pos+2] == 'N' && _text[pos+3] == 'K' && _text[pos+4] == ']') {
+            _tokens[_tokens_length++] = "[UNK]";
+            pos += 5;
+            continue;
+          }
+          if (_text[pos+1] == 'P' && _text[pos+2] == 'A' && _text[pos+3] == 'D' && _text[pos+4] == ']') {
+            _tokens[_tokens_length++] = "[PAD]";
+            pos += 5;
+            continue;
+          }
+          if (_text[pos+1] == 'C' && _text[pos+2] == 'L' && _text[pos+3] == 'S' && _text[pos+4] == ']') {
+            _tokens[_tokens_length++] = "[CLS]";
+            pos += 5;
+            continue;
+          }
+        }
+        if (pos < _text_length - 5) {
+          if (_text[pos+1] == 'M' && _text[pos+2] == 'A' && _text[pos+3] == 'S' && _text[pos+4] == 'K' &&
+            _text[pos+5] == ']') {
+            _tokens[_tokens_length++] = "[MASK]";
+            pos += 6;
+            continue;
+          }
+        }
+      }
+      string temp = {_text[pos]};
+      _tokens[_tokens_length++] = temp;
+      pos++;
+      continue;
+    }
+    if (_text[pos] < char(0)) {
+      if (!token.empty()) {
+        if (token[0] >= char(0) && _do_lower_case) {
+          string new_token;
+          _lower_token(token, new_token);
+          _tokens[_tokens_length++] = new_token;
+        } else {
+          _tokens[_tokens_length++] = token;
+        }
+        token.clear();
+      }
+      _fixed_matching(pos, token);
+      _tokens[_tokens_length++] = token;
+      token.clear();
+    } else {
+      token += _text[pos];
+      pos++;
+    }
+  }
+  if (!token.empty()) {
+    if (token[0] >= char(0) && _do_lower_case) {
+      string new_token;
+      _lower_token(token, new_token);
+      _tokens[_tokens_length++] = new_token;
+    } else {
+      _tokens[_tokens_length++] = token;
+    }
+  }
+}
+
+int main()
+{
+  cout << "-----------" << endl;
+  string vocab_file = "/home/lizheng/tcwu/model_save/nlp/albert_chinese_tiny/vocab.txt";
+  bool do_lower_case = true;
+  CustomizedTokenizer customized_tokenizer;
+  customized_tokenizer.init(vocab_file, do_lower_case);
+  clock_t startTime = clock();
+
+  string text = "";
+  string tokens[MAX_SEQ_LENGTH];
+  int input_ids[MAX_SEQ_LENGTH];
+  int attention_mask[MAX_SEQ_LENGTH];
+  int token_type_ids[MAX_SEQ_LENGTH];
+  int seq_length;
+  for (int i = 0; i < 10000; ++i) {
+    customized_tokenizer.tokenize(text, input_ids, attention_mask, token_type_ids, seq_length);
+//    for (int input_id : input_ids) {
+//      cout << input_id << ' ';
+//    }
+//    cout << endl;
+//    for (int j : attention_mask) {
+//      cout << j << ' ';
+//    }
+//    cout << endl;
+//    for (int token_type_id : token_type_ids) {
+//      cout << token_type_id << ' ';
+//    }
+//    cout << endl;
+
+    if (i == 1)
+      break;
+  }
+  cout << "The run time is: " << (double)(clock() - startTime) / CLOCKS_PER_SEC << endl;
+  cout << "-----------" << endl;
+  return 0;
+}
diff --git a/mindspore/lite/flclient/src/main/native/lenet_train_jni.cpp b/mindspore/lite/flclient/src/main/native/lenet_train_jni.cpp
deleted file mode 100644
index db070c0..0000000
--- a/mindspore/lite/flclient/src/main/native/lenet_train_jni.cpp
+++ /dev/null
@@ -1,157 +0,0 @@
-/**
- * Copyright 2020 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <lenet_train_jni.h>
-#include <jni.h>
-#include <cstring>
-#include "include/errorcode.h"
-#include "include/train_session.h"
-#include "include/lenet_train.h"
-#include "src/common/log_adapter.h"
-
-static jobject fbb;
-static jmethodID create_string_char;
-
-char *JstringToChar(JNIEnv *env, jstring jstr) {
-  char *rtn = nullptr;
-  jclass clsstring = env->FindClass("java/lang/String");
-  jstring strencode = env->NewStringUTF("GB2312");
-  jmethodID mid = env->GetMethodID(clsstring, "getBytes", "(Ljava/lang/String;)[B");
-  jbyteArray barr = (jbyteArray)env->CallObjectMethod(jstr, mid, strencode);
-  jsize alen = env->GetArrayLength(barr);
-  jbyte *ba = env->GetByteArrayElements(barr, JNI_FALSE);
-  if (alen > 0) {
-    rtn = new char[alen + 1];
-    memcpy(rtn, ba, alen);
-    rtn[alen] = 0;
-  }
-  env->ReleaseByteArrayElements(barr, ba, 0);
-  return rtn;
-}
-extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_LiteTrain_train(JNIEnv *env, jobject thiz, jstring ms_file,
-                                                                           jint batch_num, jint iterations) {
-  return fl_lenet_lite_Train(JstringToChar(env, ms_file), batch_num, iterations);
-}
-
-extern "C" jint CreateFeatureMap(JNIEnv *env, const char *name, float *data, size_t size) {
-  jstring name1 = env->NewStringUTF(name);
-  jint name_offset = env->CallIntMethod(fbb, create_string_char, name1);
-  // 1. set data size
-  jfloatArray ret = env->NewFloatArray(size);
-  env->SetFloatArrayRegion(ret, 0, size, data);
-  // 2. get methodid createDataVector
-  jclass fm_cls = env->FindClass("mindspore/schema/FeatureMap");
-  jmethodID createDataVector =
-    env->GetStaticMethodID(fm_cls, "createDataVector", "(Lcom/google/flatbuffers/FlatBufferBuilder;[F)I");
-  // 3. calc data offset
-  jint data_offset = env->CallStaticIntMethod(fm_cls, createDataVector, fbb, ret);
-  jmethodID createFeatureMap =
-    env->GetStaticMethodID(fm_cls, "createFeatureMap", "(Lcom/google/flatbuffers/FlatBufferBuilder;II)I");
-  jint fm_offset = env->CallStaticIntMethod(fm_cls, createFeatureMap, fbb, name_offset, data_offset);
-  return fm_offset;
-}
-
-extern "C" JNIEXPORT jintArray JNICALL Java_com_huawei_flclient_LiteTrain_getFeaturesMap(JNIEnv *env, jobject thiz,
-                                                                                         jstring ms_file,
-                                                                                         jobject builder) {
-  fbb = builder;
-  jclass fb_clazz = env->GetObjectClass(builder);
-  create_string_char = env->GetMethodID(fb_clazz, "createString", "(Ljava/lang/CharSequence;)I");
-  TrainFeatureParam **train_features = nullptr;
-  int feature_size = 0;
-  auto status = fl_lenet_lite_GetFeatures(JstringToChar(env, ms_file), &train_features, &feature_size);
-  if (status != mindspore::lite::RET_OK) {
-    MS_LOG(ERROR) << "get features failed:" << ms_file;
-    return env->NewIntArray(0);
-  }
-  jintArray ret = env->NewIntArray(feature_size);
-  jint *data = env->GetIntArrayElements(ret, NULL);
-
-  for (int i = 0; i < feature_size; i++) {
-    data[i] = CreateFeatureMap(env, train_features[i]->name, reinterpret_cast<float *>(train_features[i]->data),
-                               train_features[i]->elenums);
-    MS_LOG(INFO) << "upload feature:"
-                 << ", name:" << train_features[i]->name << ", elenums:" << train_features[i]->elenums;
-  }
-  env->ReleaseIntArrayElements(ret, data, 0);
-  for (int i = 0; i < feature_size; i++) {
-    delete train_features[i];
-  }
-  return ret;
-}
-
-extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_LiteTrain_updateFeatures(JNIEnv *env, jobject,
-                                                                                    jstring ms_file, jobject features) {
-  jclass arr_cls = env->GetObjectClass(features);
-  jmethodID size_method = env->GetMethodID(arr_cls, "size", "()I");
-  jmethodID get_method = env->GetMethodID(arr_cls, "get", "(I)Ljava/lang/Object;");
-
-  jclass fm_cls = env->FindClass("mindspore/schema/FeatureMap");
-  jmethodID weight_name_method = env->GetMethodID(fm_cls, "weightFullname", "()Ljava/lang/String;");
-  jmethodID data_length_method = env->GetMethodID(fm_cls, "dataLength", "()I");
-  jmethodID data_method = env->GetMethodID(fm_cls, "data", "(I)F");
-  jclass clsstring = env->FindClass("java/lang/String");
-  jmethodID mid = env->GetMethodID(clsstring, "getBytes", "(Ljava/lang/String;)[B");
-  int size = env->CallIntMethod(features, size_method);
-  // transform FeatureMap to TrainFeatureParm
-  TrainFeatureParam *features_param = reinterpret_cast<TrainFeatureParam *>(malloc(size * sizeof(TrainFeatureParam)));
-  for (int i = 0; i < size; ++i) {
-    TrainFeatureParam *param = features_param + i;
-    jobject feature = env->CallObjectMethod(features, get_method, i);
-    // set feature_param name
-    jstring weight_full_name = (jstring)env->CallObjectMethod(feature, weight_name_method);
-    jstring strencode = env->NewStringUTF("GB2312");
-    jbyteArray barr = (jbyteArray)env->CallObjectMethod(weight_full_name, mid, strencode);
-    char *name = nullptr;
-    jsize alen = env->GetArrayLength(barr);
-    jbyte *ba = env->GetByteArrayElements(barr, JNI_FALSE);
-    if (alen > 0) {
-      name = new char[alen + 1];
-      if (ba == nullptr) {
-        MS_LOG(ERROR) << "name is nullptr";
-        return mindspore::lite::RET_ERROR;
-      }
-      memcpy(name, ba, alen);
-      name[alen] = 0;
-    }
-    param->name = name;
-    env->ReleaseByteArrayElements(barr, ba, 0);
-    int data_length = env->CallIntMethod(feature, data_length_method);
-    float *data = static_cast<float *>(malloc(data_length * sizeof(float)));
-    memset(data, 0, data_length * sizeof(float));
-    for (int j = 0; j < data_length; ++j) {
-      float *addr = data + j;
-      *addr = env->CallFloatMethod(feature, data_method, j);
-    }
-    param->data = data;
-    param->elenums = data_length;
-    param->type = mindspore::kNumberTypeFloat32;
-    MS_LOG(INFO) << "get feature:" << param->name << ",elenums:" << param->elenums;
-  }
-  return fl_lenet_lite_UpdateFeatures(JstringToChar(env, ms_file), features_param, size);
-}
-
-extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_LiteTrain_setInput(JNIEnv *env, jobject, jstring files,
-                                                                              jint nums) {
-  return fl_lenet_lite_SetInputs(JstringToChar(env, files), nums);
-}
-
-extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_LiteTrain_inference(JNIEnv *env, jobject, jstring ms_file,
-                                                                               jint batch_num, jint test_nums) {
-  return fl_lenet_lite_Inference(JstringToChar(env, ms_file), batch_num, test_nums);
-}
-
-extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_LiteTrain_free(JNIEnv *, jobject) { return 0; }
diff --git a/mindspore/lite/flclient/src/main/native/lite_train.cpp b/mindspore/lite/flclient/src/main/native/lite_train.cpp
new file mode 100644
index 0000000..5c940b6
--- /dev/null
+++ b/mindspore/lite/flclient/src/main/native/lite_train.cpp
@@ -0,0 +1,287 @@
+/**
+ * Copyright 2020 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "lite_train.h"
+#include <cstring>
+#include <fstream>
+#include <iostream>
+#include "include/context.h"
+#include "include/errorcode.h"
+#include "src/common/log_adapter.h"
+#include <climits>
+
+//#include "include/lenet_train.h"
+//#include <cstring>
+//#include <fstream>
+//#include <iostream>
+//#include "include/api/lite_context.h"
+//#include "include/context.h"
+//#include "include/errorcode.h"
+//#include "src/common/log_adapter.h"
+//#include "limits.h"
+
+
+static char *fl_lenet_I0 = 0;
+static char *fl_lenet_I1 = 0;
+unsigned int seed_ = time(NULL);
+
+std::vector<int> FillInputData(mindspore::session::TrainSession *train_session, int batch_num, int batch_idx,bool serially) {
+  std::vector<int> labels_vec;
+  auto inputs = train_session->GetInputs();
+  int batch_size = inputs[0]->shape()[0];
+  int data_size = inputs[0]->ElementsNum() / batch_size;
+  int num_classes = inputs[1]->shape()[1];
+  float* input_data = reinterpret_cast<float *>(inputs.at(0)->MutableData());
+  auto labels = reinterpret_cast<float *>(inputs.at(1)->MutableData());
+  std::fill(labels, labels + inputs.at(1)->ElementsNum(), 0.f);
+  int label_idx = 0;
+  int idx = 0;
+  for (int i = 0; i < batch_size; i++) {
+    if (serially) {
+      idx = i;
+      std::memcpy(input_data + i * data_size, (float*)fl_lenet_I0 + +batch_idx*inputs[0]->ElementsNum()+idx * data_size, data_size*sizeof(float));
+      label_idx = *(reinterpret_cast<int *>(fl_lenet_I1) + batch_idx*batch_size+idx);
+    } else {
+      idx = rand_r(&seed_) % (batch_num*batch_size);
+      std::memcpy(input_data + i * data_size, (float*)fl_lenet_I0 + idx * data_size, data_size*sizeof(float));
+      label_idx = *(reinterpret_cast<int *>(fl_lenet_I1) + idx);
+    }
+    labels[i * num_classes + label_idx] = 1.0;  // Model expects labels in onehot representation
+    labels_vec.push_back(label_idx);
+  }
+  return labels_vec;
+}
+
+mindspore::tensor::MSTensor *SearchOutputsForSize(mindspore::session::TrainSession *train_session, size_t size) {
+  auto outputs = train_session->GetOutputs();
+  for (auto it = outputs.begin(); it != outputs.end(); ++it) {
+    if (it->second->ElementsNum() == size) return it->second;
+  }
+  MS_LOG(ERROR) << "Model does not have an output tensor with size:"<<size;
+  return nullptr;
+}
+
+float GetLoss(mindspore::session::TrainSession *train_session) {
+  auto outputsv = SearchOutputsForSize(train_session, 1);  // Search for Loss which is a single value tensor
+  if (outputsv == nullptr) {
+    return 10000;
+  }
+  auto loss = reinterpret_cast<float *>(outputsv->MutableData());
+  return loss[0];
+}
+mindspore::session::TrainSession *GetSession(const std::string &ms_file, bool train_mode) {
+  // create model file
+  mindspore::lite::Context context;
+  context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = mindspore::lite::NO_BIND;
+  context.thread_num_ = 1;
+  return mindspore::session::TrainSession::CreateSession(ms_file, &context, train_mode);
+}
+
+float CalculateAccuracy(mindspore::session::TrainSession *session,const std::vector<int> &labels) {
+  session->Eval();
+  session->RunGraph();
+  auto inputs = session->GetInputs();
+  auto batch_size = inputs[1]->shape()[0];
+  auto num_of_class = inputs[1]->shape()[1];
+  auto outputsv = SearchOutputsForSize(session, batch_size * num_of_class);
+  auto scores = reinterpret_cast<float *>(outputsv->MutableData());
+  float accuracy = 0.0;
+  for (int b = 0; b < batch_size; b++) {
+    int max_idx = 0;
+    float max_score = scores[num_of_class * b];
+    for (int c = 0; c < num_of_class; c++) {
+      if (scores[num_of_class * b + c] > max_score) {
+        max_score = scores[num_of_class * b + c];
+        max_idx = c;
+      }
+    }
+    if (labels[b] == max_idx) accuracy += 1.0;
+  }
+  return accuracy/batch_size;
+}
+
+
+// net inference function
+float fl_lite_Inference(const std::string &ms_file, int batch_num) {
+  auto session = GetSession(ms_file, false);
+  char *origin_input[] = {fl_lenet_I0, fl_lenet_I1};
+  auto labels = FillInputData(session, batch_num,0, true);;
+    auto infer_acc = CalculateAccuracy(session,labels);
+    std::cout << "inference acc is:" << infer_acc << std::endl;
+  fl_lenet_I0 = origin_input[0];
+  fl_lenet_I1 = origin_input[1];
+  return infer_acc;
+}
+
+
+// net training function
+int fl_lite_Train(const std::string &ms_file, const int batch_num, const int iterations) {
+  auto session = GetSession(ms_file, true);
+  if (iterations <= 0) {
+    MS_LOG(ERROR) << "error iterations or epoch!, epoch:"
+                  << ", iterations" << iterations;
+    return mindspore::lite::RET_ERROR;
+  }
+  MS_LOG(INFO) << "total iterations :" << iterations << "batch_num:" << batch_num;
+  char *origin_input[] = {fl_lenet_I0, fl_lenet_I1};
+  for (int j = 0; j < iterations/batch_num; ++j) {
+    float sum_loss_per_epoch = 0.0f;
+    float sum_acc_per_epoch = 0.0f;
+    for(int k=0;k<batch_num;++k) {
+      auto lables = FillInputData(session, batch_num,k, true);
+      session->RunGraph(nullptr, nullptr);
+      sum_loss_per_epoch+=GetLoss(session);
+      sum_acc_per_epoch += CalculateAccuracy(session,lables);
+      session->Train();
+    }
+    std::cout << "epoch " << "[" <<j<<"]" << ",mean Loss " << sum_loss_per_epoch/batch_num <<",train acc "<<  sum_acc_per_epoch/batch_num<<std::endl;
+
+  }
+  session->SaveToFile(ms_file);
+  fl_lenet_I0 = origin_input[0];
+  fl_lenet_I1 = origin_input[1];
+  return mindspore::lite::RET_OK;
+}
+
+int fl_lite_UpdateFeatures(const std::string &update_ms_file, TrainFeatureParam *new_features, int size) {
+  auto train_session = GetSession(update_ms_file, false);
+  auto status = train_session->UpdateFeatureMaps(update_ms_file, new_features, size);
+  if (status != mindspore::lite::RET_OK) {
+    MS_LOG(ERROR) << "update model feature map failed" << update_ms_file;
+  }
+  delete train_session;
+  return status;
+}
+
+int fl_lite_GetFeatures(const std::string &update_ms_file, mindspore::session::TrainFeatureParam ***feature,
+                              int *size) {
+  auto train_session = GetSession(update_ms_file, false);
+  std::vector<mindspore::session::TrainFeatureParam *> new_features;
+  auto status = train_session->GetFeatureMaps(&new_features);
+  if (status != mindspore::lite::RET_OK) {
+    MS_LOG(ERROR) << "get model feature map failed" << update_ms_file;
+    delete train_session;
+    return mindspore::lite::RET_ERROR;
+  }
+  *feature = new (std::nothrow) TrainFeatureParam *[new_features.size()];
+  if (*feature == nullptr) {
+    MS_LOG(ERROR) << "create features failed";
+    delete train_session;
+    return mindspore::lite::RET_ERROR;
+  }
+  for (int i = 0; i < new_features.size(); i++) {
+    (*feature)[i] = new_features[i];
+  }
+  *size = new_features.size();
+  delete train_session;
+  return mindspore::lite::RET_OK;
+}
+
+std::string RealPath(const char *path) {
+  if (path == nullptr) {
+    MS_LOG(ERROR) << "path is nullptr";
+    return "";
+  }
+  if ((strlen(path)) >= PATH_MAX) {
+    MS_LOG(ERROR) << "path is too long";
+    return "";
+  }
+  auto resolved_path = std::make_unique<char[]>(PATH_MAX);
+  if (resolved_path == nullptr) {
+    MS_LOG(ERROR) << "new resolved_path failed";
+    return "";
+  }
+#ifdef _WIN32
+  char *real_path = _fullpath(resolved_path.get(), path, 1024);
+#else
+  char *real_path = realpath(path, resolved_path.get());
+#endif
+  if (real_path == nullptr || strlen(real_path) == 0) {
+    MS_LOG(ERROR) << "file path is not valid : " << path;
+    return "";
+  }
+  std::string res = resolved_path.get();
+  return res;
+}
+
+char *ReadFile(const char *file, size_t *size) {
+  if (file == nullptr) {
+    MS_LOG(ERROR) << "file is nullptr";
+    return nullptr;
+  }
+  //  MS_ASSERT(size != nullptr);
+  std::string real_path = RealPath(file);
+  std::ifstream ifs(real_path);
+  if (!ifs.good()) {
+    MS_LOG(ERROR) << "file: " << real_path << " is not exist";
+    return nullptr;
+  }
+
+  if (!ifs.is_open()) {
+    MS_LOG(ERROR) << "file: " << real_path << " open failed";
+    return nullptr;
+  }
+
+  ifs.seekg(0, std::ios::end);
+  *size = ifs.tellg();
+  std::unique_ptr<char[]> buf(new (std::nothrow) char[*size]);
+  if (buf == nullptr) {
+    MS_LOG(ERROR) << "malloc buf failed, file: " << real_path;
+    ifs.close();
+    return nullptr;
+  }
+  ifs.seekg(0, std::ios::beg);
+  ifs.read(buf.get(), *size);
+  ifs.close();
+
+  return buf.release();
+}
+
+// Set input tensors.
+int fl_lite_SetInputs(const std::string &files, int num) {
+  std::vector<std::string> res;
+  if (files.empty()) {
+    MS_LOG(ERROR) << "files empty";
+    return -1;
+  }
+  std::string pattern = ",";
+  std::string strs = files + pattern;
+  size_t pos = strs.find(pattern);
+  while (pos != strs.npos) {
+    std::string temp = strs.substr(0, pos);
+    res.push_back(temp);
+    strs = strs.substr(pos + 1, strs.size());
+    pos = strs.find(pattern);
+  }
+  if (res.size() != 2) {
+    MS_LOG(ERROR) << "res size not equal 2";
+    return -1;
+  }
+  for (int i = 0; i < 2; i++) {
+    size_t size;
+    char *bin_buf = ReadFile(res[i].c_str(), &size);
+    if (bin_buf == nullptr) {
+      MS_LOG(ERROR) << "ReadFile return nullptr";
+      return mindspore::lite::RET_ERROR;
+    }
+    if (i == 0) {
+      fl_lenet_I0 = bin_buf;
+    }
+    if (i == 1) {
+      fl_lenet_I1 = bin_buf;
+    }
+  }
+  return 0;
+}
diff --git a/mindspore/lite/flclient/src/main/native/lite_train.h b/mindspore/lite/flclient/src/main/native/lite_train.h
new file mode 100644
index 0000000..6ac7199
--- /dev/null
+++ b/mindspore/lite/flclient/src/main/native/lite_train.h
@@ -0,0 +1,35 @@
+/**
+ * Copyright 2020 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef MSLITE_FL_LITE_H
+#define MSLITE_FL_LITE_H
+
+#include <string>
+#include "include/train/train_session.h"
+
+using mindspore::session::TrainFeatureParam;
+
+int fl_lite_Train(const std::string &ms_file, const int batch_num, const int iterations);
+
+float fl_lite_Inference(const std::string &ms_file, int batch_num);
+
+int fl_lite_GetFeatures(const std::string &update_ms_file, mindspore::session::TrainFeatureParam ***features,
+                              int *size);
+int fl_lite_UpdateFeatures(const std::string &update_ms_file, TrainFeatureParam *new_features, int size);
+mindspore::session::TrainSession *GetSession(const std::string &ms_file, bool train_mode = false);
+
+int fl_lite_SetInputs(const std::string &files, int num);
+#endif  // MSLITE_FL_LITE_H
diff --git a/mindspore/lite/flclient/src/main/native/lite_train_jni.cpp b/mindspore/lite/flclient/src/main/native/lite_train_jni.cpp
new file mode 100644
index 0000000..8829afd
--- /dev/null
+++ b/mindspore/lite/flclient/src/main/native/lite_train_jni.cpp
@@ -0,0 +1,219 @@
+/**
+ * Copyright 2020 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <jni.h>
+#include <cstring>
+#include "include/errorcode.h"
+#include "include/train/train_session.h"
+#include "lite_train.h"
+#include "src/common/log_adapter.h"
+
+static jobject fbb;
+static jmethodID create_string_char;
+static jobject jmap;
+
+char *JstringToChar(JNIEnv *env, jstring jstr) {
+  char *rtn = nullptr;
+  jclass clsstring = env->FindClass("java/lang/String");
+  jstring strencode = env->NewStringUTF("GB2312");
+  jmethodID mid = env->GetMethodID(clsstring, "getBytes", "(Ljava/lang/String;)[B");
+  jbyteArray barr = (jbyteArray)env->CallObjectMethod(jstr, mid, strencode);
+  jsize alen = env->GetArrayLength(barr);
+  jbyte *ba = env->GetByteArrayElements(barr, JNI_FALSE);
+  if (alen > 0) {
+    rtn = new char[alen + 1];
+    memcpy(rtn, ba, alen);
+    rtn[alen] = 0;
+  }
+  env->ReleaseByteArrayElements(barr, ba, 0);
+  return rtn;
+}
+extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_LiteTrain_train(JNIEnv *env, jobject thiz, jstring ms_file,
+                                                                           jint batch_num, jint iterations) {
+  return fl_lite_Train(JstringToChar(env, ms_file), batch_num, iterations);
+}
+
+extern "C" jint CreateFeatureMap(JNIEnv *env, const char *name, float *data, size_t size) {
+  jstring name1 = env->NewStringUTF(name);
+  jint name_offset = env->CallIntMethod(fbb, create_string_char, name1);
+  // 1. set data size
+  jfloatArray ret = env->NewFloatArray(size);
+  env->SetFloatArrayRegion(ret, 0, size, data);
+  // 2. get methodid createDataVector
+  jclass fm_cls = env->FindClass("mindspore/schema/FeatureMap");
+  jmethodID createDataVector =
+    env->GetStaticMethodID(fm_cls, "createDataVector", "(Lcom/google/flatbuffers/FlatBufferBuilder;[F)I");
+  // 3. calc data offset
+  jint data_offset = env->CallStaticIntMethod(fm_cls, createDataVector, fbb, ret);
+  jmethodID createFeatureMap =
+    env->GetStaticMethodID(fm_cls, "createFeatureMap", "(Lcom/google/flatbuffers/FlatBufferBuilder;II)I");
+  jint fm_offset = env->CallStaticIntMethod(fm_cls, createFeatureMap, fbb, name_offset, data_offset);
+  return fm_offset;
+}
+
+extern "C" JNIEXPORT jobject JNICALL Java_com_huawei_flclient_LiteTrain_getFeaturesMap(JNIEnv *env, jobject thiz, jstring ms_file) {
+
+
+  jclass strClass = env->FindClass("java/lang/String");
+  jmethodID ctorID = env->GetMethodID(strClass, "<init>", "([BLjava/lang/String;)V");
+  jstring encoding = env->NewStringUTF("GB2312");
+
+  TrainFeatureParam **train_features = nullptr;
+  int feature_size = 0;
+  auto status = fl_lite_GetFeatures(JstringToChar(env, ms_file), &train_features, &feature_size);
+  if (status != mindspore::lite::RET_OK) {
+    MS_LOG(ERROR) << "get features failed:" << ms_file;
+    return NULL;
+  }
+  jclass jmapClass = env->FindClass("java/util/HashMap");
+  if (jmapClass == NULL) {
+    return NULL;
+  }
+  jmethodID mid = env->GetMethodID(jmapClass, "<init>", "()V");
+  jmethodID putMethod = env->GetMethodID(jmapClass, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
+  jmethodID getMethod = env->GetMethodID(jmapClass, "get", "(Ljava/lang/Object;)Ljava/lang/Object;");
+  bool map_exist = true;
+  if(jmap == nullptr) {
+   jmap = env->NewGlobalRef(env->NewObject(jmapClass, mid, feature_size));
+   map_exist = false;
+  }
+  for (int i = 0; i < feature_size; i++) {
+    jbyteArray bytes = env->NewByteArray(strlen(train_features[i]->name));
+    env->SetByteArrayRegion(bytes, 0, strlen(train_features[i]->name), (jbyte *)train_features[i]->name);
+    auto key =  (jstring)env->NewObject(strClass, ctorID, bytes, encoding);
+    jfloatArray feature_data;
+    if(map_exist) {
+      feature_data = static_cast<jfloatArray>(env->CallObjectMethod(jmap, getMethod, key));
+    } else {
+     feature_data = env->NewFloatArray(train_features[i]->elenums);
+    }
+    if(feature_data == nullptr) {
+      std::cout<< "create null feature data"<<std::endl;
+    }
+    jfloat* fd = env->GetFloatArrayElements( feature_data,NULL);
+    for(int j=0;j<train_features[i]->elenums;j++) {
+      fd[j] =  reinterpret_cast<float *>(train_features[i]->data)[j];
+    }
+    env->ReleaseFloatArrayElements(feature_data, fd, 0);
+    env->CallObjectMethod(jmap, putMethod,key, feature_data);
+    env->DeleteLocalRef(bytes);
+    env->DeleteLocalRef(key);
+  }
+  env->DeleteLocalRef(encoding);
+  for (int i = 0; i < feature_size; i++) {
+    delete train_features[i]->name;
+    free (train_features[i]->data);
+    delete train_features[i];
+  }
+  return jmap;
+}
+
+extern "C" JNIEXPORT jintArray JNICALL Java_com_huawei_flclient_LiteTrain_getSeralizeFeaturesMap(JNIEnv *env, jobject thiz,
+                                                                                         jstring ms_file,
+                                                                                         jobject builder) {
+  fbb = builder;
+  jclass fb_clazz = env->GetObjectClass(builder);
+  create_string_char = env->GetMethodID(fb_clazz, "createString", "(Ljava/lang/CharSequence;)I");
+  TrainFeatureParam **train_features = nullptr;
+  int feature_size = 0;
+  auto status = fl_lite_GetFeatures(JstringToChar(env, ms_file), &train_features, &feature_size);
+  if (status != mindspore::lite::RET_OK) {
+    MS_LOG(ERROR) << "get features failed:" << ms_file;
+    return env->NewIntArray(0);
+  }
+  jintArray ret = env->NewIntArray(feature_size);
+  jint *data = env->GetIntArrayElements(ret, NULL);
+
+  for (int i = 0; i < feature_size; i++) {
+    data[i] = CreateFeatureMap(env, train_features[i]->name, reinterpret_cast<float *>(train_features[i]->data),
+                               train_features[i]->elenums);
+    MS_LOG(INFO) << "upload feature:"
+                 << ", name:" << train_features[i]->name << ", elenums:" << train_features[i]->elenums;
+  }
+  env->ReleaseIntArrayElements(ret, data, 0);
+  for (int i = 0; i < feature_size; i++) {
+    delete train_features[i]->name;
+    free (train_features[i]->data);
+   delete train_features[i];
+  }
+  return ret;
+}
+
+extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_LiteTrain_updateFeatures(JNIEnv *env, jobject,
+                                                                                    jstring ms_file, jobject features) {
+  jclass arr_cls = env->GetObjectClass(features);
+  jmethodID size_method = env->GetMethodID(arr_cls, "size", "()I");
+  jmethodID get_method = env->GetMethodID(arr_cls, "get", "(I)Ljava/lang/Object;");
+
+  jclass fm_cls = env->FindClass("mindspore/schema/FeatureMap");
+  jmethodID weight_name_method = env->GetMethodID(fm_cls, "weightFullname", "()Ljava/lang/String;");
+  jmethodID data_length_method = env->GetMethodID(fm_cls, "dataLength", "()I");
+  jmethodID data_method = env->GetMethodID(fm_cls, "data", "(I)F");
+  jclass clsstring = env->FindClass("java/lang/String");
+  jmethodID mid = env->GetMethodID(clsstring, "getBytes", "(Ljava/lang/String;)[B");
+  int size = env->CallIntMethod(features, size_method);
+  // transform FeatureMap to TrainFeatureParm
+  TrainFeatureParam *features_param = reinterpret_cast<TrainFeatureParam *>(malloc(size * sizeof(TrainFeatureParam)));
+  for (int i = 0; i < size; ++i) {
+    TrainFeatureParam *param = features_param + i;
+    jobject feature = env->CallObjectMethod(features, get_method, i);
+    // set feature_param name
+    jstring weight_full_name = (jstring)env->CallObjectMethod(feature, weight_name_method);
+    jstring strencode = env->NewStringUTF("GB2312");
+    jbyteArray barr = (jbyteArray)env->CallObjectMethod(weight_full_name, mid, strencode);
+    char *name = nullptr;
+    jsize alen = env->GetArrayLength(barr);
+    jbyte *ba = env->GetByteArrayElements(barr, JNI_FALSE);
+    if (alen > 0) {
+      name = new char[alen + 1];
+      if (ba == nullptr) {
+        MS_LOG(ERROR) << "name is nullptr";
+        return mindspore::lite::RET_ERROR;
+      }
+      memcpy(name, ba, alen);
+      name[alen] = 0;
+    }
+    param->name = name;
+    env->ReleaseByteArrayElements(barr, ba, 0);
+    int data_length = env->CallIntMethod(feature, data_length_method);
+    float *data = static_cast<float *>(malloc(data_length * sizeof(float)));
+    memset(data, 0, data_length * sizeof(float));
+    for (int j = 0; j < data_length; ++j) {
+      float *addr = data + j;
+      *addr = env->CallFloatMethod(feature, data_method, j);
+    }
+    param->data = data;
+    param->elenums = data_length;
+    param->type = mindspore::kNumberTypeFloat32;
+    MS_LOG(INFO) << "get feature:" << param->name << ",elenums:" << param->elenums;
+  }
+  return fl_lite_UpdateFeatures(JstringToChar(env, ms_file), features_param, size);
+}
+
+extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_LiteTrain_setInput(JNIEnv *env, jobject, jstring files,
+                                                                              jint nums) {
+  return fl_lite_SetInputs(JstringToChar(env, files), nums);
+}
+
+extern "C" JNIEXPORT jfloat JNICALL Java_com_huawei_flclient_LiteTrain_inference(JNIEnv *env, jobject, jstring ms_file,
+                                                                               jint batch_num) {
+  return fl_lite_Inference(JstringToChar(env, ms_file), batch_num);
+}
+
+extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_LiteTrain_free(JNIEnv * env, jobject) {
+  env->DeleteGlobalRef(jmap);
+  jmap = NULL;
+  return 0; }
diff --git a/mindspore/lite/include/train/train_session.h b/mindspore/lite/include/train/train_session.h
index 97a80b3..437e027 100644
--- a/mindspore/lite/include/train/train_session.h
+++ b/mindspore/lite/include/train/train_session.h
@@ -24,6 +24,13 @@
 namespace mindspore {
 namespace session {
 
+struct TrainFeatureParam{
+  char* name;
+  void *data;
+  size_t elenums;
+  enum TypeId type;
+};
+
 /// \brief TrainSession Defines a class that allows training a MindSpore model
 class TrainSession : public session::LiteSession {
  public:
@@ -142,6 +149,10 @@ class TrainSession : public session::LiteSession {
     return mindspore::lite::RET_OK;
   }
 
+  virtual int GetFeatureMaps(std::vector<mindspore::session::TrainFeatureParam *>* feature_maps) =0;
+
+  virtual int UpdateFeatureMaps(const std::string &update_ms_file,
+                                TrainFeatureParam* new_features,int size) =0;
  protected:
   bool train_mode_ = false;
   std::string get_loss_name() const { return loss_name_; }
diff --git a/mindspore/lite/nnacl/infer/arithmetic_grad_infer.c b/mindspore/lite/nnacl/infer/arithmetic_grad_infer.c
index a6d85ca..ed00572 100644
--- a/mindspore/lite/nnacl/infer/arithmetic_grad_infer.c
+++ b/mindspore/lite/nnacl/infer/arithmetic_grad_infer.c
@@ -103,4 +103,3 @@ int ArithmeticGradInferShape(const TensorC *const *inputs, size_t inputs_size, T
 
 REG_INFER(DivGrad, PrimType_DivGrad, ArithmeticGradInferShape)
 REG_INFER(MulGrad, PrimType_MulGrad, ArithmeticGradInferShape)
-REG_INFER(MinimumGrad, PrimType_MinimumGrad, ArithmeticGradInferShape)
diff --git a/mindspore/lite/nnacl/infer/maximum_grad_infer.c b/mindspore/lite/nnacl/infer/maximum_grad_infer.c
index c06774e..a72de16 100644
--- a/mindspore/lite/nnacl/infer/maximum_grad_infer.c
+++ b/mindspore/lite/nnacl/infer/maximum_grad_infer.c
@@ -61,3 +61,4 @@ int MaximumGradInferShape(const TensorC *const *inputs, size_t inputs_size, Tens
 }
 
 REG_INFER(MaximumGrad, PrimType_MaximumGrad, MaximumGradInferShape)
+REG_INFER(MinimumGrad, PrimType_MinimumGrad, MaximumGradInferShape)
diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc
index 34d5ef2..d468459 100644
--- a/mindspore/lite/src/train/train_session.cc
+++ b/mindspore/lite/src/train/train_session.cc
@@ -22,6 +22,7 @@
 #include <iostream>
 #include <fstream>
 #include <memory>
+#include <cstring>
 #include "include/errorcode.h"
 #include "src/common/utils.h"
 #include "src/tensor.h"
@@ -228,7 +229,7 @@ int TrainSession::SaveToFile(const std::string &filename) const {
   ofs.seekp(0, std::ios::beg);
   ofs.write(buf, fb_size);
   ofs.close();
-  return chmod(filename.c_str(), S_IRUSR);
+  return chmod(filename.c_str(), S_IRWXU);
 }
 
 int TrainSession::Train() {
@@ -518,6 +519,60 @@ int TrainSession::SetLossName(std::string loss_name) {
   }
   return RET_OK;
 }
+
+int TrainSession::GetFeatureMaps(std::vector<mindspore::session::TrainFeatureParam *> *feature_maps) {
+  for (auto tensor : this->tensors_) {
+    if (tensor->IsConst()) {
+      auto param = new mindspore::session::TrainFeatureParam();
+      int len = tensor->tensor_name().length();
+      char* name = nullptr;
+      if(len>0) {
+        name = new char[len+1];
+        memcpy(name, tensor->tensor_name().c_str(), len);
+        name[len] = 0;
+      }
+      param->name =  name;
+//      param->data = new float[tensor->ElementsNum()];
+      param->data = malloc(tensor->ElementsNum() * sizeof(float));
+      memcpy(param->data, tensor->data_c(), tensor->ElementsNum()*sizeof(float));
+      param->elenums = tensor->ElementsNum();
+      param->type = tensor->data_type();
+      feature_maps->push_back(param);
+    }
+  }
+  MS_LOG(INFO) << "get feature map success";
+  return RET_OK;
+}
+int TrainSession::UpdateFeatureMaps(const std::string &update_ms_file,
+                                    mindspore::session::TrainFeatureParam *new_features, int size) {
+  bool find = false;
+  for (int i = 0; i < size; ++i) {
+    mindspore::session::TrainFeatureParam *new_feature = new_features + i;
+    for (auto tensor : this->tensors_) {
+      if (!tensor->IsConst()) {
+        continue;
+      }
+
+      if (strcmp(tensor->tensor_name().c_str(), new_feature->name) == 0) {
+        if (tensor->ElementsNum() != static_cast<int>(new_feature->elenums)) {
+          MS_LOG(ERROR) << "feature name:" << tensor->tensor_name() << ",len diff:"
+                        << "old is:" << new_feature->elenums << "new is:" << new_feature->elenums;
+          return RET_ERROR;
+        }
+        find = true;
+        memcpy(tensor->data_c(), new_feature->data, new_feature->elenums * sizeof(float));
+        break;
+      }
+    }
+    if (!find) {
+      MS_LOG(ERROR) << "cannot find feature:" << new_feature->name << ",update failed";
+      return RET_ERROR;
+    }
+  }
+  SaveToFile(update_ms_file);
+  MS_LOG(INFO) << "update model:" << update_ms_file << ",feature map success";
+  return RET_OK;
+}
 }  // namespace lite
 
 session::TrainSession *session::TrainSession::CreateSession(const char *model_buf, size_t size, lite::Context *context,
diff --git a/mindspore/lite/src/train/train_session.h b/mindspore/lite/src/train/train_session.h
index 69e73e8..24fae33 100644
--- a/mindspore/lite/src/train/train_session.h
+++ b/mindspore/lite/src/train/train_session.h
@@ -92,6 +92,10 @@ class TrainSession : virtual public session::TrainSession, virtual public lite::
     return outputs;
   }
 
+  int GetFeatureMaps(std::vector<mindspore::session::TrainFeatureParam *> *feature_maps) override;
+  int UpdateFeatureMaps(const std::string &update_ms_file,
+                        mindspore::session::TrainFeatureParam* new_features,int size) override;
+
  protected:
   void AllocWorkSpace();
   bool IsLossKernel(const kernel::LiteKernel *kernel) const;
-- 
2.7.4

