<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

#Data process

In [None]:
From ee750131058e9469e14f8d37a7f2d70afd7cd63a Mon Sep 17 00:00:00 2001
From: guohongzilong <guohongzilong@huawei.com>
Date: Sat, 10 Apr 2021 15:11:00 +0800
Subject: [PATCH] add android train api

---
 .../main/java/com/huawei/flclient/AsyncFLJob.java  |  77 ++++++++-
 .../main/java/com/huawei/flclient/LiteTrain.java   |   8 +
 .../main/java/com/huawei/flclient/NativeTrain.java |   4 +
 .../lite/flclient/src/main/native/bert_train.cpp   | 173 ++++++++++++++++-----
 .../lite/flclient/src/main/native/bert_train.h     |   9 +-
 .../src/main/native/dataset/CustomizedTokenizer.cc |  23 +--
 .../src/main/native/dataset/CustomizedTokenizer.h  |   4 +-
 .../flclient/src/main/native/lite_train_jni.cpp    | 156 +++++++++++++------
 .../lite/flclient/src/main/native/test_train.cc    |  92 +++++------
 mindspore/lite/flclient/src/main/native/util.cpp   |   8 +-
 10 files changed, 394 insertions(+), 160 deletions(-)

diff --git a/mindspore/lite/flclient/src/main/java/com/huawei/flclient/AsyncFLJob.java b/mindspore/lite/flclient/src/main/java/com/huawei/flclient/AsyncFLJob.java
index d6850f1..d373a7f 100644
--- a/mindspore/lite/flclient/src/main/java/com/huawei/flclient/AsyncFLJob.java
+++ b/mindspore/lite/flclient/src/main/java/com/huawei/flclient/AsyncFLJob.java
@@ -1,8 +1,13 @@
 package com.huawei.flclient;
 
-import java.io.FileInputStream;
-import java.io.InputStream;
+import java.io.*;
 import java.nio.ByteBuffer;
+import java.nio.channels.FileChannel;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.List;
 
 public class AsyncFLJob {
     public AsyncFLJob() {}
@@ -130,6 +135,33 @@ public class AsyncFLJob {
         }
     }
 
+    public int modeTrainFromBuffer(String trainFile,String vocabFile,String modelPath,int batchSize,int epoches,EarlyStopMod earlyStopMod,int numThreads) {
+        LiteTrain train = LiteTrain.getInstance();
+        System.out.println("==========initial model===========");
+        ByteBuffer modelBuffer = loadModelFile(modelPath);
+        train.initFromBuffer(modelBuffer, numThreads);
+        System.out.println("trainFilePath: " + trainFile);
+        System.out.println("===========model train=============");
+        try {
+            Path path = Paths.get(vocabFile);
+            List<String> allLines = Files.readAllLines(path, StandardCharsets.UTF_8);
+            String vocal[] = new String[allLines.size()];
+            for(int i=0;i<allLines.size();i++) {
+                vocal[i] = allLines.get(i);
+            }
+            int trainSize = train.setAlbertInput(trainFile,vocal);
+            int status = train.trainAlbert(batchSize, epoches,earlyStopMod.ordinal());
+            // update modelBuffer to msfile
+            updateModelFile(modelPath,modelBuffer);
+            train.free();
+            System.out.println("train finish");
+            return status;
+        } catch (Exception e) {
+            e.printStackTrace();
+            System.out.println("train failed");
+            throw new RuntimeException();
+        }
+    }
     public int modelInferenceSingleDataFromBuffer(String data, ByteBuffer modelPath, String[] vocal_file, int numThreads) {
         LiteTrain train = LiteTrain.getInstance();
         System.out.println("==========initial model===========");
@@ -149,7 +181,7 @@ public class AsyncFLJob {
         }
     }
 
-    public ByteBuffer loadModelFile(String modelPath) {
+    public static ByteBuffer loadModelFile(String modelPath) {
         InputStream input = null;
         try {
             input = new FileInputStream(modelPath);
@@ -163,7 +195,15 @@ public class AsyncFLJob {
         return null;
     }
 
-    public static void main(String[] args) {
+    public static void updateModelFile(String modelPath,ByteBuffer modelBuffer) throws IOException {
+        File file = new File(modelPath);
+        FileChannel wChannel = new FileOutputStream(file, false).getChannel();
+        modelBuffer.flip();
+        wChannel.write(modelBuffer);
+        wChannel.close();
+    }
+
+    public static void main(String[] args) throws IOException {
         String url = args[0];
         String flName = args[1];
         String modelPath = args[2];
@@ -177,6 +217,20 @@ public class AsyncFLJob {
         String task = args[10];
         boolean ifInteract = Boolean.parseBoolean(args[11]);
 
+//        String url = "null";
+//        String flName = "/home/meng/zj10/fl/mindspore/mindspore/lite/albert.ms";
+//        String modelPath = "/home/meng/zj10/fl/mindspore/mindspore/lite/albert.ms";
+//        String flID = "32";
+//        int batchSize = 16;
+//        int epochs = 1;
+//        int iterations = 2;
+//        String trainDataset =  "/home/meng/zj10/fl/mindspore/mindspore/lite/flclient/src/main/native/dataset/0.tsv";
+//        String testDataset =  "/home/meng/zj10/fl/mindspore/mindspore/lite/flclient/src/main/native/dataset/0.tsv";
+//        EarlyStopMod earlyStopMod = EarlyStopMod.NotEarlyStop;
+//        String task = "train";
+//        boolean ifInteract = false;
+        String vocabFile =  "/home/meng/zj10/fl/mindspore/mindspore/lite/flclient/src/main/native/dataset/vocab.txt";
+
         // todo test
         System.out.println("[args] url: " + url);
         System.out.println("[args] flName: " + flName);
@@ -191,14 +245,21 @@ public class AsyncFLJob {
         System.out.println("[args] task: " + task);
         System.out.println("[args] ifInteract: " + ifInteract);
         AsyncFLJob asyncFLJob = new AsyncFLJob();
+
         if (task.equals("train")) {
-            asyncFLJob.asyncFLJobRun(url, flName, modelPath, flID, batchSize, epochs, iterations, trainDataset, testDataset, earlyStopMod, ifInteract);
+//            asyncFLJob.asyncFLJobRun(url, flName, modelPath, flID, batchSize, epochs, iterations, trainDataset, testDataset, earlyStopMod, ifInteract);
+            asyncFLJob.modeTrainFromBuffer(trainDataset,vocabFile,modelPath,batchSize,epochs,EarlyStopMod.NotEarlyStop,1);
         } else if (task.equals("inference")) {
             asyncFLJob.modelInference(trainDataset, modelPath);
         } else {
-            String[] vocal = {"hello1", "hello2"};
-            ByteBuffer modelPathBuf = asyncFLJob.loadModelFile(modelPath);
-            asyncFLJob.modelInferenceSingleDataFromBuffer("hello", modelPathBuf, vocal, 1);
+            Path path = Paths.get(vocabFile);
+            List<String> allLines = Files.readAllLines(path, StandardCharsets.UTF_8);
+            String[] vocal = new String[allLines.size()];
+            for(int i=0;i<allLines.size();i++) {
+                vocal[i] = allLines.get(i);
+            }
+            ByteBuffer modelPathBuf = loadModelFile(modelPath);
+            asyncFLJob.modelInferenceSingleDataFromBuffer("", modelPathBuf, vocal, 1);
         }
       
 
diff --git a/mindspore/lite/flclient/src/main/java/com/huawei/flclient/LiteTrain.java b/mindspore/lite/flclient/src/main/java/com/huawei/flclient/LiteTrain.java
index 759dd05..73edd36 100644
--- a/mindspore/lite/flclient/src/main/java/com/huawei/flclient/LiteTrain.java
+++ b/mindspore/lite/flclient/src/main/java/com/huawei/flclient/LiteTrain.java
@@ -41,10 +41,18 @@ public class LiteTrain {
         return NativeTrain.setInput(fileSet);
     }
 
+    public int  setAlbertInput(String trainFile,String[] vocabArray) {
+        return NativeTrain.setBertInputFromArr(trainFile,vocabArray);
+    }
+
     public int train(int batchSize,int epoches, int earlyStopMod) {
         return NativeTrain.train(sessionPtr, batchSize,epoches, earlyStopMod);
     }
 
+    public int trainAlbert(int batchSize, int epoches, int earlyStopMod) {
+        return NativeTrain.trainFromBuffer(sessionPtr, batchSize,epoches, earlyStopMod);
+    }
+
     public float infer() {
         return NativeTrain.infer(sessionPtr);
     }
diff --git a/mindspore/lite/flclient/src/main/java/com/huawei/flclient/NativeTrain.java b/mindspore/lite/flclient/src/main/java/com/huawei/flclient/NativeTrain.java
index 725aebd..f4bde20 100644
--- a/mindspore/lite/flclient/src/main/java/com/huawei/flclient/NativeTrain.java
+++ b/mindspore/lite/flclient/src/main/java/com/huawei/flclient/NativeTrain.java
@@ -30,6 +30,8 @@ public  class NativeTrain {
 
     static native int setInput(String fileSet);
 
+    static native int setBertInputFromArr(String trainFile,String[] vocabArray);
+
     static native long createSessionFromBuffer(ByteBuffer modelBuffer,int numThread);
 
     static native long createSession(String modelPath,long msConfigPtr);
@@ -44,6 +46,8 @@ public  class NativeTrain {
 
     static native int train(long sessionPtr,int batch_size,int epoches,int earlyStopMod);
 
+    static native int trainFromBuffer(long sessionPtr,int batch_size,int epoches,int earlyStopMod);
+
    static native int[] getSeralizeFeaturesMap(long sessionPtr, FlatBufferBuilder builder);
 
    static native Map<String,float[]> getFeaturesMap(long sessionPtr);
diff --git a/mindspore/lite/flclient/src/main/native/bert_train.cpp b/mindspore/lite/flclient/src/main/native/bert_train.cpp
index 1996a11..0be3b3b 100644
--- a/mindspore/lite/flclient/src/main/native/bert_train.cpp
+++ b/mindspore/lite/flclient/src/main/native/bert_train.cpp
@@ -14,15 +14,16 @@
  * limitations under the License.
  */
 #include "bert_train.h"
-#include <climits>
+#include "util.h"
+#include "dataset/labels.h"
 #include <cstring>
 #include <fstream>
 #include <iostream>
-#include "dataset/CustomizedTokenizer.h"
 #include "include/context.h"
 #include "include/errorcode.h"
 #include "src/common/log_adapter.h"
-#include "util.h"
+#include "dataset/CustomizedTokenizer.h"
+#include <climits>
 
 static int *lable_ids = 0;
 static int *input_ids = 0;
@@ -60,21 +61,21 @@ std::vector<int> FillBertInput(mindspore::session::TrainSession *train_session,
       labels_vec[i] = lable_ids[batch_idx * batch_size + i];
     }
   }
-  //  std::ofstream ofs("model_input_mask.bin", std::ios::binary | std::ios::out);
-  //  ofs.write((const char *)model_input_mask, sizeof(int) * inputs[0]->ElementsNum());
-  //  ofs.close();
-  //
-  //  std::ofstream ofs1("model_input_ids.bin", std::ios::binary | std::ios::out);
-  //  ofs1.write((const char *)model_input_ids, sizeof(int) * inputs[1]->ElementsNum());
-  //  ofs1.close();
-  //
-  //  std::ofstream ofs3("model_label_ids.bin", std::ios::binary | std::ios::out);
-  //  ofs3.write((const char *)model_label_ids, sizeof(int) * inputs[3]->ElementsNum());
-  //  ofs3.close();
-  //
-  //  std::ofstream ofs2("model_token_id.bin", std::ios::binary | std::ios::out);
-  //  ofs2.write((const char *)model_token_id, sizeof(int) * inputs[2]->ElementsNum());
-  //  ofs2.close();
+    std::ofstream ofs("model_input_mask.bin", std::ios::binary | std::ios::out);
+    ofs.write((const char *)model_input_mask, sizeof(int) * inputs[0]->ElementsNum());
+    ofs.close();
+
+    std::ofstream ofs1("model_input_ids.bin", std::ios::binary | std::ios::out);
+    ofs1.write((const char *)model_input_ids, sizeof(int) * inputs[1]->ElementsNum());
+    ofs1.close();
+
+    std::ofstream ofs3("model_label_ids.bin", std::ios::binary | std::ios::out);
+    ofs3.write((const char *)model_label_ids, sizeof(int) * inputs[3]->ElementsNum());
+    ofs3.close();
+
+    std::ofstream ofs2("model_token_id.bin", std::ios::binary | std::ios::out);
+    ofs2.write((const char *)model_token_id, sizeof(int) * inputs[2]->ElementsNum());
+    ofs2.close();
   return labels_vec;
 }
 
@@ -101,23 +102,23 @@ int InferFromVocabFile(TrainSession *session, const std::string &input_str, cons
   int s_token_type_ids[MAX_SEQ_LENGTH];
   int len = 0;
   customized_tokenizer.tokenize(input_str, s_input_ids, s_attention_mask, s_token_type_ids, len);
+  session->Eval();
   // not support one input,need pad to one batch
   auto ms_inputs = session->GetInputs();
   auto model_input_mask = reinterpret_cast<int *>(ms_inputs.at(0)->MutableData());
   auto model_input_ids = reinterpret_cast<int *>(ms_inputs.at(1)->MutableData());
   auto model_token_id = reinterpret_cast<int *>(ms_inputs.at(2)->MutableData());
   auto model_label_ids = reinterpret_cast<int *>(ms_inputs.at(3)->MutableData());
-  for (int i = 0; i < BATCH_SIZE; i++) {
-    std::memcpy(model_input_mask + i * MAX_SEQ_LENGTH, s_attention_mask, MAX_SEQ_LENGTH * sizeof(int));
-    std::memcpy(model_input_ids + i * MAX_SEQ_LENGTH, s_input_ids, MAX_SEQ_LENGTH * sizeof(int));
-    std::memcpy(model_token_id + i * MAX_SEQ_LENGTH, s_token_type_ids, MAX_SEQ_LENGTH * sizeof(int));
+  for(int i=0;i<BATCH_SIZE;i++) {
+  std::memcpy(model_input_mask+i*MAX_SEQ_LENGTH, s_attention_mask, MAX_SEQ_LENGTH * sizeof(int));
+  std::memcpy(model_input_ids+i*MAX_SEQ_LENGTH, s_input_ids, MAX_SEQ_LENGTH * sizeof(int));
+  std::memcpy(model_token_id+i*MAX_SEQ_LENGTH, s_token_type_ids, MAX_SEQ_LENGTH * sizeof(int));
   }
   std::fill(model_label_ids, model_label_ids + ms_inputs.at(3)->ElementsNum(), 0.f);
 
-  session->Eval();
   session->RunGraph();
   auto inputs = session->GetInputs();
-  auto outputsv = SearchOutputsForSize(session, BATCH_SIZE * LABEL_CLASS);
+  auto outputsv = SearchOutputsForSize(session,  BATCH_SIZE* LABEL_CLASS);
   std::cout << "ouput tensor name:" << outputsv->tensor_name() << std::endl;
   auto scores = reinterpret_cast<float *>(outputsv->MutableData());
 
@@ -132,32 +133,35 @@ int InferFromVocabFile(TrainSession *session, const std::string &input_str, cons
   return max_idx;
 }
 
-int InferFromVocabArr(TrainSession *session, const std::string &input_str, std::string vocab_array[]) {
+int InferFromVocabArr(TrainSession *session, const std::string &input_str, std::string vocab_array[],int vocab_size) {
+  if(input_str.empty()) {
+    return -1;
+  }
   CustomizedTokenizer customized_tokenizer;
   bool do_lower_case = true;
-  customized_tokenizer.initFromVocab(vocab_array, do_lower_case);
+  customized_tokenizer.initFromVocab(vocab_array,vocab_size, do_lower_case);
   int s_input_ids[MAX_SEQ_LENGTH];
   int s_attention_mask[MAX_SEQ_LENGTH];
   int s_token_type_ids[MAX_SEQ_LENGTH];
   int len = 0;
   customized_tokenizer.tokenize(input_str, s_input_ids, s_attention_mask, s_token_type_ids, len);
+  session->Eval();
   // not support one input,need pad to one batch
   auto ms_inputs = session->GetInputs();
   auto model_input_mask = reinterpret_cast<int *>(ms_inputs.at(0)->MutableData());
   auto model_input_ids = reinterpret_cast<int *>(ms_inputs.at(1)->MutableData());
   auto model_token_id = reinterpret_cast<int *>(ms_inputs.at(2)->MutableData());
   auto model_label_ids = reinterpret_cast<int *>(ms_inputs.at(3)->MutableData());
-  for (int i = 0; i < BATCH_SIZE; i++) {
-    std::memcpy(model_input_mask + i * MAX_SEQ_LENGTH, s_attention_mask, MAX_SEQ_LENGTH * sizeof(int));
-    std::memcpy(model_input_ids + i * MAX_SEQ_LENGTH, s_input_ids, MAX_SEQ_LENGTH * sizeof(int));
-    std::memcpy(model_token_id + i * MAX_SEQ_LENGTH, s_token_type_ids, MAX_SEQ_LENGTH * sizeof(int));
+  for(int i=0;i<BATCH_SIZE;i++) {
+    std::memcpy(model_input_mask+i*MAX_SEQ_LENGTH, s_attention_mask, MAX_SEQ_LENGTH * sizeof(int));
+    std::memcpy(model_input_ids+i*MAX_SEQ_LENGTH, s_input_ids, MAX_SEQ_LENGTH * sizeof(int));
+    std::memcpy(model_token_id+i*MAX_SEQ_LENGTH, s_token_type_ids, MAX_SEQ_LENGTH * sizeof(int));
   }
   std::fill(model_label_ids, model_label_ids + ms_inputs.at(3)->ElementsNum(), 0.f);
 
-  session->Eval();
   session->RunGraph();
   auto inputs = session->GetInputs();
-  auto outputsv = SearchOutputsForSize(session, BATCH_SIZE * LABEL_CLASS);
+  auto outputsv = SearchOutputsForSize(session,  BATCH_SIZE* LABEL_CLASS);
   std::cout << "ouput tensor name:" << outputsv->tensor_name() << std::endl;
   auto scores = reinterpret_cast<float *>(outputsv->MutableData());
 
@@ -172,6 +176,7 @@ int InferFromVocabArr(TrainSession *session, const std::string &input_str, std::
   return max_idx;
 }
 
+
 // net training function
 int TrainBert(TrainSession *session, const std::string &ms_file, int batch_size, int epoches) {
   if (epoches <= 0) {
@@ -190,7 +195,7 @@ int TrainBert(TrainSession *session, const std::string &ms_file, int batch_size,
       session->RunGraph(nullptr, nullptr);
       auto loss = GetLoss(session);
       sum_loss_per_epoch += loss;
-      //      std::cout << "batch:" << k << ",loss:" << loss << std::endl;
+            std::cout << "batch:" << k << ",loss:" << loss << std::endl;
       sum_acc_per_epoch += CalculateAccuracy(session, lables, LABEL_CLASS);
     }
     std::cout << "epoch "
@@ -202,6 +207,36 @@ int TrainBert(TrainSession *session, const std::string &ms_file, int batch_size,
   return mindspore::lite::RET_OK;
 }
 
+// net training function
+int TrainBertFromBuffer(TrainSession *session, int batch_size, int epoches,char* model_buffer,size_t* model_len) {
+  if (epoches <= 0) {
+    std::cout << "error iterations or epoch!, epoch:"
+              << ", iterations" << epoches;
+    return mindspore::lite::RET_ERROR;
+  }
+  batch_num = total_size / batch_size;
+  std::cout << "total epoches :" << epoches << ",batch_num:" << batch_num << std::endl;
+  for (int j = 0; j < epoches; ++j) {
+    float sum_loss_per_epoch = 0.0f;
+    float sum_acc_per_epoch = 0.0f;
+    for (int k = 0; k < batch_num; ++k) {
+      session->Train();
+      auto lables = FillBertInput(session, k);
+      session->RunGraph(nullptr, nullptr);
+      auto loss = GetLoss(session);
+      sum_loss_per_epoch += loss;
+//            std::cout << "batch:" << k << ",loss:" << loss << std::endl;
+      sum_acc_per_epoch += CalculateAccuracy(session, lables, LABEL_CLASS);
+    }
+    std::cout << "epoch "
+              << "[" << j << "]"
+              << ",mean Loss " << sum_loss_per_epoch / batch_num << ",train acc " << sum_acc_per_epoch / batch_num
+              << std::endl;
+  }
+  session->ExportToBuf(model_buffer,model_len);
+  return mindspore::lite::RET_OK;
+}
+
 void ReadTxt(const std::string &file, std::vector<std::string> *train_data) {
   std::fstream fin;
   fin.open(file, std::ios::in);
@@ -222,17 +257,79 @@ void ReadTxt(const std::string &file, std::vector<std::string> *train_data) {
 }
 
 // Set input tensors.
-int SetBertInputs(const std::string &train_file, const std::string &vocab_file, const std::string &labels_file) {
+int SetBertInputsFromArray(const std::string &train_file, std::string vocab_file[],int vocab_size) {
+  if (train_file.empty()) {
+    std::cout << "files empty";
+    return -1;
+  }
+  std::vector<std::string> train_data;
+  ReadTxt(train_file, &train_data);
+  std::map<std::string, int> labels_map;
+  for (int i = 0; i < LABEL_CLASS; i++) {
+    labels_map[labels[i]] = i;
+  }
+  int train_size = train_data.size();
+  int total_batch_num = train_size / BATCH_SIZE;
+  int remain_size = train_size % BATCH_SIZE;
+  int pad_size = BATCH_SIZE - remain_size;
+  if (total_batch_num == 0) {
+    std::cout << "train data size less than one batch,need random padding" << std::endl;
+  }
+  total_size = train_size + pad_size;
+  input_ids = new (std::nothrow) int[total_size * MAX_SEQ_LENGTH];
+  input_mask = new (std::nothrow) int[total_size * MAX_SEQ_LENGTH];
+  token_type_ids = new (std::nothrow) int[total_size * MAX_SEQ_LENGTH];
+  lable_ids = new (std::nothrow) int[total_size];
+
+  CustomizedTokenizer customized_tokenizer;
+  bool do_lower_case = true;
+  customized_tokenizer.initFromVocab(vocab_file, vocab_size,do_lower_case);
+
+  int s_input_ids[MAX_SEQ_LENGTH];
+  int s_attention_mask[MAX_SEQ_LENGTH];
+  int s_token_type_ids[MAX_SEQ_LENGTH];
+  int seq_length;
+
+  for (int i = 0; i < total_size; i++) {
+    std::vector<std::string> dataset_tuple(2);
+    dataset_tuple.clear();
+    int idx = i;
+    // less than one batch would  pad random
+    if (i >= train_size) {
+      idx = rand_r(&seed_) % train_size;
+    }
+    size_t pos = train_data[idx].find("\t", 0);
+    if (pos != std::string::npos) {
+      dataset_tuple.push_back(train_data[idx].substr(0, pos));
+      dataset_tuple.push_back(train_data[idx].substr(pos + 1, train_data[idx].size()));
+    }
+    if (dataset_tuple.size() != 2) {
+      std::cout << "train data error,must 2 word.idx::" << idx << std::endl;
+      return -1;
+    }
+    customized_tokenizer.tokenize(dataset_tuple[1], s_input_ids, s_attention_mask, s_token_type_ids, seq_length);
+    memcpy(input_ids + i * MAX_SEQ_LENGTH, s_input_ids, MAX_SEQ_LENGTH * sizeof(int));
+    memcpy(input_mask + i * MAX_SEQ_LENGTH, s_attention_mask, MAX_SEQ_LENGTH * sizeof(int));
+    memcpy(token_type_ids + i * MAX_SEQ_LENGTH, s_token_type_ids, MAX_SEQ_LENGTH * sizeof(int));
+    std::string key = dataset_tuple[0];
+    lable_ids[i] = labels_map[key];
+  }
+  std::cout << "total train size :" << total_size << std::endl << std::endl;
+  return total_size;
+}
+
+// Set input tensors.
+int SetBertInputs(const std::string &train_file, const std::string &vocab_file) {
   if (train_file.empty()) {
     std::cout << "files empty";
     return -1;
   }
   std::vector<std::string> train_data;
   ReadTxt(train_file, &train_data);
-  std::vector<std::string> labels;
-  ReadTxt(labels_file, &labels);
+//  std::vector<std::string> labels;
+//  ReadTxt(labels_file, &labels);
   std::map<std::string, int> labels_map;
-  for (int i = 0; i < labels.size(); i++) {
+  for (int i = 0; i < LABEL_CLASS; i++) {
     labels_map[labels[i]] = i;
   }
   int train_size = train_data.size();
@@ -281,7 +378,7 @@ int SetBertInputs(const std::string &train_file, const std::string &vocab_file,
     std::string key = dataset_tuple[0];
     lable_ids[i] = labels_map[key];
   }
-  std::cout << "total train size :" << std::endl << std::endl;
+  std::cout << "total train size :" << total_size << std::endl << std::endl;
   return total_size;
 }
 
diff --git a/mindspore/lite/flclient/src/main/native/bert_train.h b/mindspore/lite/flclient/src/main/native/bert_train.h
index 240e6e9..edd57db 100644
--- a/mindspore/lite/flclient/src/main/native/bert_train.h
+++ b/mindspore/lite/flclient/src/main/native/bert_train.h
@@ -20,11 +20,14 @@
 #include "include/train/train_session.h"
 #include <string>
 using mindspore::session::TrainSession;
-int SetBertInputs(const std::string &train_file,const std::string &vocab_file,const std::string &labels_file);
+int SetBertInputs(const std::string &train_file, const std::string &vocab_file);
+int SetBertInputsFromArray(const std::string &train_file, std::string vocab_file[],int vocab_size);
 void FreeBertInput();
 std::vector<int> GetBertInferRes(TrainSession *session);
-int TrainBert( TrainSession*session,const std::string &ms_file,int batch_size ,int epoches);
+int InferFromVocabArr(TrainSession *session, const std::string &input_str, std::string vocab_array[],int vocab_size);
+int TrainBert(TrainSession *session, const std::string &ms_file, int batch_size, int epoches);
+int TrainBertFromBuffer(TrainSession *session, int batch_size, int epoches,
+                        char *model_buffer, size_t *model_len);
 float InferBert(TrainSession *session);
 int InferFromVocabFile(TrainSession *session, const std::string &input_str, const std::string &vocab_file);
-int InferFromVocabArr(TrainSession *session, const std::string &input_str, std::string vocab_array[]);
 #endif  // MSLITE_FL_BERT_TRAIN_H
diff --git a/mindspore/lite/flclient/src/main/native/dataset/CustomizedTokenizer.cc b/mindspore/lite/flclient/src/main/native/dataset/CustomizedTokenizer.cc
index 910f4fd..4d2c2c5 100644
--- a/mindspore/lite/flclient/src/main/native/dataset/CustomizedTokenizer.cc
+++ b/mindspore/lite/flclient/src/main/native/dataset/CustomizedTokenizer.cc
@@ -5,6 +5,7 @@
 
 CustomizedTokenizer::CustomizedTokenizer() = default;
 
+
 CustomizedTokenizer::~CustomizedTokenizer() {
   _vocab.clear();
 }
@@ -14,6 +15,11 @@ void CustomizedTokenizer::init(const string &vocab_file, bool do_lower_case) {
   _load_vocab(vocab_file);
 }
 
+void CustomizedTokenizer::initFromVocab(string vocab_array[], int vocab_size,bool do_lower_case) {
+  _do_lower_case = do_lower_case;
+  _load_vocabFromVocab(vocab_array,vocab_size);
+}
+
 void CustomizedTokenizer::tokenize(const string &text, string output_tokens[MAX_SEQ_LENGTH], int &seq_length) {
 //  clock_t startTime;
 //  double time_cost = 0.0;
@@ -254,6 +260,12 @@ void CustomizedTokenizer::_clean_text() {
   }
 }
 
+void CustomizedTokenizer::_load_vocabFromVocab(string vocab_file[],int vocab_size) {
+  for(int i=0;i<vocab_size;i++) {
+    _vocab[vocab_file[i]] = i;
+  }
+}
+
 void CustomizedTokenizer::_load_vocab(const string &vocab_file) {
   int index = 0;
   fstream fin;
@@ -315,17 +327,6 @@ void CustomizedTokenizer::_fixed_matching(int &pos, string &token) {
   pos++;
 }
 
-void CustomizedTokenizer::initFromVocab(string vocab_array[], bool do_lower_case) {
-  _do_lower_case = do_lower_case;
-  _load_vocabFromVocab(vocab_array);
-}
-
-void CustomizedTokenizer::_load_vocabFromVocab(string vocab_file[]) {
-  for(int i=0;i<vocab_file->length();i++) {
-    _vocab[vocab_file[i]] = i;
-  }
-}
-
 void CustomizedTokenizer::_split_text() {
   // Performs invalid character removal and whitespace cleanup on text.
   _clean_text();
diff --git a/mindspore/lite/flclient/src/main/native/dataset/CustomizedTokenizer.h b/mindspore/lite/flclient/src/main/native/dataset/CustomizedTokenizer.h
index 9566d08..f8ee67a 100644
--- a/mindspore/lite/flclient/src/main/native/dataset/CustomizedTokenizer.h
+++ b/mindspore/lite/flclient/src/main/native/dataset/CustomizedTokenizer.h
@@ -23,10 +23,10 @@ class CustomizedTokenizer
   ~CustomizedTokenizer();
 
   void init(const std::string &vocab_file, bool do_lower_case);
+  void initFromVocab(string vocab_array[], int vocab_size,bool do_lower_case);
   void tokenize(const string &text, string output_tokens[MAX_SEQ_LENGTH], int &seq_length);
   void tokenize(const string &text, int input_ids[MAX_SEQ_LENGTH],
                 int attention_mask[MAX_SEQ_LENGTH], int token_type_ids[MAX_SEQ_LENGTH], int &seq_length);
-                void initFromVocab(string vocab_array[], bool do_lower_case);
 
 //private:
   string _text;
@@ -42,8 +42,8 @@ class CustomizedTokenizer
   void _split_text();
   void _clean_text();
   void _fixed_matching(int &pos, string &token);
+  void _load_vocabFromVocab(string vocab_file[],int vocab_size);
   void _load_vocab(const string &vocab_file);
-  void _load_vocabFromVocab(string vocab_file[]);
 };
 
 
diff --git a/mindspore/lite/flclient/src/main/native/lite_train_jni.cpp b/mindspore/lite/flclient/src/main/native/lite_train_jni.cpp
index b74b40d..ce0f34e 100644
--- a/mindspore/lite/flclient/src/main/native/lite_train_jni.cpp
+++ b/mindspore/lite/flclient/src/main/native/lite_train_jni.cpp
@@ -29,6 +29,7 @@ static jobject fbb;
 static jmethodID create_string_char;
 static jobject jmap;
 static jstring model_path;
+static jobject jmodel_buffer;
 
 char *JstringToChar(JNIEnv *env, jstring jstr) {
   char *rtn = nullptr;
@@ -47,6 +48,30 @@ char *JstringToChar(JNIEnv *env, jstring jstr) {
   return rtn;
 }
 
+void CastJstringArrayToC(JNIEnv *env,jobjectArray vocab_array,std::string c_array[],int size) {
+  for (int i = 0; i < size; i++) {
+    jstring jstr = (jstring)env->GetObjectArrayElement(vocab_array, i);
+    const jsize strLen = env->GetStringUTFLength(jstr);
+    const char *charBuffer = env->GetStringUTFChars(jstr, 0);
+    c_array[i] = std::string(charBuffer, strLen);
+    env->ReleaseStringUTFChars(jstr, charBuffer);
+    env->DeleteLocalRef(jstr);
+  }
+}
+
+char *CreateLocalModelBuffer(JNIEnv *env, jobject modelBuffer) {
+  jbyte *modelAddr = static_cast<jbyte *>(env->GetDirectBufferAddress(modelBuffer));
+  int modelLen = static_cast<int>(env->GetDirectBufferCapacity(modelBuffer));
+  char *buffer(new char[modelLen]);
+  memcpy(buffer, modelAddr, modelLen);
+  return buffer;
+}
+
+void UpdateModelBuffer(JNIEnv *env,char* updated_buffer,int size) {
+  jbyte *modelAddr = static_cast<jbyte *>(env->GetDirectBufferAddress(jmodel_buffer));
+  memcpy(modelAddr, updated_buffer ,size);
+}
+
 extern "C" jint CreateFeatureMap(JNIEnv *env, const char *name, float *data, size_t size) {
   jstring name1 = env->NewStringUTF(name);
   jint name_offset = env->CallIntMethod(fbb, create_string_char, name1);
@@ -66,17 +91,34 @@ extern "C" jint CreateFeatureMap(JNIEnv *env, const char *name, float *data, siz
 }
 
 extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_NativeTrain_train(JNIEnv *env, jobject thiz,
-                                                                             jlong session_ptr,
-                                                                             jint batch_size, jint epoches,
-                                                                             jint early_stop_type) {
+                                                                             jlong session_ptr, jint batch_size,
+                                                                             jint epoches, jint early_stop_type) {
   std::string model_name = JstringToChar(env, model_path);
-  if(model_name.find("lenet") != std::string::npos){
+  if (model_name.find("lenet") != std::string::npos) {
     return TrainLenet(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr), JstringToChar(env, model_path),
                       batch_size, epoches);
+  } else {
+    return TrainBert(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr), JstringToChar(env, model_path),
+                     batch_size, epoches);
   }
-  return TrainBert(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr), JstringToChar(env, model_path),
-                   batch_size, epoches);
+}
+
+extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_NativeTrain_trainFromBuffer(JNIEnv *env, jobject thiz,
+                                                                             jlong session_ptr, jint batch_size,
+                                                                             jint epoches, jint early_stop_type) {
 
+  int model_len = static_cast<int>(env->GetDirectBufferCapacity(jmodel_buffer));
+  std::cout<< "model len:"<< model_len<<std::endl;
+  char* temp_buffer = new char[model_len];
+  size_t model_size = model_len;
+  auto status = TrainBertFromBuffer(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr),batch_size, epoches,temp_buffer,&model_size);
+  if(model_size != model_len) {
+    std::cout<< "java local bytebuffer size not equal model size"<<std::endl;
+    return -1;
+  }
+  UpdateModelBuffer(env,temp_buffer,model_size);
+  delete[] temp_buffer;
+  return status;
 }
 
 extern "C" jlong JNICALL Java_com_huawei_flclient_NativeTrain_createSession(JNIEnv *env, jclass, jstring ms_file,
@@ -84,32 +126,27 @@ extern "C" jlong JNICALL Java_com_huawei_flclient_NativeTrain_createSession(JNIE
   model_path = (jstring)env->NewGlobalRef(ms_file);
   return reinterpret_cast<jlong>(CreateSession(JstringToChar(env, ms_file)));
 }
-char *CreateLocalModelBuffer(JNIEnv *env, jobject modelBuffer) {
-  jbyte *modelAddr = static_cast<jbyte *>(env->GetDirectBufferAddress(modelBuffer));
-  int modelLen = static_cast<int>(env->GetDirectBufferCapacity(modelBuffer));
-  char *buffer(new char[modelLen]);
-  memcpy(buffer, modelAddr, modelLen);
-  return buffer;
-}
 
-extern "C" jlong JNICALL Java_com_huawei_flclient_NativeTrain_createSessionFromBuffer(JNIEnv *env, jclass,jobject model_buffer,jint num_thread) {
 
+extern "C" jlong JNICALL Java_com_huawei_flclient_NativeTrain_createSessionFromBuffer(JNIEnv *env, jclass,
+                                                                                      jobject model_buffer,
+                                                                                      jint num_thread) {
   if (nullptr == model_buffer) {
-//    MS_PRINT("error, buffer is nullptr!");
+    //    MS_PRINT("error, buffer is nullptr!");
     return (jlong) nullptr;
   }
   jlong bufferLen = env->GetDirectBufferCapacity(model_buffer);
   if (0 == bufferLen) {
-//    MS_PRINT("error, bufferLen is 0!");
+    //    MS_PRINT("error, bufferLen is 0!");
     return (jlong) nullptr;
   }
-
+  jmodel_buffer = env->NewGlobalRef(model_buffer);
   char *modelBuffer = CreateLocalModelBuffer(env, model_buffer);
   if (modelBuffer == nullptr) {
-//    MS_PRINT("modelBuffer create failed!");
+    //    MS_PRINT("modelBuffer create failed!");
     return (jlong) nullptr;
   }
-  return reinterpret_cast<jlong>(CreateSession(modelBuffer,bufferLen));
+  return reinterpret_cast<jlong>(CreateSession(modelBuffer, bufferLen));
 }
 
 extern "C" JNIEXPORT jobject JNICALL Java_com_huawei_flclient_NativeTrain_getFeaturesMap(JNIEnv *env, jclass,
@@ -251,11 +288,12 @@ extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_NativeTrain_updateFea
     param->type = mindspore::kNumberTypeFloat32;
     MS_LOG(INFO) << "get feature:" << param->name << ",elenums:" << param->elenums;
   }
-  return UpdateFeatures(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr), JstringToChar(env, model_path),
-                        features_param, size);
+  return UpdateFeatures(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr),
+                        JstringToChar(env, model_path), features_param, size);
 }
 
-extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_NativeTrain_setInput(JNIEnv *env, jobject, jstring files) {
+extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_NativeTrain_setInput(JNIEnv *env, jobject,
+                                                                                     jstring files) {
   std::string input_files = JstringToChar(env, files);
   std::string pattern = ",";
   std::string strs = input_files + pattern;
@@ -267,59 +305,73 @@ extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_NativeTrain_setInput(
     strs = strs.substr(pos + 1, strs.size());
     pos = strs.find(pattern);
   }
-  if (res.size() == 2) {
+  if (res[0].find("bin") != std::string::npos) {
     return SetLenetInputs(res[0], res[1]);
-  } else if (res.size() == 3) {
-    return SetBertInputs(res[0], res[1], res[2]);
+  } else {
+    return SetBertInputs(res[0], res[1]);
   }
-  std::cout << "input files error" << std::endl;
-  return -1;
 }
 
-extern "C" JNIEXPORT jfloat JNICALL Java_com_huawei_flclient_NativeTrain_infer(JNIEnv *env, jclass, jlong session_ptr) {
+extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_NativeTrain_setBertInputFromArr(JNIEnv *env, jobject,
+                                                                                           jstring train_file,
+                                                                                           jobjectArray vocab_array) {
+  jsize size = env->GetArrayLength(vocab_array);
+  std::string c_vocab_array[size];
+  CastJstringArrayToC(env,vocab_array,c_vocab_array,size);
+  return SetBertInputsFromArray(JstringToChar(env, train_file), c_vocab_array,size);
+}
 
+extern "C" JNIEXPORT jfloat JNICALL Java_com_huawei_flclient_NativeTrain_infer(JNIEnv *env, jclass, jlong session_ptr) {
   std::string model_name = JstringToChar(env, model_path);
-  if(model_name.find("lenet") != std::string::npos){
+  if (model_name.find("lenet") != std::string::npos) {
     return InferLenet(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr));
   }
   return InferBert(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr));
 }
 
-extern "C" JNIEXPORT jintArray JNICALL Java_com_huawei_flclient_NativeTrain_getInferLabels(JNIEnv *env, jclass, jlong session_ptr) {
-
+extern "C" JNIEXPORT jintArray JNICALL Java_com_huawei_flclient_NativeTrain_getInferLabels(JNIEnv *env, jclass,
+                                                                                           jlong session_ptr) {
   std::string model_name = JstringToChar(env, model_path);
   std::vector<int> infer_result;
-  if(model_name.find("lenet") != std::string::npos){
-   infer_result = GetLenetInferRes(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr));
+  if (model_name.find("lenet") != std::string::npos) {
+    infer_result = GetLenetInferRes(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr));
   } else {
     infer_result = GetBertInferRes(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr));
   }
   jintArray jArray = env->NewIntArray(infer_result.size());
   jint *jnum = new jint[infer_result.size()];
-  for(int i=0;i<infer_result.size();i++) {
-    *(jnum+i) = infer_result[i];
+  for (int i = 0; i < infer_result.size(); i++) {
+    *(jnum + i) = infer_result[i];
   }
   env->SetIntArrayRegion(jArray, 0, infer_result.size(), jnum);
   return jArray;
 }
 
-extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_NativeTrain_getInferLabel(JNIEnv *env, jclass, jlong session_ptr,jstring input_str ,jstring vocab_file) {
-  return InferFromVocabFile(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr),JstringToChar(env, input_str),JstringToChar(env, vocab_file));
+extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_NativeTrain_getInferLabel(JNIEnv *env, jclass,
+                                                                                     jlong session_ptr,
+                                                                                     jstring input_str,
+                                                                                     jstring vocab_file) {
+  return InferFromVocabFile(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr),
+                            JstringToChar(env, input_str), JstringToChar(env, vocab_file));
 }
 
-extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_NativeTrain_getInferLabelFromVocab(JNIEnv *env, jclass, jlong session_ptr,jstring input_str ,jobjectArray vocab_array) {
 
+
+extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_NativeTrain_getInferLabelFromVocab(
+  JNIEnv *env, jclass, jlong session_ptr, jstring input_str, jobjectArray vocab_array) {
+  if(input_str == NULL) {
+    std::cout<< "input cannot empty" << std::endl;
+    return -1;
+  }
+  auto c_input = JstringToChar(env, input_str);
+  if(c_input == nullptr) {
+    std::cout<< "input cannot empty" << std::endl;
+    return -1;
+  }
   jsize size = env->GetArrayLength(vocab_array);
   std::string c_vocab_array[size];
-  for(int i=0;i<size;i++) {
-    jstring jstr = (jstring)env->GetObjectArrayElement(vocab_array, i);
-    const jsize strLen = env->GetStringUTFLength(jstr);
-    const char *charBuffer = env->GetStringUTFChars(jstr, 0);
-    c_vocab_array[i] = std::string (charBuffer, strLen);
-    env->ReleaseStringUTFChars(jstr, charBuffer);
-    env->DeleteLocalRef(jstr);
-  }
-  return InferFromVocabArr(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr),JstringToChar(env, input_str),c_vocab_array);
+  CastJstringArrayToC(env,vocab_array,c_vocab_array,size);
+  return InferFromVocabArr(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr),c_input, c_vocab_array,size);
 }
 
 extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_NativeTrain_free(JNIEnv *env, jclass, jlong session_ptr) {
@@ -328,9 +380,13 @@ extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_NativeTrain_free(JNIE
   if (0 != session_ptr) {
     delete (reinterpret_cast<mindspore::session::TrainSession *>(session_ptr));
   }
-  std::string model_name = JstringToChar(env, model_path);
-  if(model_name.find("lenet") != std::string::npos){
-    FreeLenetInput();
+  if (model_path != NULL) {
+    std::string model_name = JstringToChar(env, model_path);
+    if (model_name.find("lenet") != std::string::npos) {
+      FreeLenetInput();
+    } else {
+      FreeBertInput();
+    }
   } else {
     FreeBertInput();
   }
diff --git a/mindspore/lite/flclient/src/main/native/test_train.cc b/mindspore/lite/flclient/src/main/native/test_train.cc
index 94f7584..89d9121 100644
--- a/mindspore/lite/flclient/src/main/native/test_train.cc
+++ b/mindspore/lite/flclient/src/main/native/test_train.cc
@@ -6,49 +6,49 @@
 #include "lenet_train.h"
 #include "util.h"
 int main() {
-  std::cout << "----------begin train lenet-------" << std::endl;
-  std::string lenet_ms_file = "/home/meng/zj10/fl/mindspore/mindspore/lite/lenet_train.mindir.ms";
-  std::string lenet_data_input =
-    "/home/meng/zj10/fl/mindspore/mindspore/lite/flclient/src/test/resources/data_bin_new/f0049_32/"
-    "f0049_32_bn_11_train_data.bin";
-  std::string lenet_label_input =
-    "/home/meng/zj10/fl/mindspore/mindspore/lite/flclient/src/test/resources/data_bin_new/f0049_32/"
-    "f0049_32_bn_11_train_label.bin";
-
-  auto lenet_input_size = SetLenetInputs(lenet_data_input,lenet_label_input);
-  std::cout<< "total train size:"<< lenet_input_size<<std::endl;
-  auto session = CreateSession(lenet_ms_file);
-  auto status = TrainLenet(session, lenet_ms_file, 32, 2);
-  if (status != 0) {
-    std::cout << "train failed" << std::endl;
-  }
-  mindspore::session::TrainFeatureParam **feature;
-  int size = 0;
-  status = GetFeatures(session, &feature, &size);
-  if(status != 0) {
-    std::cout<< "get feature failed"<<std::endl;
-  }
-  std::cout << "get total features:" << size << std::endl;
-  for (int i = 0; i < size; i++) {
-    std::cout << "name:" << feature[i]->name << std::endl;
-  }
-  std::string lenet_test_data = "/home/meng/zj10/fl/mindspore/mindspore/lite/flclient/src/test/resources/data_bin_new/f0049_32/f0049_32_bn_1_test_data.bin";
-  std::string lenet_test_label = "/home/meng/zj10/fl/mindspore/mindspore/lite/flclient/src/test/resources/data_bin_new/f0049_32/f0049_32_bn_1_test_label.bin";
-
-  (void)SetLenetInputs(lenet_test_data,lenet_test_label);
-  std::cout<< "cal acc:"<< InferLenet(session) << std::endl;
-  auto infer_result = GetLenetInferRes(session);
-  for(auto infer_label:infer_result) {
-    std::cout<< "infer_label:"<< infer_label<<std::endl;
-  }
-  delete session;
+//  std::cout << "----------begin train lenet-------" << std::endl;
+//  std::string lenet_ms_file = "/home/meng/zj10/fl/mindspore/mindspore/lite/lenet_train.mindir.ms";
+//  std::string lenet_data_input =
+//    "/home/meng/zj10/fl/mindspore/mindspore/lite/flclient/src/test/resources/data_bin_new/f0049_32/"
+//    "f0049_32_bn_11_train_data.bin";
+//  std::string lenet_label_input =
+//    "/home/meng/zj10/fl/mindspore/mindspore/lite/flclient/src/test/resources/data_bin_new/f0049_32/"
+//    "f0049_32_bn_11_train_label.bin";
+//
+//  auto lenet_input_size = SetLenetInputs(lenet_data_input,lenet_label_input);
+//  std::cout<< "total train size:"<< lenet_input_size<<std::endl;
+//  auto session = CreateSession(lenet_ms_file);
+//  auto status = TrainLenet(session, lenet_ms_file, 32, 2);
+//  if (status != 0) {
+//    std::cout << "train failed" << std::endl;
+//  }
+//  mindspore::session::TrainFeatureParam **feature;
+//  int size = 0;
+//  status = GetFeatures(session, &feature, &size);
+//  if(status != 0) {
+//    std::cout<< "get feature failed"<<std::endl;
+//  }
+//  std::cout << "get total features:" << size << std::endl;
+//  for (int i = 0; i < size; i++) {
+//    std::cout << "name:" << feature[i]->name << std::endl;
+//  }
+//  std::string lenet_test_data = "/home/meng/zj10/fl/mindspore/mindspore/lite/flclient/src/test/resources/data_bin_new/f0049_32/f0049_32_bn_1_test_data.bin";
+//  std::string lenet_test_label = "/home/meng/zj10/fl/mindspore/mindspore/lite/flclient/src/test/resources/data_bin_new/f0049_32/f0049_32_bn_1_test_label.bin";
+//
+//  (void)SetLenetInputs(lenet_test_data,lenet_test_label);
+//  std::cout<< "cal acc:"<< InferLenet(session) << std::endl;
+//  auto infer_result = GetLenetInferRes(session);
+//  for(auto infer_label:infer_result) {
+//    std::cout<< "infer_label:"<< infer_label<<std::endl;
+//  }
+//  delete session;
 
 
   std::cout << "----------begin train bert-------" << std::endl;
   std::string vocab_file = "/home/meng/zj10/fl/mindspore/mindspore/lite/flclient/src/main/native/dataset/vocab.txt";
   std::string train_file = "/home/meng/zj10/fl/mindspore/mindspore/lite/flclient/src/main/native/dataset/0.tsv";
   std::string labels_file = "/home/meng/zj10/fl/mindspore/mindspore/lite/flclient/src/main/native/dataset/label.tsv";
-  auto train_data_size = SetBertInputs(train_file, vocab_file, labels_file);
+  auto train_data_size = SetBertInputs(train_file, vocab_file);
   std::cout << "total train data size:" << train_data_size << std::endl;
   if(train_data_size == -1) {
     std::cout<< "set bert inputs failed" << std::endl;
@@ -57,16 +57,18 @@ int main() {
   std::string ms_file = "/home/meng/zj10/fl/mindspore/mindspore/lite/albert.ms";
   int epoches = 2;
   int batch_size = 16;
-  session = CreateSession(ms_file);
-  status = TrainBert(session, ms_file, 16, epoches);
+  auto session = CreateSession(ms_file);
+  auto status = TrainBert(session, ms_file, batch_size, epoches);
   if (status != 0) {
     std::cout << "train failed" << std::endl;
   }
-  size = 0;
-  status = GetFeatures(session, &feature, &size);
-  std::cout << "get total features:" << size << std::endl;
-  for (int i = 0; i < size; i++) {
-    std::cout << "name:" << feature[i]->name << std::endl;
-  }
+//  auto size = 0;
+//  status = GetFeatures(session, &feature, &size);
+//  std::cout << "get total features:" << size << std::endl;
+//  for (int i = 0; i < size; i++) {
+//    std::cout << "name:" << feature[i]->name << std::endl;
+//  }
+  std::string input = "DI鎺у埗ME涓€鐩存槸鍚勭粍浠剁殑鏍囨潌鍟?;
+  std::cout<<"infer result:"<< InferFromVocabFile(session,input,vocab_file) <<std::endl;
   delete session;
 }
\ No newline at end of file
diff --git a/mindspore/lite/flclient/src/main/native/util.cpp b/mindspore/lite/flclient/src/main/native/util.cpp
index 462f965..8a2a2d0 100644
--- a/mindspore/lite/flclient/src/main/native/util.cpp
+++ b/mindspore/lite/flclient/src/main/native/util.cpp
@@ -37,13 +37,15 @@ float GetLoss(mindspore::session::TrainSession *train_session) {
   auto loss = reinterpret_cast<float *>(outputsv->MutableData());
   return loss[0];
 }
-mindspore::session::TrainSession *CreateSession(const std::string &ms_file) {
+TrainSession *CreateSession(const std::string &ms_file) {
   // create model file
   mindspore::lite::Context context;
   context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = mindspore::lite::NO_BIND;
   context.thread_num_ = 1;
   bool train_mode = false;
-  return mindspore::session::TrainSession::CreateSession(ms_file, &context, train_mode);
+  size_t size=0;
+//  auto *model = mindspore::lite::Model::Import(ms_file.c_str(),size);
+  return mindspore::session::TrainSession::CreateSession(ms_file,&context, train_mode);
 }
 
 TrainSession *CreateSession(char* model_buffer,size_t buffen_len) {
@@ -64,7 +66,7 @@ std::vector<int> GetInferResult(TrainSession *session,int num_of_class) {
   auto inputs = session->GetInputs();
   auto batch_size = inputs[1]->shape()[0];
   auto outputsv = SearchOutputsForSize(session, batch_size * num_of_class);
-  std::cout<< "ouput tensor name:"<< outputsv->tensor_name()<<std::endl;
+//  std::cout<< "ouput tensor name:"<< outputsv->tensor_name()<<std::endl;
   auto scores = reinterpret_cast<float *>(outputsv->MutableData());
   std::vector<int> infer_result(batch_size);
   for (int b = 0; b < batch_size; b++) {
-- 
2.7.4

