<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

#Data process

In [None]:
From cb8484a4126de3624c4df16b48d47364a44e69f0 Mon Sep 17 00:00:00 2001
From: guohongzilong <guohongzilong@huawei.com>
Date: Fri, 12 Mar 2021 11:40:35 +0800
Subject: [PATCH] sync lite code

---
 mindspore/lite/examples/train_lenet/Makefile       |   8 +-
 .../lite/examples/train_lenet/src/lenet_train.cc   | 307 +++++++++++++++++++++
 .../lite/examples/train_lenet/src/lenet_train.h    |  35 +++
 .../lite/examples/train_lenet/src/test_run.cc      |  60 ++++
 .../lite/flclient/src/main/native/CMakeLists.txt   |  10 +-
 .../flclient/src/main/native/include/lenet_train.h |   2 +-
 .../flclient/src/main/native/src/lenet_train.cpp   |  80 +++---
 mindspore/lite/include/train_session.h             |  12 +
 mindspore/lite/schema/ops.fbs                      |  17 +-
 mindspore/lite/src/train/train_session.cc          |  58 ++++
 mindspore/lite/src/train/train_session.h           |   4 +
 11 files changed, 538 insertions(+), 55 deletions(-)
 create mode 100644 mindspore/lite/examples/train_lenet/src/lenet_train.cc
 create mode 100644 mindspore/lite/examples/train_lenet/src/lenet_train.h
 create mode 100644 mindspore/lite/examples/train_lenet/src/test_run.cc

diff --git a/mindspore/lite/examples/train_lenet/Makefile b/mindspore/lite/examples/train_lenet/Makefile
index 7e2b69c..5aabd01 100644
--- a/mindspore/lite/examples/train_lenet/Makefile
+++ b/mindspore/lite/examples/train_lenet/Makefile
@@ -5,15 +5,17 @@ LMDLIB:=-lminddata-lite -ljpeg
 LHIAILIB:=-lhiai_ir_build  -lhiai_ir -lhiai
 MSDIR:=$(realpath package-$(TARGET)/lib)
 
-SRC:=src/net_runner.cc
+SRC:=src/test_run.cc src/lenet_train.cc
 OBJ:=$(SRC:.cc=.o)
 
 CFLAGS := -Ofast -std=c++17  \
 	-I . \
+	   -I ../../ \
         -I ./msl \
         -I ./msl/minddata \
-        -I ./msl/third_party/flatbuffers/include
-
+        -I ./msl/third_party/flatbuffers/include \
+           -I ./msl/include \
+                   -I ../../build \
 
 ifeq ($(TARGET),arm64)
 CXX :=  ${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/bin/clang++
diff --git a/mindspore/lite/examples/train_lenet/src/lenet_train.cc b/mindspore/lite/examples/train_lenet/src/lenet_train.cc
new file mode 100644
index 0000000..e6c6a90
--- /dev/null
+++ b/mindspore/lite/examples/train_lenet/src/lenet_train.cc
@@ -0,0 +1,307 @@
+/**
+ * Copyright 2020 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "lenet_train.h"
+#include <cstring>
+#include <fstream>
+#include <iostream>
+#include "include/api/lite_context.h"
+#include "include/context.h"
+#include "include/errorcode.h"
+#include "limits.h"
+
+static char *fl_lenet_I0 = 0;
+static char *fl_lenet_I1 = 0;
+unsigned int seed_ = time(NULL);
+
+std::vector<int> FillInputData(mindspore::session::TrainSession *train_session, int batch_num, bool serially) {
+  std::vector<int> labels_vec;
+  auto inputs = train_session->GetInputs();
+  int batch_size = inputs[0]->shape()[0];
+  static unsigned int idx = 1;
+  int data_size = inputs[0]->ElementsNum() / batch_size;
+  int num_classes = inputs[1]->shape()[1];
+  float *input_data = reinterpret_cast<float *>(inputs.at(0)->MutableData());
+  auto labels = reinterpret_cast<float *>(inputs.at(1)->MutableData());
+  std::fill(labels, labels + inputs.at(1)->ElementsNum(), 0.f);
+  for (int i = 0; i < batch_size; i++) {
+    if (serially) {
+      idx = ++idx % (batch_num*batch_size);
+    } else {
+      idx = rand_r(&seed_) % (batch_num*batch_size);
+    }
+    std::memcpy(input_data + i * data_size, (float*)fl_lenet_I0 + idx * data_size, data_size*sizeof(float));
+    int label_idx = *(reinterpret_cast<int *>(fl_lenet_I1) + idx);
+    labels[i * num_classes + label_idx] = 1.0;  // Model expects labels in onehot representation
+    labels_vec.push_back(label_idx);
+  }
+  return labels_vec;
+}
+
+
+std::vector<int> FillInputData2(mindspore::session::TrainSession *train_session, int batch_idx, bool serially) {
+  std::vector<int> labels_vec;
+  auto inputs = train_session->GetInputs();
+  int batch_size = inputs[0]->shape()[0];
+  static unsigned int idx = 1;
+  int data_size = inputs[0]->ElementsNum() / batch_size;
+  int num_classes = inputs[1]->shape()[1];
+  float *input_data = reinterpret_cast<float *>(inputs.at(0)->MutableData());
+  auto labels = reinterpret_cast<float *>(inputs.at(1)->MutableData());
+  std::fill(labels, labels + inputs.at(1)->ElementsNum(), 0.f);
+  for (int i = 0; i < batch_size; i++) {
+    idx = i;
+    std::memcpy(input_data + i * data_size, (float*)fl_lenet_I0 +batch_idx* inputs[0]->ElementsNum() + idx * data_size, data_size*sizeof(float));
+    int label_idx = *(reinterpret_cast<int *>(fl_lenet_I1)+batch_idx *batch_size + idx);
+    labels[i * num_classes + label_idx] = 1.0;  // Model expects labels in onehot representation
+    labels_vec.push_back(label_idx);
+
+//    std::cout<< "fill input data:"<< ",idx:"<<idx<< ",labels idx:"<< label_idx <<std::endl;
+//    float sum_input = 0.0f;
+//    float sum_fl_input = 0.0f;
+//    for(int j=0;j<data_size;j++) {
+//      sum_input+=*(input_data+i*data_size+j);
+//      sum_fl_input+=*((float*)fl_lenet_I0 + batch_idx* inputs[0]->ElementsNum()+idx * data_size+j);
+//    }
+//    std::cout<< "sum_input:"<<sum_input<<",sum_fl_input"<<sum_fl_input<<std::endl;
+//    std::cout<<"------------------"<<std::endl;
+  }
+//    std::cout<< "-------------next batchsize----------"<<std::endl;
+  return labels_vec;
+}
+
+mindspore::tensor::MSTensor *SearchOutputsForSize(mindspore::session::TrainSession *train_session, size_t size) {
+  auto outputs = train_session->GetOutputs();
+  for (auto it = outputs.begin(); it != outputs.end(); ++it) {
+    if (it->second->ElementsNum() == size) return it->second;
+  }
+ std::cout << "Model does not have an output tensor with size ";
+  return nullptr;
+}
+
+float GetLoss(mindspore::session::TrainSession *train_session) {
+  auto outputsv = SearchOutputsForSize(train_session, 1);  // Search for Loss which is a single value tensor
+  if (outputsv == nullptr) {
+    return 10000;
+  }
+  auto loss = reinterpret_cast<float *>(outputsv->MutableData());
+  return loss[0];
+}
+mindspore::session::TrainSession *GetSession(const std::string &ms_file, bool train_mode) {
+  // create model file
+  mindspore::lite::Context context;
+  context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = mindspore::lite::NO_BIND;
+  context.thread_num_ = 1;
+  return mindspore::session::TrainSession::CreateSession(ms_file, &context, train_mode);
+}
+
+float CalculateAccuracy(mindspore::session::TrainSession *session,int batch_num) {
+  session->Eval();
+  auto labels = FillInputData(session, batch_num, false);
+  session->RunGraph();
+  auto inputs = session->GetInputs();
+  auto batch_size = inputs[1]->shape()[0];
+  auto num_of_class = inputs[1]->shape()[1];
+  auto outputsv = SearchOutputsForSize(session, batch_size * num_of_class);
+  auto scores = reinterpret_cast<float *>(outputsv->MutableData());
+  float accuracy = 0.0;
+  for (int b = 0; b < batch_size; b++) {
+    int max_idx = 0;
+    float max_score = scores[num_of_class * b];
+    for (int c = 0; c < num_of_class; c++) {
+      if (scores[num_of_class * b + c] > max_score) {
+        max_score = scores[num_of_class * b + c];
+        max_idx = c;
+      }
+    }
+    if (labels[b] == max_idx) accuracy += 1.0;
+  }
+  return accuracy/batch_size;
+}
+
+
+// net training function
+float fl_lenet_lite_Inference(const std::string &ms_file, int batch_num, int test_nums) {
+  auto session = GetSession(ms_file, false);
+  char *origin_input[] = {fl_lenet_I0, fl_lenet_I1};
+  float sum_acc = 0.0f;
+  for (int j = 0; j < test_nums; ++j) {
+    auto acc_per_test = CalculateAccuracy(session,batch_num);
+    sum_acc+=acc_per_test;
+    std::cout << "infer:"<< j <<"times,acc is " << acc_per_test << std::endl;
+  }
+  fl_lenet_I0 = origin_input[0];
+  fl_lenet_I1 = origin_input[1];
+  return sum_acc/test_nums;
+}
+
+
+// net training function
+int fl_lenet_lite_Train(const std::string &ms_file, const int batch_num, const int iterations) {
+  auto session = GetSession(ms_file, true);
+  if (iterations <= 0) {
+   std::cout << "error iterations or epoch!, epoch:"
+                  << ", iterations" << iterations;
+    return mindspore::lite::RET_ERROR;
+  }
+ std::cout << "total iterations :" << iterations << "batch_num:" << batch_num <<std::endl;
+  char *origin_input[] = {fl_lenet_I0, fl_lenet_I1};
+  for (int j = 0; j < iterations/batch_num; ++j) {
+    float sum_loss_per_epoch = 0.0f;
+    for(int k=0;k<batch_num;++k) {
+//      FillInputData(session, batch_num, false);
+      FillInputData2(session,k,false);
+      session->RunGraph(nullptr, nullptr);
+      sum_loss_per_epoch+=GetLoss(session);
+    }
+    std::cout << "epoch " << "[" <<j<<"]" << ",mean Loss " << sum_loss_per_epoch/batch_num <<",train acc "<< CalculateAccuracy(session,batch_num) <<std::endl;
+    session->Train();
+  }
+  session->SaveToFile(ms_file);
+  fl_lenet_I0 = origin_input[0];
+  fl_lenet_I1 = origin_input[1];
+  return mindspore::lite::RET_OK;
+}
+
+int fl_lenet_lite_UpdateFeatures(const std::string &update_ms_file, TrainFeatureParam *new_features, int size) {
+  auto train_session = GetSession(update_ms_file, false);
+  auto status = train_session->UpdateFeatureMaps(update_ms_file, new_features, size);
+  if (status != mindspore::lite::RET_OK) {
+   std::cout << "update model feature map failed" << update_ms_file;
+  }
+  delete train_session;
+  return status;
+}
+
+int fl_lenet_lite_GetFeatures(const std::string &update_ms_file, mindspore::session::TrainFeatureParam ***feature,
+                              int *size) {
+  auto train_session = GetSession(update_ms_file, false);
+  std::vector<mindspore::session::TrainFeatureParam *> new_features;
+  auto status = train_session->GetFeatureMaps(&new_features);
+  if (status != mindspore::lite::RET_OK) {
+   std::cout << "get model feature map failed" << update_ms_file;
+    delete train_session;
+    return mindspore::lite::RET_ERROR;
+  }
+  *feature = new (std::nothrow) TrainFeatureParam *[new_features.size()];
+  if (*feature == nullptr) {
+   std::cout << "create features failed";
+    delete train_session;
+    return mindspore::lite::RET_ERROR;
+  }
+  for (int i = 0; i < new_features.size(); i++) {
+    (*feature)[i] = new_features[i];
+  }
+  *size = new_features.size();
+  delete train_session;
+  return mindspore::lite::RET_OK;
+}
+
+std::string RealPath(const char *path) {
+  if (path == nullptr) {
+   std::cout << "path is nullptr";
+    return "";
+  }
+  if ((strlen(path)) >= PATH_MAX) {
+   std::cout << "path is too long";
+    return "";
+  }
+  auto resolved_path = std::make_unique<char[]>(PATH_MAX);
+  if (resolved_path == nullptr) {
+   std::cout << "new resolved_path failed";
+    return "";
+  }
+#ifdef _WIN32
+  char *real_path = _fullpath(resolved_path.get(), path, 1024);
+#else
+  char *real_path = realpath(path, resolved_path.get());
+#endif
+  if (real_path == nullptr || strlen(real_path) == 0) {
+   std::cout << "file path is not valid : " << path;
+    return "";
+  }
+  std::string res = resolved_path.get();
+  return res;
+}
+
+char *ReadFile(const char *file, size_t *size) {
+  if (file == nullptr) {
+   std::cout << "file is nullptr";
+    return nullptr;
+  }
+  //  MS_ASSERT(size != nullptr);
+  std::string real_path = RealPath(file);
+  std::ifstream ifs(real_path);
+  if (!ifs.good()) {
+   std::cout << "file: " << real_path << " is not exist";
+    return nullptr;
+  }
+
+  if (!ifs.is_open()) {
+   std::cout << "file: " << real_path << " open failed";
+    return nullptr;
+  }
+
+  ifs.seekg(0, std::ios::end);
+  *size = ifs.tellg();
+  std::unique_ptr<char[]> buf(new (std::nothrow) char[*size]);
+  if (buf == nullptr) {
+   std::cout << "malloc buf failed, file: " << real_path;
+    ifs.close();
+    return nullptr;
+  }
+  ifs.seekg(0, std::ios::beg);
+  ifs.read(buf.get(), *size);
+  ifs.close();
+
+  return buf.release();
+}
+
+// Set input tensors.
+int fl_lenet_lite_SetInputs(const std::string &files, int num) {
+  std::vector<std::string> res;
+  if (files.empty()) {
+   std::cout << "files empty";
+    return -1;
+  }
+  std::string pattern = ",";
+  std::string strs = files + pattern;
+  size_t pos = strs.find(pattern);
+  while (pos != strs.npos) {
+    std::string temp = strs.substr(0, pos);
+    res.push_back(temp);
+    strs = strs.substr(pos + 1, strs.size());
+    pos = strs.find(pattern);
+  }
+  if (res.size() != 2) {
+   std::cout << "res size not equal 2";
+    return -1;
+  }
+  for (int i = 0; i < 2; i++) {
+    size_t size;
+    char *bin_buf = ReadFile(res[i].c_str(), &size);
+    if (bin_buf == nullptr) {
+     std::cout << "ReadFile return nullptr";
+      return mindspore::lite::RET_ERROR;
+    }
+    if (i == 0) {
+      fl_lenet_I0 = bin_buf;
+    }
+    if (i == 1) {
+      fl_lenet_I1 = bin_buf;
+    }
+  }
+  return 0;
+}
diff --git a/mindspore/lite/examples/train_lenet/src/lenet_train.h b/mindspore/lite/examples/train_lenet/src/lenet_train.h
new file mode 100644
index 0000000..ad173f5
--- /dev/null
+++ b/mindspore/lite/examples/train_lenet/src/lenet_train.h
@@ -0,0 +1,35 @@
+/**
+ * Copyright 2020 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef MSLITE_FL_LITE_LENET_H
+#define MSLITE_FL_LITE_LENET_H
+
+#include <string>
+#include "include/train_session.h"
+
+using mindspore::session::TrainFeatureParam;
+
+int fl_lenet_lite_Train(const std::string &ms_file, const int batch_num, const int iterations);
+
+float fl_lenet_lite_Inference(const std::string &ms_file, int batch_num, int test_nums);
+
+int fl_lenet_lite_GetFeatures(const std::string &update_ms_file, mindspore::session::TrainFeatureParam ***features,
+                              int *size);
+int fl_lenet_lite_UpdateFeatures(const std::string &update_ms_file, TrainFeatureParam *new_features, int size);
+mindspore::session::TrainSession *GetSession(const std::string &ms_file, bool train_mode = false);
+
+int fl_lenet_lite_SetInputs(const std::string &files, int num);
+#endif  // MSLITE_FL_LITE_LENET_H
diff --git a/mindspore/lite/examples/train_lenet/src/test_run.cc b/mindspore/lite/examples/train_lenet/src/test_run.cc
new file mode 100644
index 0000000..ae90a7d
--- /dev/null
+++ b/mindspore/lite/examples/train_lenet/src/test_run.cc
@@ -0,0 +1,60 @@
+/**
+ * Copyright 2020 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "src/net_runner.h"
+#include <math.h>
+#include <getopt.h>
+#include <cstring>
+#include <iostream>
+#include <fstream>
+#include <utility>
+#include "include/context.h"
+#include "include/train/loss_monitor.h"
+#include "include/train/ckpt_saver.h"
+#include "include/train/lr_scheduler.h"
+#include "include/train/accuracy_metrics.h"
+#include "include/train/classification_train_accuracy_monitor.h"
+#include "lenet_train.h"
+
+
+int main(int argc, char **argv) {
+  std::string train_dataset =
+    "/home/meng/zj10/hdc/mindspore/mindspore/lite/flclient/src/main/resources/client_data/f0049_32/"
+    "f0049_32_train_data.bin,/home/meng/zj10/hdc/mindspore/mindspore/lite/flclient/src/main/resources/client_data/"
+    "f0049_32/f0049_32_train_label.bin";
+  auto status = fl_lenet_lite_SetInputs(train_dataset, 2);
+  std::cout << "set input ok" << std::endl;
+  if (status != 0) {
+    std::cout << "set inputs error";
+  }
+  std::string ms_file = "/home/meng/zj10/hdc/fl/mindspore/mindspore/lite/lenet_train.mindir.ms";
+  int batches_per_epoch = 11;
+  fl_lenet_lite_Train(ms_file,batches_per_epoch, 11*2000);
+
+  // eval
+  std::string test_dataset =
+    "/home/meng/zj10/hdc/mindspore/mindspore/lite/flclient/src/main/resources/client_data/f0049_32/"
+    "f0049_32_test_data.bin,/home/meng/zj10/hdc/mindspore/mindspore/lite/flclient/src/main/resources/client_data/"
+    "f0049_32/f0049_32_test_label.bin";
+  status = fl_lenet_lite_SetInputs(test_dataset, 2);
+  if (status != 0) {
+    std::cout << "set inputs error";
+  }
+  batches_per_epoch = 1;
+  auto accuracy =fl_lenet_lite_Inference(ms_file,batches_per_epoch, 1);
+
+  return 0;
+}
diff --git a/mindspore/lite/flclient/src/main/native/CMakeLists.txt b/mindspore/lite/flclient/src/main/native/CMakeLists.txt
index 81a9ff9..1ee548e 100644
--- a/mindspore/lite/flclient/src/main/native/CMakeLists.txt
+++ b/mindspore/lite/flclient/src/main/native/CMakeLists.txt
@@ -12,12 +12,10 @@ set(LITE_DIR ${TOP_DIR}/mindspore/lite)
 set(MS_VERSION_MAJOR ${MS_VERSION_MAJOR})
 set(MS_VERSION_MINOR ${MS_VERSION_MINOR})
 set(MS_VERSION_REVISION ${MS_VERSION_REVISION})
-set(CMAKE_C_FLAGS
-    "${CMAKE_C_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} "
-    "-DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
-set(CMAKE_CXX_FLAGS
-      "${CMAKE_CXX_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} "
-      "-DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
+set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} \
+  -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} \
+  -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
 set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
 
 include_directories(${CMAKE_CURRENT_SOURCE_DIR})
diff --git a/mindspore/lite/flclient/src/main/native/include/lenet_train.h b/mindspore/lite/flclient/src/main/native/include/lenet_train.h
index 4d448af..ad173f5 100644
--- a/mindspore/lite/flclient/src/main/native/include/lenet_train.h
+++ b/mindspore/lite/flclient/src/main/native/include/lenet_train.h
@@ -24,7 +24,7 @@ using mindspore::session::TrainFeatureParam;
 
 int fl_lenet_lite_Train(const std::string &ms_file, const int batch_num, const int iterations);
 
-int fl_lenet_lite_Inference(const std::string &ms_file, int batch_num, int test_nums);
+float fl_lenet_lite_Inference(const std::string &ms_file, int batch_num, int test_nums);
 
 int fl_lenet_lite_GetFeatures(const std::string &update_ms_file, mindspore::session::TrainFeatureParam ***features,
                               int *size);
diff --git a/mindspore/lite/flclient/src/main/native/src/lenet_train.cpp b/mindspore/lite/flclient/src/main/native/src/lenet_train.cpp
index 4455efb..728f0fe 100644
--- a/mindspore/lite/flclient/src/main/native/src/lenet_train.cpp
+++ b/mindspore/lite/flclient/src/main/native/src/lenet_train.cpp
@@ -21,6 +21,7 @@
 #include "include/context.h"
 #include "include/errorcode.h"
 #include "src/common/log_adapter.h"
+#include "limits.h"
 
 static char *fl_lenet_I0 = 0;
 static char *fl_lenet_I1 = 0;
@@ -33,16 +34,16 @@ std::vector<int> FillInputData(mindspore::session::TrainSession *train_session,
   static unsigned int idx = 1;
   int data_size = inputs[0]->ElementsNum() / batch_size;
   int num_classes = inputs[1]->shape()[1];
-  char *input_data = reinterpret_cast<char *>(inputs.at(0)->MutableData());
+  float* input_data = reinterpret_cast<float *>(inputs.at(0)->MutableData());
   auto labels = reinterpret_cast<float *>(inputs.at(1)->MutableData());
   std::fill(labels, labels + inputs.at(1)->ElementsNum(), 0.f);
   for (int i = 0; i < batch_size; i++) {
     if (serially) {
-      idx = ++idx % batch_num;
+      idx = ++idx % (batch_num*batch_size);
     } else {
-      idx = rand_r(&seed_) % batch_num;
+      idx = rand_r(&seed_) % (batch_num*batch_size);
     }
-    std::memcpy(input_data + i * data_size, fl_lenet_I0 + idx * data_size, data_size);
+    std::memcpy(input_data + i * data_size, (float*)fl_lenet_I0 + idx * data_size, data_size*sizeof(float));
     int label_idx = *(reinterpret_cast<int *>(fl_lenet_I1) + idx);
     labels[i * num_classes + label_idx] = 1.0;  // Model expects labels in onehot representation
     labels_vec.push_back(label_idx);
@@ -75,42 +76,47 @@ mindspore::session::TrainSession *GetSession(const std::string &ms_file, bool tr
   return mindspore::session::TrainSession::CreateSession(ms_file, &context, train_mode);
 }
 
-// net training function
-int fl_lenet_lite_Inference(const std::string &ms_file, int batch_num, int test_nums) {
-  auto session = GetSession(ms_file, false);
-  char *origin_input[] = {fl_lenet_I0, fl_lenet_I1};
-  float accuracy = 0.0;
+float CalculateAccuracy(mindspore::session::TrainSession *session,int batch_num) {
   session->Eval();
+  auto labels = FillInputData(session, batch_num, false);
+  session->RunGraph();
   auto inputs = session->GetInputs();
-  if (inputs[1]->shape().size() != 2) {
-    return mindspore::lite::RET_ERROR;
-  }
   auto batch_size = inputs[1]->shape()[0];
   auto num_of_class = inputs[1]->shape()[1];
-  for (int j = 0; j < test_nums; ++j) {
-    auto labels = FillInputData(session, batch_num, true);
-    session->RunGraph();
-    auto outputsv = SearchOutputsForSize(session, batch_size * num_of_class);
-    auto scores = reinterpret_cast<float *>(outputsv->MutableData());
-    for (int b = 0; b < batch_size; b++) {
-      int max_idx = 0;
-      float max_score = scores[num_of_class * b];
-      for (int c = 0; c < num_of_class; c++) {
-        if (scores[num_of_class * b + c] > max_score) {
-          max_score = scores[num_of_class * b + c];
-          max_idx = c;
-        }
+  auto outputsv = SearchOutputsForSize(session, batch_size * num_of_class);
+  auto scores = reinterpret_cast<float *>(outputsv->MutableData());
+  float accuracy = 0.0;
+  for (int b = 0; b < batch_size; b++) {
+    int max_idx = 0;
+    float max_score = scores[num_of_class * b];
+    for (int c = 0; c < num_of_class; c++) {
+      if (scores[num_of_class * b + c] > max_score) {
+        max_score = scores[num_of_class * b + c];
+        max_idx = c;
       }
-      if (labels[b] == max_idx) accuracy += 1.0;
     }
+    if (labels[b] == max_idx) accuracy += 1.0;
+  }
+  return accuracy/batch_size;
+}
+
+
+// net training function
+float fl_lenet_lite_Inference(const std::string &ms_file, int batch_num, int test_nums) {
+  auto session = GetSession(ms_file, false);
+  char *origin_input[] = {fl_lenet_I0, fl_lenet_I1};
+  float sum_acc = 0.0f;
+  for (int j = 0; j < test_nums; ++j) {
+    auto acc_per_test = CalculateAccuracy(session,batch_num);
+    sum_acc+=acc_per_test;
+    std::cout << "infer:"<< j <<"times,acc is " << acc_per_test << std::endl;
   }
   fl_lenet_I0 = origin_input[0];
   fl_lenet_I1 = origin_input[1];
-  accuracy /= static_cast<float>(batch_size * test_nums);
-  MS_LOG(INFO) << "accuracy  is " << accuracy;
-  return mindspore::lite::RET_OK;
+  return sum_acc/test_nums;
 }
 
+
 // net training function
 int fl_lenet_lite_Train(const std::string &ms_file, const int batch_num, const int iterations) {
   auto session = GetSession(ms_file, true);
@@ -121,15 +127,15 @@ int fl_lenet_lite_Train(const std::string &ms_file, const int batch_num, const i
   }
   MS_LOG(INFO) << "total iterations :" << iterations << "batch_num:" << batch_num;
   char *origin_input[] = {fl_lenet_I0, fl_lenet_I1};
-  float min_loss = 1000.;
-  for (int j = 0; j < iterations; ++j) {
-    FillInputData(session, batch_num, false);
-    session->RunGraph(nullptr, nullptr);
-    float loss = GetLoss(session);
-    if (min_loss > loss) min_loss = loss;
-    if (j % 50 == 0) {
-      MS_LOG(INFO) << "iteration:" << j << ",Loss is" << loss << " [min=" << min_loss << "]";
+  for (int j = 0; j < iterations/batch_num; ++j) {
+    float sum_loss_per_epoch = 0.0f;
+    for(int k=0;k<batch_num;++k) {
+      FillInputData(session, batch_num, false);
+      session->RunGraph(nullptr, nullptr);
+      sum_loss_per_epoch+=GetLoss(session);
     }
+    std::cout << "epoch " << "[" <<j<<"]" << ",mean Loss " << sum_loss_per_epoch/batch_num <<",train acc "<< CalculateAccuracy(session,batch_num) <<std::endl;
+    session->Train();
   }
   session->SaveToFile(ms_file);
   fl_lenet_I0 = origin_input[0];
diff --git a/mindspore/lite/include/train_session.h b/mindspore/lite/include/train_session.h
index f5d3dfb..08e42a7 100644
--- a/mindspore/lite/include/train_session.h
+++ b/mindspore/lite/include/train_session.h
@@ -23,6 +23,13 @@
 namespace mindspore {
 namespace session {
 
+struct TrainFeatureParam{
+  char* name;
+  void *data;
+  size_t elenums;
+  enum TypeId type;
+};
+
 /// \brief TrainSession Defines a class that allows training a MindSpore model
 class TrainSession : public session::LiteSession {
  public:
@@ -137,6 +144,11 @@ class TrainSession : public session::LiteSession {
   /// \param[in] loss_name Identifucation name for loss kernels
   void SetLossName(std::string loss_name) { loss_name_ = loss_name; }
 
+  virtual int GetFeatureMaps(std::vector<mindspore::session::TrainFeatureParam *>* feature_maps) =0;
+
+  virtual int UpdateFeatureMaps(const std::string &update_ms_file,
+                                TrainFeatureParam* new_features,int size) =0;
+
  protected:
   bool train_mode_ = false;
   std::string get_loss_name() const { return loss_name_; }
diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs
index d07261e..0916185 100644
--- a/mindspore/lite/schema/ops.fbs
+++ b/mindspore/lite/schema/ops.fbs
@@ -953,6 +953,14 @@ table StridedSlice {
     shrink_axis_mask: long;
 }
 
+table StridedSliceGrad {
+    begin_mask: long;
+    end_mask: long;
+    ellipsis_mask: long;
+    new_axis_mask: long;
+    shrink_axis_mask: long;
+}
+
 table SubFusion {
     activation_type: ActivationType = 0;
 }
@@ -1056,14 +1064,6 @@ table CropAndResize {
 table Erf {
 }
 
-table StridedSliceGrad {
-    begin_mask: long;
-    end_mask: long;
-    ellipsis_mask: long;
-    new_axis_mask: long;
-    shrink_axis_mask: long;
-}
-
 table IsFinite {
 }
 
@@ -1077,3 +1077,4 @@ table UniformReal {
 
 table AbsGrad {
 }
+
diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc
index badd1fd..a3685cb 100644
--- a/mindspore/lite/src/train/train_session.cc
+++ b/mindspore/lite/src/train/train_session.cc
@@ -22,6 +22,7 @@
 #include <iostream>
 #include <fstream>
 #include <memory>
+#include <cstring>
 #include "include/errorcode.h"
 #include "src/common/utils.h"
 #include "src/tensor.h"
@@ -480,6 +481,63 @@ bool TrainSession::IsBN(kernel::LiteKernel *kernel) const {
           (kernel->Type() == schema::PrimitiveType_FusedBatchNorm));
 }
 
+int lite::TrainSession::GetFeatureMaps(std::vector<mindspore::session::TrainFeatureParam *> *feature_maps) {
+  for (auto tensor : this->tensors_) {
+    if (tensor->IsConst()) {
+      auto param = new mindspore::session::TrainFeatureParam();
+      int len = tensor->tensor_name().length();
+      char* name = nullptr;
+      if(len>0) {
+        name = new char[len+1];
+        memcpy(name, tensor->tensor_name().c_str(), len);
+        name[len] = 0;
+      }
+      param->name =  name;
+      param->data = new float[tensor->ElementsNum()];
+      memcpy(param->data, tensor->data_c(), tensor->ElementsNum()*sizeof(float));
+      param->data = tensor->data_c();
+      param->elenums = tensor->ElementsNum();
+      param->type = tensor->data_type();
+      feature_maps->push_back(param);
+    }
+  }
+  MS_LOG(INFO) << "get feature map success";
+  return RET_OK;
+}
+int lite::TrainSession::UpdateFeatureMaps(const std::string &update_ms_file,
+                                          mindspore::session::TrainFeatureParam* new_features,int size) {
+  std::vector<mindspore::session::TrainFeatureParam *> old_features;
+  auto status = GetFeatureMaps(&old_features);
+  if (status != RET_OK) {
+    MS_LOG(ERROR) << "get features map failed:";
+  }
+  for (int i=0;i<size;++i) {
+    mindspore::session::TrainFeatureParam* new_feature = new_features + i;
+    bool find = false;
+    for (auto old_feature : old_features) {
+      if (strcmp(old_feature->name, new_feature->name) == 0) {
+        if(old_feature->elenums != new_feature->elenums) {
+          MS_LOG(ERROR) << "feature name:"<<old_feature->name<<",len diff:"<<"old is:"<<old_feature->elenums<<"new is:"<<new_feature->elenums;
+          return RET_ERROR;
+        }
+        find = true;
+        memcpy(old_feature->data, new_feature->data, new_feature->elenums*sizeof(float));
+        break;
+      }
+    }
+    if (!find) {
+      MS_LOG(ERROR) << "cannot find feature:" << new_feature->name;
+      return RET_ERROR;
+    }
+  }
+  SaveToFile(update_ms_file);
+  for (auto feature : old_features) {
+    delete feature;
+  }
+  MS_LOG(INFO) << "update model:" << update_ms_file << ",feature map success";
+  return RET_OK;
+}
+
 }  // namespace lite
 
 session::TrainSession *session::TrainSession::CreateSession(const char *model_buf, size_t size, lite::Context *context,
diff --git a/mindspore/lite/src/train/train_session.h b/mindspore/lite/src/train/train_session.h
index 266baef..c69fa5d 100644
--- a/mindspore/lite/src/train/train_session.h
+++ b/mindspore/lite/src/train/train_session.h
@@ -91,6 +91,10 @@ class TrainSession : virtual public session::TrainSession, virtual public lite::
     return outputs;
   }
 
+  int GetFeatureMaps(std::vector<mindspore::session::TrainFeatureParam *> *feature_maps) override;
+  int UpdateFeatureMaps(const std::string &update_ms_file,
+                        mindspore::session::TrainFeatureParam* new_features,int size) override;
+
  protected:
   void AllocWorkSpace();
   bool IsLossKernel(const kernel::LiteKernel *kernel) const;
-- 
2.7.4

