<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

#Data process

In [None]:
From 6c08a8d4eb95796a2c862abeedda2a3bb835c06a Mon Sep 17 00:00:00 2001
From: guohongzilong <guohongzilong@huawei.com>
Date: Thu, 4 Mar 2021 19:42:36 +0800
Subject: [PATCH] fl lite lenet demo

---
 .../java/com/huawei/flclient/FLLiteClient.java     | 572 +++++++++++++++++++++
 .../main/java/com/huawei/flclient/GetModel.java    |  12 +
 .../main/java/com/huawei/flclient/LiteTrain.java   |  83 +++
 .../main/java/com/huawei/flclient/StartFLJob.java  |  43 ++
 .../main/java/com/huawei/flclient/UpdateModel.java |  15 +-
 .../lite/flclient/src/main/native/CMakeLists.txt   |  39 +-
 .../flclient/src/main/native/include/lenet_train.h |  35 ++
 .../flclient/src/main/native/lenet_train_jni.cpp   | 158 ++++++
 .../flclient/src/main/native/src/lenet_train.cpp   | 270 ++++++++++
 .../com/huawei/flclient/test/TestFLClient.java     |   8 +-
 mindspore/lite/include/train_session.h             |  13 +-
 mindspore/lite/src/train/train_session.cc          |  59 ++-
 mindspore/lite/src/train/train_session.h           |   3 +
 13 files changed, 1296 insertions(+), 14 deletions(-)
 create mode 100644 mindspore/lite/flclient/src/main/java/com/huawei/flclient/FLLiteClient.java
 create mode 100644 mindspore/lite/flclient/src/main/java/com/huawei/flclient/LiteTrain.java
 create mode 100644 mindspore/lite/flclient/src/main/native/include/lenet_train.h
 create mode 100644 mindspore/lite/flclient/src/main/native/lenet_train_jni.cpp
 create mode 100644 mindspore/lite/flclient/src/main/native/src/lenet_train.cpp

diff --git a/mindspore/lite/flclient/src/main/java/com/huawei/flclient/FLLiteClient.java b/mindspore/lite/flclient/src/main/java/com/huawei/flclient/FLLiteClient.java
new file mode 100644
index 0000000..f29bc9c
--- /dev/null
+++ b/mindspore/lite/flclient/src/main/java/com/huawei/flclient/FLLiteClient.java
@@ -0,0 +1,572 @@
+/**
+ * Copyright 2020 Huawei Technologies Co., Ltd
+ * <p>
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.huawei.flclient;
+
+import com.huawei.flclient.cipher.BaseUtil;
+import mindspore.schema.*;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.Date;
+import java.util.concurrent.TimeoutException;
+
+public class FLLiteClient {
+    private FLCommunication flCommunication;
+
+    private FLClientStatus status;
+
+    private static int iteration = 0;
+
+    private int iterations;
+    private int epochs;
+    private int batchSize;
+    private int minSecretNum;
+    private byte[] prime;
+    private int featureSize;
+
+    private String fl_name;
+    private String fl_id;
+    private String ip;
+    private int port;
+    private boolean ifAsync;
+    private EncryptLevel encryptLevel;
+    private boolean use_elb;
+    private int serverNum;
+    private int train_batch_num;
+    private String test_dataset;
+    private int test_batch_num;
+
+    private SecureProtocol secureProtocol = new SecureProtocol();
+
+
+    public FLLiteClient(String ip, int port, String fl_name, String fl_id, boolean ifAsync, EncryptLevel encryptLevel, boolean use_elb, int serverNum, int train_batch_num, String test_dataset, int test_batch_num) {
+        flCommunication = FLCommunication.getInstance();
+        try {
+            flCommunication.setTimeOut(100);
+        } catch (TimeoutException e) {
+            e.printStackTrace();
+        }
+        this.ip = ip;
+        this.port = port;
+        this.fl_name = fl_name;
+        this.fl_id = fl_id;
+        this.ifAsync = ifAsync;
+        this.encryptLevel = encryptLevel;
+        this.use_elb = use_elb;
+        this.serverNum = serverNum;
+        this.train_batch_num = train_batch_num;
+        this.test_dataset = test_dataset;
+        this.test_batch_num = test_batch_num;
+
+    }
+
+
+    public void setGlobalParameters(ResponseFLJob flJob) {
+        FLPlan flPlan = flJob.flPlanConfig();
+        if (flPlan != null) {
+            iterations = flPlan.iterations();
+            epochs = flPlan.epochs();
+            batchSize = flPlan.miniBatch();
+//            minSecretNum = 3;
+//            String prime_s = "238586b57e1179feb154e90ace3e1886a36714ca789702ba31fda502b2c4eab4c7";
+//            byte[] prime_b = BaseUtil.hexString2ByteArray(prime_s);
+//            prime = prime_b;
+            CipherPublicParams cipherPublicParams = flPlan.cipher();
+            minSecretNum = cipherPublicParams.t();
+            int primeLength = cipherPublicParams.primeLength();
+            prime = new byte[primeLength];
+            for (int i = 0; i < primeLength; i++) {
+                prime[i] = (byte) cipherPublicParams.prime(i);
+            }
+            System.out.println("[Encrypt] the minSecretNum from server: " + minSecretNum);
+            System.out.println("[Encrypt] the prime from server: " + BaseUtil.byte2HexString(prime));
+        } else{  //todo for test
+            iterations = 5;
+            epochs = 20;
+            batchSize = 32;
+            minSecretNum = 3;
+            String prime_s = "238586b57e1179feb154e90ace3e1886a36714ca789702ba31fda502b2c4eab4c7";
+            byte[] prime_b = BaseUtil.hexString2ByteArray(prime_s);
+            prime = prime_b;
+            featureSize = 66404;
+        }
+    }
+
+    public int getIteration() {
+        return iteration;
+    }
+
+    public int getIterations() {
+        return iterations;
+    }
+
+    public int getEpochs() {
+        return epochs;
+    }
+
+    public int getBatchSize() {
+        return batchSize;
+    }
+
+    public FLClientStatus checkStatus() {
+        return this.status;
+    }
+
+    public void syncFLJob(int epoch, int batch_num) {
+        String url = Common.generateUrl(use_elb, ip, port, serverNum);
+        // 1. verify server
+        {
+            System.out.println("[startFLJob] ========Verify server========");
+            StartFLJob startFLJob = StartFLJob.getInstance();
+            byte[] msg = startFLJob.getRequestStartFLJob(fl_name, fl_id, iteration);
+            byte[] message = new byte[0];
+            try {
+                message = flCommunication.syncRequest(url + "/startFLJob", msg);
+            } catch (IOException e) {
+                e.printStackTrace();
+            }
+            ByteBuffer buffer = ByteBuffer.wrap(message);
+            ResponseFLJob flJob = ResponseFLJob.getRootAsResponseFLJob(buffer);
+            iteration = flJob.iteration();
+            FLClientStatus status = startFLJob.doResponse(fl_name,flJob);
+            if (status != FLClientStatus.SUCCESS) {
+                throw new RuntimeException();
+            }
+        }
+
+        // 2. update model, and push features map to server
+        {
+            System.out.println("[Train] ===========Training===========");
+            LiteTrain train = LiteTrain.getInstance();
+            train.train(fl_name,batch_num,epoch);
+            UpdateModel updateModel = UpdateModel.getInstance();
+            byte[] updateModelBuffer = updateModel.getRequestUpdateFLJob(encryptLevel, fl_name, fl_id, iteration, secureProtocol);
+            byte[] message = new byte[0];
+            try {
+                message = flCommunication.syncRequest(url + "/updateModel", updateModelBuffer);
+            } catch (IOException e) {
+                e.printStackTrace();
+            }
+            ByteBuffer debugBuffer = ByteBuffer.wrap(message);
+            ResponseUpdateModel response = ResponseUpdateModel.getRootAsResponseUpdateModel(debugBuffer);
+            FLClientStatus status = updateModel.doResponse(response);
+            if (status != FLClientStatus.SUCCESS) {
+                throw new RuntimeException();
+            }
+        }
+
+        // 3. update featuresMap
+        {
+            System.out.println("[getModel] ===========getModel=============");
+            GetModel getModel = GetModel.getInstance();
+            byte[] buffer = getModel.getRequestGetModel(fl_name, iterations);
+            byte[] message = new byte[0];
+            try {
+                message = flCommunication.syncRequest(url + "/getModel", buffer);
+            } catch (IOException e) {
+                e.printStackTrace();
+            }
+            ByteBuffer debugBuffer = ByteBuffer.wrap(message);
+            ResponseGetModel rgm = ResponseGetModel.getRootAsResponseGetModel(debugBuffer);
+            FLClientStatus status = getModel.doResponse(fl_name,rgm);
+            if (status != FLClientStatus.SUCCESS) {
+                throw new RuntimeException();
+            }
+        }
+    }
+
+    public FLClientStatus startFLJob() {
+        System.out.println("[startFLJob] ====================================Verify server====================================");
+        String url = Common.generateUrl(use_elb, ip, port, serverNum);
+        StartFLJob startFLJob = StartFLJob.getInstance();
+        byte[] msg = startFLJob.getRequestStartFLJob(fl_name, fl_id, iteration);
+        if (ifAsync) {
+            try {
+                flCommunication.asyncRequest(url + "/startFLJob", msg, new IAsyncCallBack() {
+
+                    @Override
+                    public FLClientStatus onResponse(byte[] msg) {
+                        ByteBuffer buffer = ByteBuffer.wrap(msg);
+                        ResponseFLJob flJob = ResponseFLJob.getRootAsResponseFLJob(buffer);
+                        FLClientStatus status = startFLJob.doResponse(flJob);
+                        FLLiteClient.this.status = status;
+                        status = updateModel();
+                        if (status != FLClientStatus.SUCCESS) {
+                            throw new RuntimeException();
+                        }
+                        return FLClientStatus.SUCCESS;
+                    }
+
+                    @Override
+                    public FLClientStatus onFailure(IOException exception) {
+                        exception.printStackTrace();
+                        return FLClientStatus.FAILED;
+                    }
+                });
+            } catch (Exception e) {
+                e.printStackTrace();
+            }
+        } else {
+            try {
+                byte[] message = flCommunication.syncRequest(url + "/startFLJob", msg);
+                ByteBuffer buffer = ByteBuffer.wrap(message);
+                ResponseFLJob responseDataBuf = ResponseFLJob.getRootAsResponseFLJob(buffer);
+                status = judgeStartFLJob(startFLJob, responseDataBuf);
+                return status;
+
+            } catch (IOException e) {
+                System.out.println("[startFLJob] un sloved error code in StartFLJob");
+                e.printStackTrace();
+                status = FLClientStatus.FAILED;
+                return status;
+            }
+        }
+        return FLClientStatus.SUCCESS;
+    }
+
+    public FLClientStatus judgeStartFLJob(StartFLJob startFLJob, ResponseFLJob responseDataBuf) {
+        iteration = responseDataBuf.iteration();
+        FLClientStatus response = startFLJob.doResponse(fl_name,responseDataBuf);
+        FLClientStatus curStatus;
+        if (response == FLClientStatus.SUCCESS) {
+            featureSize = startFLJob.getFeatureSize();    // todo get feature size
+            System.out.println("[startFLJob] ***the feature size get in ResponseFLJob***: "+featureSize);
+            setGlobalParameters(responseDataBuf);
+            return FLClientStatus.SUCCESS;
+        } else if (response == FLClientStatus.WAIT) {
+//            long waitTime = getWaitTime(responseDataBuf.nextReqTime());
+//            sleep(waitTime);
+//            curStatus = startFLJob();
+//            return curStatus;
+            return FLClientStatus.FAILED;
+        } else {
+            return FLClientStatus.FAILED;
+        }
+
+    }
+
+    public FLClientStatus startUpdateModel(FLClientStatus response) {
+        FLClientStatus curStatus;
+        if (response == FLClientStatus.SUCCESS) {
+            System.out.println("[startFLJob] startFLJob succeed");
+            System.out.println("[train] ========================================================global train epoch "+iteration+"========================================================");
+
+            // todo zs : need add: Synchronously start the <encryption part> and <train process>.
+            if (encryptLevel == EncryptLevel.PWEncrypt) {
+                curStatus = getFeatureMask();
+                if (curStatus != FLClientStatus.SUCCESS) {
+                    return FLClientStatus.FAILED;
+                }
+                System.out.println("[Encrypt] create mask for <" + encryptLevel.toString() + ">" + " ok!");
+            } else if (encryptLevel == EncryptLevel.DPEncrypt) {
+                // TODO jxl set parameters
+                System.out.println("[Encrypt] set parameters for DPEncrypt!");
+            } else if (encryptLevel == EncryptLevel.NotEncrypt){
+                System.out.println("[Encrypt] don't mask model");
+            } else {
+                System.out.println("[Encrypt] The encrypt level is error!");
+            }
+
+            curStatus = updateModel();
+            return curStatus;
+        } else {
+            System.out.println("[startFLJob] startFLJob failed");
+            return FLClientStatus.FAILED;
+        }
+    }
+
+
+    public FLClientStatus updateModel() {
+        String url = Common.generateUrl(use_elb, ip, port, serverNum);
+        System.out.println("[updateModel] ==============updateModel url: "+url+"==============");
+        LiteTrain train = LiteTrain.getInstance();
+        train.train(fl_name,train_batch_num, epochs * train_batch_num);
+        UpdateModel updateModelBuf = UpdateModel.getInstance();
+        byte[] updateModelBuffer = updateModelBuf.getRequestUpdateFLJob(encryptLevel, fl_name, fl_id, iteration, secureProtocol);
+        if (ifAsync) {
+            try {
+                flCommunication.asyncRequest(url + "/updateModel", updateModelBuffer, new IAsyncCallBack() {
+                    @Override
+                    public FLClientStatus onResponse(byte[] msg) {
+                        ByteBuffer debugBuffer = ByteBuffer.wrap(msg);
+                        ResponseUpdateModel updateModel = ResponseUpdateModel.getRootAsResponseUpdateModel(debugBuffer);
+                        //Debug code
+                        System.out.println("[updateModel] ==========update model content is:================");
+                        System.out.println("[updateModel] ==========fl name: " + updateModel.retcode());
+                        System.out.println("[updateModel] ==========reason: " + updateModel.reason());
+                        System.out.println("[updateModel] ==========time: " + updateModel.timestemp());
+                        FLClientStatus status = getModel();
+                        if (status != FLClientStatus.SUCCESS) {
+                            throw new RuntimeException();
+                        }
+                        return FLClientStatus.SUCCESS;
+                    }
+
+                    @Override
+                    public FLClientStatus onFailure(IOException exception) {
+                        exception.printStackTrace();
+                        return FLClientStatus.FAILED;
+                    }
+                });
+            } catch (Exception e) {
+                e.printStackTrace();
+            }
+        } else {
+            try {
+                byte[] message = flCommunication.syncRequest(url + "/updateModel", updateModelBuffer);
+                ByteBuffer debugBuffer = ByteBuffer.wrap(message);
+                ResponseUpdateModel responseDataBuf = ResponseUpdateModel.getRootAsResponseUpdateModel(debugBuffer);
+                status = judgeUpdateModel(updateModelBuf, responseDataBuf);
+                return status;
+
+            } catch (IOException e) {
+                System.out.println("[updateModel] un sloved error code in updateModel");
+                e.printStackTrace();
+                status = FLClientStatus.FAILED;
+                return status;
+            }
+        }
+        return FLClientStatus.SUCCESS;
+    }
+
+    public FLClientStatus judgeUpdateModel(UpdateModel updateModelBuf, ResponseUpdateModel responseDataBuf) {
+        FLClientStatus response = updateModelBuf.doResponse(responseDataBuf);
+        FLClientStatus curStatus;
+        System.out.println("[updateModel] response updateModel ok!");
+        if (response == FLClientStatus.SUCCESS) {
+
+            if (encryptLevel == EncryptLevel.PWEncrypt) {
+                curStatus = unMasking();
+                if (curStatus != FLClientStatus.SUCCESS) {
+                    return FLClientStatus.FAILED;
+                }
+                System.out.println("[Encrypt] pairwise unmasking ok!");
+            } else if (encryptLevel == EncryptLevel.DPEncrypt) {
+                System.out.println("[Encrypt] DPEncrypt don't need unmasking!");
+            } else if (encryptLevel == EncryptLevel.NotEncrypt){
+                System.out.println("[Encrypt] don't mask model");
+            } else {
+                System.out.println("[Encrypt] The encrypt level is error!");
+            }
+
+            curStatus = getModel();
+            return curStatus;
+        } else if (response == FLClientStatus.WAIT) {
+//            long waitTime = getWaitTime(responseDataBuf.nextReqTime());
+//            sleep(waitTime);
+//            FLClientStatus curResponse = startFLJob();
+//            curStatus = startUpdateModel(curResponse);
+//            return curStatus;
+            return FLClientStatus.FAILED;
+        } else {
+            System.out.println("[updateModel] updateModel failed");
+            return FLClientStatus.FAILED;
+        }
+
+    }
+
+
+
+    public FLClientStatus getModel() {
+        String url = Common.generateUrl(use_elb, ip, port, serverNum);
+        System.out.println("[getModel] ===========getModel url: "+url+"==============");
+        GetModel getModelBuf = GetModel.getInstance();
+        byte[] buffer = getModelBuf.getRequestGetModel(fl_name, iteration);
+        if (ifAsync) {
+            try {
+
+                flCommunication.asyncRequest(url + "/getModel", buffer, new IAsyncCallBack() {
+                    @Override
+                    public FLClientStatus onResponse(byte[] msg) {
+                        ByteBuffer debugBuffer = ByteBuffer.wrap(msg);
+                        ResponseGetModel rgm = ResponseGetModel.getRootAsResponseGetModel(debugBuffer);
+                        getModelBuf.doResponse(rgm);
+                        return FLClientStatus.SUCCESS;
+                    }
+
+                    @Override
+                    public FLClientStatus onFailure(IOException exception) {
+                        exception.printStackTrace();
+                        return FLClientStatus.FAILED;
+                    }
+                });
+            } catch (Exception e) {
+                e.printStackTrace();
+                return FLClientStatus.FAILED;
+            }
+        } else {
+            try {
+                byte[] message = flCommunication.syncRequest(url + "/getModel", buffer);
+                System.out.println("[getModel] get model request success");
+                ByteBuffer debugBuffer = ByteBuffer.wrap(message);
+                ResponseGetModel responseDataBuf = ResponseGetModel.getRootAsResponseGetModel(debugBuffer);
+                status = judgeGetModel(getModelBuf, responseDataBuf);
+                return status;
+            } catch (IOException e) {
+                System.out.println("[getModel] un sloved error code in getModel");
+                e.printStackTrace();
+                status = FLClientStatus.FAILED;
+                return FLClientStatus.FAILED;
+            }
+
+        }
+        return FLClientStatus.SUCCESS;
+    }
+
+    public FLClientStatus judgeGetModel(GetModel getModelBuf, ResponseGetModel responseDataBuf) {
+        //Debug code
+        System.out.println("[getModel] ==========get model content is:================");
+        System.out.println("[getModel] ==========retcode: " + responseDataBuf.retcode());
+        System.out.println("[getModel] ==========reason: " + responseDataBuf.reason());
+        System.out.println("[getModel] ==========iteration: " + responseDataBuf.iteration());
+        System.out.println("[getModel] ==========time: " + responseDataBuf.timestemp());
+        int retcode = responseDataBuf.retcode();
+        FLClientStatus curStatus;
+        switch (retcode) {
+            case (ResponseCode.SUCCEED):
+                FLClientStatus status = getModelBuf.doResponse(fl_name,responseDataBuf);
+                if (status != FLClientStatus.SUCCESS) {
+                    System.out.println("[getModel] catch error in getModel.doResponse");
+                    return FLClientStatus.FAILED;
+                }
+                System.out.println("[test] ==============test combine model==============");
+                testCombineModel();
+                return FLClientStatus.SUCCESS;
+            case (ResponseCode.SucNotReady):
+                // todo, server need add next_req_time in ResponseGetModel
+                sleep(200);
+                curStatus = getModel();
+                return curStatus;
+            case (ResponseCode.RequestError):
+            case (ResponseCode.SystemError):
+                System.out.println("[getModel] catch RequestError or SystemError");
+                return FLClientStatus.FAILED;
+            default:
+                return FLClientStatus.FAILED;
+
+        }
+    }
+
+    public FLClientStatus getFeatureMask() {
+        System.out.println("[Encrypt] creating feature mask of <" + encryptLevel.toString() + ">");
+        secureProtocol.setPWParameter(fl_id, iteration, ip, port, minSecretNum, prime, featureSize, ifAsync, use_elb, serverNum);
+        secureProtocol.pwCreateMask();
+        FLClientStatus response = secureProtocol.getStatus();
+        FLClientStatus curStatus;
+        if (response == FLClientStatus.WAIT) {
+//            System.out.println("Create feature mask OutOfTime, need wait and request again");
+//            long waitTime = getWaitTime(secureProtocol.getNextRequestTime());
+//            sleep(waitTime);
+//            FLClientStatus curResponse = startFLJob();
+//            curStatus = startUpdateModel(curResponse);
+//            status = curStatus;
+//            return curStatus;
+            return FLClientStatus.FAILED;
+        } else if (response == FLClientStatus.SUCCESS) {
+            System.out.println("[Encrypt] Create feature mask succeed");
+            status = FLClientStatus.SUCCESS;
+            return FLClientStatus.SUCCESS;
+        } else {
+            System.out.println("[Encrypt] Create feature mask failed");
+            status = FLClientStatus.FAILED;
+            return FLClientStatus.FAILED;
+        }
+    }
+
+    public  FLClientStatus unMasking() {
+        FLClientStatus curStatus;
+        secureProtocol.pwUnmasking();
+        FLClientStatus response = secureProtocol.getStatus();
+        if (response == FLClientStatus.WAIT) {
+//            System.out.println("unmasking OutOfTime, need wait and request again");
+//            long waitTime = getWaitTime(secureProtocol.getNextRequestTime());
+//            sleep(waitTime);
+//            FLClientStatus curResponse = startFLJob();
+//            curStatus = startUpdateModel(curResponse);
+//            status = curStatus;
+//            return curStatus;
+            return FLClientStatus.FAILED;
+        }else if (response == FLClientStatus.SUCCESS) {
+            System.out.println("[Encrypt] unmasking succeed");
+            status = FLClientStatus.SUCCESS;
+            return FLClientStatus.SUCCESS;
+        } else {
+            System.out.println("[Encrypt] unmasking failed");
+            status = FLClientStatus.FAILED;
+            return FLClientStatus.FAILED;
+        }
+    }
+
+
+    public void sleep(long millis) {
+        try {
+            Thread.sleep(millis);                 //1000 milliseconds is one second.
+        } catch(InterruptedException ex) {
+            Thread.currentThread().interrupt();
+        }
+    }
+
+    public long getWaitTime(String nextRequestTime) {
+        Date date = new Date();
+        long currentTime = date.getTime();
+        long waitTime = Long.valueOf(nextRequestTime) - currentTime;
+        return waitTime;
+    }
+
+    public void testCombineModel(){
+        setInput(test_dataset, test_batch_num);
+        evaluate();
+    }
+
+
+    /**
+     * init runtime resource only needs to be called once per client
+     */
+    public void initRuntimeResource() {
+//        Train train = Train.getInstance();
+//        train.prepare();
+    }
+
+    /**
+     * @param dataset,  train or test dataset and label set
+     * @param batch_num
+     */
+    public void setInput(String dataset, int batch_num) {
+        System.out.println("==========set input===========");
+        LiteTrain train = LiteTrain.getInstance();
+        train.setInput(dataset, batch_num);
+    }
+
+    /**
+     * evalute trained model performance
+     */
+    public void evaluate() {
+        System.out.println("===========evaluate=============");
+        LiteTrain train = LiteTrain.getInstance();
+        int status = train.inference(fl_name,32,1);
+        if (status != FLClientStatus.SUCCESS.ordinal()) {
+            System.out.println("inference failed");
+            throw new RuntimeException();
+        }
+    }
+
+    @Override
+    protected void finalize() {
+        LiteTrain train = LiteTrain.getInstance();
+        train.free();
+    }
+}
diff --git a/mindspore/lite/flclient/src/main/java/com/huawei/flclient/GetModel.java b/mindspore/lite/flclient/src/main/java/com/huawei/flclient/GetModel.java
index 872d7de..eb19c50 100644
--- a/mindspore/lite/flclient/src/main/java/com/huawei/flclient/GetModel.java
+++ b/mindspore/lite/flclient/src/main/java/com/huawei/flclient/GetModel.java
@@ -78,4 +78,16 @@ public class GetModel {
         train.updateFeatures(fms);
         return FLClientStatus.SUCCESS;
     }
+    public FLClientStatus doResponse(String modelName,ResponseGetModel getModel) {
+        int num = getModel.featureMapLength();
+        ArrayList<FeatureMap> fms = new ArrayList<FeatureMap>();
+        for (int i = 0; i < num; i++) {
+            FeatureMap feature = getModel.featureMap(i);
+            fms.add(feature);
+            System.out.println("get [" + i + "] " + feature.weightFullname() + ", elenums: " + feature.dataLength());
+        }
+        LiteTrain train = LiteTrain.getInstance();
+        train.updateFeatures(modelName,fms);
+        return FLClientStatus.SUCCESS;
+    }
 }
diff --git a/mindspore/lite/flclient/src/main/java/com/huawei/flclient/LiteTrain.java b/mindspore/lite/flclient/src/main/java/com/huawei/flclient/LiteTrain.java
new file mode 100644
index 0000000..877ca40
--- /dev/null
+++ b/mindspore/lite/flclient/src/main/java/com/huawei/flclient/LiteTrain.java
@@ -0,0 +1,83 @@
+/**
+ * Copyright 2020 Huawei Technologies Co., Ltd
+ * <p>
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.huawei.flclient;
+
+import com.google.flatbuffers.FlatBufferBuilder;
+import mindspore.schema.FeatureMap;
+import java.util.ArrayList;
+
+public  class LiteTrain {
+    static {
+        System.loadLibrary("fl");
+    }
+
+    private static LiteTrain train;
+
+    private LiteTrain() {
+    }
+    public static synchronized LiteTrain getInstance() {
+        if (train == null) {
+            train = new LiteTrain();
+        }
+        return train;
+    }
+    /**
+     * set the Inference set or Train set
+     *
+     * @param fileSet   input binary file path which format is NHWC
+     * @param batch_num binary file batch num
+     * @return
+     */
+    native int setInput(String fileSet, int num);
+
+    /**
+     * inference
+     *
+     * @return status
+     */
+    public native int inference(String modelName,int batch_num,int test_nums);
+
+    /**
+     * train
+     *
+     * @return status
+     */
+    public native int train(String modelName, int batch_num,int iterations);
+
+    /**
+     * get the features map of training model
+     *
+     * @param builder FlatBufferBuilder
+     * @return features offset
+     */
+    native int[] getFeaturesMap(String modelName,FlatBufferBuilder builder);
+
+    /**
+     * update the features map of training model
+     *
+     * @param featureMaps
+     * @return status
+     */
+    native int updateFeatures(String modelName,ArrayList<FeatureMap> featureMaps);
+
+    /**
+     * free Inference or Train runtime memory resource
+     *
+     * @return status
+     */
+    native int free();
+}
diff --git a/mindspore/lite/flclient/src/main/java/com/huawei/flclient/StartFLJob.java b/mindspore/lite/flclient/src/main/java/com/huawei/flclient/StartFLJob.java
index 55641f1..3db5e42 100644
--- a/mindspore/lite/flclient/src/main/java/com/huawei/flclient/StartFLJob.java
+++ b/mindspore/lite/flclient/src/main/java/com/huawei/flclient/StartFLJob.java
@@ -105,6 +105,21 @@ public class StartFLJob {
         return FLClientStatus.SUCCESS;
     }
 
+    private FLClientStatus parseResponse(String modelName,ResponseFLJob flJob) {
+        int fmCount = flJob.featureMapLength();
+        ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
+        for (int i = 0; i < fmCount; i++) {
+            FeatureMap feature = flJob.featureMap(i);
+            featureMaps.add(feature);
+            System.out.println("get [" + i + "] " + feature.weightFullname() + ", elenums: " + feature.dataLength());
+            featureSize += feature.dataLength();
+        }
+        if (true) {
+            LiteTrain train = LiteTrain.getInstance();
+            train.updateFeatures(modelName,featureMaps);       // load featureMaps to model
+        }
+        return FLClientStatus.SUCCESS;
+    }
     public FLClientStatus doResponse(ResponseFLJob flJob) {
         System.out.println("[startFLJob] return code: " + flJob.retcode());
         System.out.println("[startFLJob] reason: " + flJob.reason());
@@ -133,6 +148,34 @@ public class StartFLJob {
         }
     }
 
+    public FLClientStatus doResponse(String modelName,ResponseFLJob flJob) {
+        System.out.println("[startFLJob] return code: " + flJob.retcode());
+        System.out.println("[startFLJob] reason: " + flJob.reason());
+        System.out.println("[startFLJob] iteration: " + flJob.iteration());
+        System.out.println("[startFLJob] is selected: " + flJob.isSelected());
+        System.out.println("[startFLJob] next request time: " + flJob.nextReqTime());
+
+        // skip mind ir
+        System.out.println("[startFLJob] timestamp: " + flJob.timestemp());
+        int retcode = flJob.retcode();
+
+        switch (retcode) {
+            case (ResponseCode.SUCCEED):
+                parseResponse(modelName,flJob);
+                return FLClientStatus.SUCCESS;
+            case (ResponseCode.OutOfTime):
+//            case (ResponseCode.NotSelected):     // todo: need add to fl_job.fbs
+                return FLClientStatus.WAIT;
+            case (ResponseCode.SucNotMatch):
+            case (ResponseCode.SucNotReady):
+                System.out.println("[startFLJob] catch RequestError or SystemError");
+                return FLClientStatus.FAILED;
+            default:
+                System.out.println("[startFLJob] nresolved error code");
+                return FLClientStatus.FAILED;
+        }
+    }
+
     public FLClientStatus getStatus() {
         return this.status;
     }
diff --git a/mindspore/lite/flclient/src/main/java/com/huawei/flclient/UpdateModel.java b/mindspore/lite/flclient/src/main/java/com/huawei/flclient/UpdateModel.java
index 0b886c7..d3e3a36 100644
--- a/mindspore/lite/flclient/src/main/java/com/huawei/flclient/UpdateModel.java
+++ b/mindspore/lite/flclient/src/main/java/com/huawei/flclient/UpdateModel.java
@@ -66,9 +66,16 @@ public class UpdateModel {
             return this;
         }
 
-        public RequestUpdateModelBuilder featuresmap() {
-            Train train = Train.getInstance();
-            int[] fmOffsets = train.getFeaturesMap(this.builder);
+//        public RequestUpdateModelBuilder featuresmap(String modelName) {
+//            LiteTrain train = LiteTrain.getInstance();
+//            int[] fmOffsets = train.getFeaturesMap(modelName,this.builder);
+//            this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsets);
+//            return this;
+//        }
+
+        public RequestUpdateModelBuilder featuresmap(String modelName) {
+            LiteTrain train = LiteTrain.getInstance();
+            int[] fmOffsets = train.getFeaturesMap(modelName,this.builder);
             this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsets);
             return this;
         }
@@ -118,7 +125,7 @@ public class UpdateModel {
 
     public byte[] getRequestUpdateFLJob(EncryptLevel encryptLevel, String name, String id, int iteration, SecureProtocol secureProtocol) {
         RequestUpdateModelBuilder builder = new RequestUpdateModelBuilder(encryptLevel);
-        return builder.flName(name).time().id(id).featuresmap().iteration(iteration).build(secureProtocol);
+        return builder.flName(name).time().id(id).featuresmap(name).iteration(iteration).build(secureProtocol);
     }
 
     public FLClientStatus doResponse(ResponseUpdateModel response) {
diff --git a/mindspore/lite/flclient/src/main/native/CMakeLists.txt b/mindspore/lite/flclient/src/main/native/CMakeLists.txt
index 80f9931..4a3938c 100644
--- a/mindspore/lite/flclient/src/main/native/CMakeLists.txt
+++ b/mindspore/lite/flclient/src/main/native/CMakeLists.txt
@@ -2,14 +2,33 @@ cmake_minimum_required(VERSION 3.14)
 
 project(FederalLearning)
 
+option(SUPPORT_GPU "if support gpu" off)
+set(BUILD_LITE "on")
+set(SUPPORT_TRAIN "on")
+set(PLATFORM_ARM "on")
+
+set(TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../../../../)
+set(LITE_DIR ${TOP_DIR}/mindspore/lite)
+set(MS_VERSION_MAJOR ${MS_VERSION_MAJOR})
+set(MS_VERSION_MINOR ${MS_VERSION_MINOR})
+set(MS_VERSION_REVISION ${MS_VERSION_REVISION})
+set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
 set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
 
 include_directories(${CMAKE_CURRENT_SOURCE_DIR})
 include_directories(${CMAKE_CURRENT_SOURCE_DIR}/linux)
 include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
 include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src)
+if (ENABLE_MICRO)
 include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/runtime)
+endif()
+include_directories(${LITE_DIR}) ## lite include
+include_directories(${TOP_DIR}) ## api include
+include_directories(${TOP_DIR}/mindspore/core/) ## core include
+include_directories(${LITE_DIR}/build) ## flatbuffers
 
+if (ENABLE_MICRO)
 set(OP_SRC
     src/nnacl/arithmetic_common.c
     src/nnacl/common_func.c
@@ -43,11 +62,23 @@ set(OP_SRC
     src/fl_lenet.c
     src/weight_files/fl_lenet_weight_epoch_0.c
 )
-
+else()
+    set(OP_SRC
+            src/lenet_train.cpp
+            )
+    endif()
+if (ENABLE_MICRO)
+    set(SRC_FILES
+            flearning.cpp)
+    else()
 set(SRC_FILES
-  flearning.cpp
-)
+        lenet_train_jni.cpp
+        )
+endif()
+find_library(log-lib glog)
 
-link_directories(${CMAKE_CURRENT_SOURCE_DIR}/lib/)
 add_library(fl SHARED ${SRC_FILES} ${OP_SRC})
+target_link_libraries(fl mindspore-lite  glog)
+link_directories(${CMAKE_CURRENT_SOURCE_DIR}/lib/)
+
 install(TARGETS fl LIBRARY DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/lib)
\ No newline at end of file
diff --git a/mindspore/lite/flclient/src/main/native/include/lenet_train.h b/mindspore/lite/flclient/src/main/native/include/lenet_train.h
new file mode 100644
index 0000000..5108e65
--- /dev/null
+++ b/mindspore/lite/flclient/src/main/native/include/lenet_train.h
@@ -0,0 +1,35 @@
+/**
+ * Copyright 2020 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef MSLITE_FL_LITE_LENET_H
+#define MSLITE_FL_LITE_LENET_H
+
+#include <string>
+#include "include/train_session.h"
+
+using mindspore::session::TrainFeatureParam;
+
+int fl_lenet_lite_Train(const std::string &ms_file,const int batch_num, const int iterations);
+
+int fl_lenet_lite_Inference(const std::string &ms_file,int batch_num,int test_nums);
+
+int fl_lenet_lite_GetFeatures(const std::string &update_ms_file,mindspore::session::TrainFeatureParam *** features,int* size);
+int fl_lenet_lite_UpdateFeatures(const std::string &update_ms_file,
+                                 TrainFeatureParam* new_features,int size);
+mindspore::session::TrainSession * GetSession(const std::string& ms_file,bool train_mode=false);
+
+int fl_lenet_lite_SetInputs(const std::string &files, int num);
+#endif  // MSLITE_FL_LITE_LENET_H
diff --git a/mindspore/lite/flclient/src/main/native/lenet_train_jni.cpp b/mindspore/lite/flclient/src/main/native/lenet_train_jni.cpp
new file mode 100644
index 0000000..617c8cb
--- /dev/null
+++ b/mindspore/lite/flclient/src/main/native/lenet_train_jni.cpp
@@ -0,0 +1,158 @@
+/**
+ * Copyright 2020 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <jni.h>
+#include "include/train_session.h"
+#include "include/errorcode.h"
+#include "lenet_train.h"
+#include <cstring>
+#include "src/common/log_adapter.h"
+
+static jobject fbb;
+static jmethodID create_string_char;
+
+char *JstringToChar(JNIEnv *env, jstring jstr) {
+  char *rtn = nullptr;
+  jclass clsstring = env->FindClass("java/lang/String");
+  jstring strencode = env->NewStringUTF("GB2312");
+  jmethodID mid = env->GetMethodID(clsstring, "getBytes", "(Ljava/lang/String;)[B");
+  jbyteArray barr = (jbyteArray)env->CallObjectMethod(jstr, mid, strencode);
+  jsize alen = env->GetArrayLength(barr);
+  jbyte *ba = env->GetByteArrayElements(barr, JNI_FALSE);
+  if (alen > 0) {
+    rtn = new char[alen + 1];
+    memcpy(rtn, ba, alen);
+    rtn[alen] = 0;
+  }
+  env->ReleaseByteArrayElements(barr, ba, 0);
+  return rtn;
+}
+extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_LiteTrain_train(JNIEnv *env, jobject thiz, jstring ms_file,
+                                                                           jint batch_num, jint iterations) {
+  return fl_lenet_lite_Train(JstringToChar(env, ms_file), batch_num, iterations);
+}
+
+extern "C" jint CreateFeatureMap(JNIEnv *env, const char *name, float *data, size_t size) {
+  jstring name1 = env->NewStringUTF(name);
+  jint name_offset = env->CallIntMethod(fbb, create_string_char, name1);
+  // 1. set data size
+  jfloatArray ret = env->NewFloatArray(size);
+  env->SetFloatArrayRegion(ret, 0, size, data);
+  // 2. get methodid createDataVector
+  jclass fm_cls = env->FindClass("mindspore/schema/FeatureMap");
+  jmethodID createDataVector =
+    env->GetStaticMethodID(fm_cls, "createDataVector", "(Lcom/google/flatbuffers/FlatBufferBuilder;[F)I");
+  // 3. calc data offset
+  jint data_offset = env->CallStaticIntMethod(fm_cls, createDataVector, fbb, ret);
+  jmethodID createFeatureMap =
+    env->GetStaticMethodID(fm_cls, "createFeatureMap", "(Lcom/google/flatbuffers/FlatBufferBuilder;II)I");
+  jint fm_offset = env->CallStaticIntMethod(fm_cls, createFeatureMap, fbb, name_offset, data_offset);
+  return fm_offset;
+}
+
+extern "C" JNIEXPORT jintArray JNICALL Java_com_huawei_flclient_LiteTrain_getFeaturesMap(JNIEnv *env, jobject thiz,
+                                                                                         jstring ms_file,
+                                                                                         jobject builder) {
+  fbb = builder;
+  jclass fb_clazz = env->GetObjectClass(builder);
+  create_string_char = env->GetMethodID(fb_clazz, "createString", "(Ljava/lang/CharSequence;)I");
+  TrainFeatureParam **train_features = nullptr;
+  int feature_size = 0;
+  auto status = fl_lenet_lite_GetFeatures(JstringToChar(env, ms_file), &train_features, &feature_size);
+  if(status != mindspore::lite::RET_OK) {
+    MS_LOG(ERROR) << "get features failed:" << ms_file;
+    return env->NewIntArray(0);
+  }
+  jintArray ret = env->NewIntArray(feature_size);
+  jint *data = env->GetIntArrayElements(ret, NULL);
+
+  for (int i = 0; i < feature_size; i++) {
+        data[i] =
+          CreateFeatureMap(env, train_features[i]->name, (float *)train_features[i]->data,
+          train_features[i]->elenums);
+        MS_LOG(INFO) << "upload feature:"<< ", name:" << train_features[i]->name << ", elenums:" <<
+        train_features[i]->elenums;
+  }
+    env->ReleaseIntArrayElements(ret, data, 0);
+    for (int i = 0; i < feature_size; i++) {
+      delete  train_features[i];
+    }
+  return ret;
+}
+
+extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_LiteTrain_updateFeatures(JNIEnv *env, jobject,
+                                                                                    jstring ms_file, jobject features) {
+  jclass arr_cls = env->GetObjectClass(features);
+  jmethodID size_method = env->GetMethodID(arr_cls, "size", "()I");
+  jmethodID get_method = env->GetMethodID(arr_cls, "get", "(I)Ljava/lang/Object;");
+
+  jclass fm_cls = env->FindClass("mindspore/schema/FeatureMap");
+  jmethodID weight_name_method = env->GetMethodID(fm_cls, "weightFullname", "()Ljava/lang/String;");
+  jmethodID data_length_method = env->GetMethodID(fm_cls, "dataLength", "()I");
+  jmethodID data_method = env->GetMethodID(fm_cls, "data", "(I)F");
+  jclass clsstring = env->FindClass("java/lang/String");
+  jmethodID mid = env->GetMethodID(clsstring, "getBytes", "(Ljava/lang/String;)[B");
+  int size = env->CallIntMethod(features, size_method);
+  // transform FeatureMap to TrainFeatureParm
+  TrainFeatureParam *features_param = (TrainFeatureParam *)malloc(size * sizeof(TrainFeatureParam));
+  for (int i = 0; i < size; ++i) {
+    TrainFeatureParam *param = features_param + i;
+    jobject feature = env->CallObjectMethod(features, get_method, i);
+    // set feature_param name
+    jstring weight_full_name = (jstring)env->CallObjectMethod(feature, weight_name_method);
+    jstring strencode = env->NewStringUTF("GB2312");
+    jbyteArray barr = (jbyteArray)env->CallObjectMethod(weight_full_name, mid, strencode);
+    char *name = nullptr;
+    jsize alen = env->GetArrayLength(barr);
+    jbyte *ba = env->GetByteArrayElements(barr, JNI_FALSE);
+    if (alen > 0) {
+      name = new char[alen + 1];
+      if (ba == nullptr) {
+        MS_LOG(ERROR) << "name is nullptr";
+        return mindspore::lite::RET_ERROR;
+      }
+      memcpy(name, ba, alen);
+      name[alen] = 0;
+    }
+    param->name = name;
+    env->ReleaseByteArrayElements(barr, ba, 0);
+    int data_length = env->CallIntMethod(feature, data_length_method);
+    float *data = static_cast<float *>(malloc(data_length * sizeof(float)));
+    memset(data, 0, data_length * sizeof(float));
+    for (int j = 0; j < data_length; ++j) {
+      float *addr = data + j;
+      *addr = env->CallFloatMethod(feature, data_method, j);
+    }
+    param->data = data;
+    param->elenums = data_length;
+    param->type = mindspore::kNumberTypeFloat32;
+    MS_LOG(INFO) << "get feature:" << param->name << ",elenums:" << param->elenums;
+  }
+  return fl_lenet_lite_UpdateFeatures(JstringToChar(env, ms_file), features_param, size);
+}
+
+extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_LiteTrain_setInput(JNIEnv *env, jobject, jstring files,
+                                                                              jint nums) {
+  return fl_lenet_lite_SetInputs(JstringToChar(env, files),nums);
+}
+
+extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_LiteTrain_inference(JNIEnv *env, jobject, jstring ms_file,
+                                                                               jint batch_num, jint test_nums) {;
+  auto accuary = fl_lenet_lite_Inference(JstringToChar(env, ms_file), batch_num, test_nums);
+  return accuary;
+}
+
+extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_LiteTrain_free(JNIEnv *, jobject) { return 0; }
\ No newline at end of file
diff --git a/mindspore/lite/flclient/src/main/native/src/lenet_train.cpp b/mindspore/lite/flclient/src/main/native/src/lenet_train.cpp
new file mode 100644
index 0000000..9a0a6b1
--- /dev/null
+++ b/mindspore/lite/flclient/src/main/native/src/lenet_train.cpp
@@ -0,0 +1,270 @@
+/**
+ * Copyright 2020 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "lenet_train.h"
+#include "include/errorcode.h"
+#include "include/context.h"
+#include <cstring>
+#include <iostream>
+#include <fstream>
+#include "include/api/lite_context.h"
+#include "src/common/log_adapter.h"
+
+static char *fl_lenet_I0 = 0;
+static char *fl_lenet_I1 = 0;
+unsigned int seed_ = time(NULL);
+
+std::vector<int> FillInputData(mindspore::session::TrainSession *train_session, int batch_num, bool serially) {
+  std::vector<int> labels_vec;
+  auto inputs = train_session->GetInputs();
+  int batch_size = inputs[0]->shape()[0];
+  static unsigned int idx = 1;
+  int data_size = inputs[0]->ElementsNum() / batch_size;
+  int num_classes = inputs[1]->shape()[1];
+  char *input_data = reinterpret_cast<char *>(inputs.at(0)->MutableData());
+  auto labels = reinterpret_cast<float *>(inputs.at(1)->MutableData());
+  std::fill(labels, labels + inputs.at(1)->ElementsNum(), 0.f);
+  for (int i = 0; i < batch_size; i++) {
+    if (serially) {
+      idx = ++idx % batch_num;
+    } else {
+      idx = rand_r(&seed_) % batch_num;
+    }
+    std::memcpy(input_data + i * data_size, fl_lenet_I0 + idx * data_size, data_size);
+    int label_idx = *((int *)(fl_lenet_I1) + idx);
+    labels[i * num_classes + label_idx] = 1.0;  // Model expects labels in onehot representation
+    labels_vec.push_back(label_idx);
+  }
+  return labels_vec;
+}
+
+mindspore::tensor::MSTensor *SearchOutputsForSize(mindspore::session::TrainSession *train_session, size_t size) {
+  auto outputs = train_session->GetOutputs();
+  for (auto it = outputs.begin(); it != outputs.end(); ++it) {
+    if (it->second->ElementsNum() == size) return it->second;
+  }
+  MS_LOG(ERROR) << "Model does not have an output tensor with size ";
+  return nullptr;
+}
+
+float GetLoss(mindspore::session::TrainSession *train_session) {
+  auto outputsv = SearchOutputsForSize(train_session, 1);  // Search for Loss which is a single value tensor
+  if (outputsv == nullptr) {
+    return 10000;
+  }
+  auto loss = reinterpret_cast<float *>(outputsv->MutableData());
+  return loss[0];
+}
+mindspore::session::TrainSession *GetSession(const std::string &ms_file, bool train_mode) {
+  // create model file
+  mindspore::lite::Context context;
+  context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = mindspore::lite::NO_BIND;
+  context.thread_num_ = 1;
+  return mindspore::session::TrainSession::CreateSession(ms_file, &context, train_mode);
+}
+
+// net training function
+int fl_lenet_lite_Inference(const std::string &ms_file, int batch_num, int test_nums) {
+  auto session = GetSession(ms_file, false);
+  char *origin_input[] = {fl_lenet_I0, fl_lenet_I1};
+  float accuracy = 0.0;
+  session->Eval();
+  auto inputs = session->GetInputs();
+  if (inputs[1]->shape().size() != 2) {
+    return mindspore::lite::RET_ERROR;
+  }
+  auto batch_size = inputs[1]->shape()[0];
+  auto num_of_class = inputs[1]->shape()[1];
+  for (int j = 0; j < test_nums; ++j) {
+    auto labels = FillInputData(session, batch_num, true);
+    session->RunGraph();
+    auto outputsv = SearchOutputsForSize(session, batch_size * num_of_class);
+    auto scores = reinterpret_cast<float *>(outputsv->MutableData());
+    for (int b = 0; b < batch_size; b++) {
+      int max_idx = 0;
+      float max_score = scores[num_of_class * b];
+      for (int c = 0; c < num_of_class; c++) {
+        if (scores[num_of_class * b + c] > max_score) {
+          max_score = scores[num_of_class * b + c];
+          max_idx = c;
+        }
+      }
+      if (labels[b] == max_idx) accuracy += 1.0;
+    }
+  }
+  fl_lenet_I0 = origin_input[0];
+  fl_lenet_I1 = origin_input[1];
+  accuracy /= static_cast<float>(batch_size * test_nums);
+  MS_LOG(INFO) << "accuracy  is " << accuracy;
+  return mindspore::lite::RET_OK;
+}
+
+// net training function
+int fl_lenet_lite_Train(const std::string &ms_file, const int batch_num, const int iterations) {
+  auto session = GetSession(ms_file, true);
+  if (iterations <= 0) {
+    MS_LOG(ERROR) << "error iterations or epoch!, epoch:"
+                 << ", iterations" << iterations;
+    return mindspore::lite::RET_ERROR;
+  }
+  MS_LOG(INFO) << "total iterations :" << iterations << "batch_num:" << batch_num;
+  char *origin_input[] = {fl_lenet_I0, fl_lenet_I1};
+  float min_loss = 1000.;
+  for (int j = 0; j < iterations; ++j) {
+    FillInputData(session, batch_num, false);
+    session->RunGraph(nullptr, nullptr);
+    float loss = GetLoss(session);
+    if (min_loss > loss) min_loss = loss;
+    if (j % 50 == 0) {
+      MS_LOG(INFO) << "iteration:" << j << ",Loss is" << loss << " [min=" << min_loss << "]";
+    }
+  }
+  session->SaveToFile(ms_file);
+  fl_lenet_I0 = origin_input[0];
+  fl_lenet_I1 = origin_input[1];
+  return mindspore::lite::RET_OK;
+}
+
+int fl_lenet_lite_UpdateFeatures(const std::string &update_ms_file, TrainFeatureParam *new_features, int size) {
+  auto train_session = GetSession(update_ms_file, false);
+  auto status = train_session->UpdateFeatureMaps(update_ms_file, new_features, size);
+  if (status != mindspore::lite::RET_OK) {
+    MS_LOG(ERROR) << "update model feature map failed" << update_ms_file;
+  }
+  delete train_session;
+  return status;
+}
+
+int fl_lenet_lite_GetFeatures(const std::string &update_ms_file, mindspore::session::TrainFeatureParam ***feature,
+                              int *size) {
+  auto train_session = GetSession(update_ms_file, false);
+  std::vector<mindspore::session::TrainFeatureParam *> new_features;
+  auto status = train_session->GetFeatureMaps(&new_features);
+  if (status != mindspore::lite::RET_OK) {
+    MS_LOG(ERROR) << "get model feature map failed" << update_ms_file;
+    delete train_session;
+    return mindspore::lite::RET_ERROR;
+  }
+  *feature = new (std::nothrow) TrainFeatureParam *[new_features.size()];
+  if (*feature == nullptr) {
+    MS_LOG(ERROR) << "create features failed";
+    delete train_session;
+    return mindspore::lite::RET_ERROR;
+  }
+  for (int i = 0; i < new_features.size(); i++) {
+    (*feature)[i] = new_features[i];
+  }
+  *size = new_features.size();
+  delete train_session;
+  return mindspore::lite::RET_OK;
+}
+
+std::string RealPath(const char *path) {
+  if (path == nullptr) {
+    MS_LOG(ERROR) << "path is nullptr";
+    return "";
+  }
+  if ((strlen(path)) >= PATH_MAX) {
+    MS_LOG(ERROR) << "path is too long";
+    return "";
+  }
+  auto resolved_path = std::make_unique<char[]>(PATH_MAX);
+  if (resolved_path == nullptr) {
+    MS_LOG(ERROR) << "new resolved_path failed";
+    return "";
+  }
+#ifdef _WIN32
+  char *real_path = _fullpath(resolved_path.get(), path, 1024);
+#else
+  char *real_path = realpath(path, resolved_path.get());
+#endif
+  if (real_path == nullptr || strlen(real_path) == 0) {
+    MS_LOG(ERROR) << "file path is not valid : " << path;
+    return "";
+  }
+  std::string res = resolved_path.get();
+  return res;
+}
+
+char *ReadFile(const char *file, size_t *size) {
+  if (file == nullptr) {
+    MS_LOG(ERROR) << "file is nullptr";
+    return nullptr;
+  }
+  //  MS_ASSERT(size != nullptr);
+  std::string real_path = RealPath(file);
+  std::ifstream ifs(real_path);
+  if (!ifs.good()) {
+    MS_LOG(ERROR) << "file: " << real_path << " is not exist";
+    return nullptr;
+  }
+
+  if (!ifs.is_open()) {
+    MS_LOG(ERROR) << "file: " << real_path << " open failed";
+    return nullptr;
+  }
+
+  ifs.seekg(0, std::ios::end);
+  *size = ifs.tellg();
+  std::unique_ptr<char[]> buf(new (std::nothrow) char[*size]);
+  if (buf == nullptr) {
+    MS_LOG(ERROR) << "malloc buf failed, file: " << real_path;
+    ifs.close();
+    return nullptr;
+  }
+  ifs.seekg(0, std::ios::beg);
+  ifs.read(buf.get(), *size);
+  ifs.close();
+
+  return buf.release();
+}
+
+// Set input tensors.
+int fl_lenet_lite_SetInputs(const std::string &files, int num) {
+  std::vector<std::string> res;
+  if (files.empty()) {
+    MS_LOG(ERROR) << "files empty";
+    return -1;
+  }
+  std::string pattern = ",";
+  std::string strs = files + pattern;
+  size_t pos = strs.find(pattern);
+  while (pos != strs.npos) {
+    std::string temp = strs.substr(0, pos);
+    res.push_back(temp);
+    strs = strs.substr(pos + 1, strs.size());
+    pos = strs.find(pattern);
+  }
+  if (res.size() != 2) {
+    MS_LOG(ERROR) << "res size not equal 2";
+    return -1;
+  }
+  for (int i = 0; i < 2; i++) {
+    size_t size;
+    char *bin_buf = ReadFile(res[i].c_str(), &size);
+    if (bin_buf == nullptr) {
+      MS_LOG(ERROR) << "ReadFile return nullptr";
+      return mindspore::lite::RET_ERROR;
+    }
+    if (i == 0) {
+      fl_lenet_I0 = bin_buf;
+    }
+    if (i == 1) {
+      fl_lenet_I1 = bin_buf;
+    }
+  }
+
+  return 0;
+}
\ No newline at end of file
diff --git a/mindspore/lite/flclient/src/test/java/com/huawei/flclient/test/TestFLClient.java b/mindspore/lite/flclient/src/test/java/com/huawei/flclient/test/TestFLClient.java
index 1986e20..fa6ff63 100644
--- a/mindspore/lite/flclient/src/test/java/com/huawei/flclient/test/TestFLClient.java
+++ b/mindspore/lite/flclient/src/test/java/com/huawei/flclient/test/TestFLClient.java
@@ -17,7 +17,7 @@
 package com.huawei.flclient.test;
 
 import com.huawei.flclient.EncryptLevel;
-import com.huawei.flclient.FLClient;
+import com.huawei.flclient.FLLiteClient;
 
 public class TestFLClient {
 
@@ -34,7 +34,7 @@ public class TestFLClient {
         EncryptLevel encryptLevel = EncryptLevel.valueOf(args[9]);
         boolean use_elb = Boolean.parseBoolean(args[10]);
         int serverNum = Integer.parseInt(args[11]);
-        FLClient client = new FLClient(ip, port, fl_name, fl_id, ifAsync, encryptLevel, use_elb, serverNum, train_batch_num, test_dataset, test_batch_num);
+        FLLiteClient client = new FLLiteClient(ip, port, fl_name, fl_id, ifAsync, encryptLevel, use_elb, serverNum, train_batch_num, test_dataset, test_batch_num);
         client.initRuntimeResource();
 
 //        String train_dataset = "/home/user1/gitdir/master/mindspore/lite/flclient/src/main/resources/data_mindir_case1_2/data_bin_new/f0049_32/f0049_32_train_data.bin," +
@@ -62,7 +62,7 @@ public class TestFLClient {
      * @param test_dataset,    the binary files for evaluate model preference, eg: "./test_data.bin,test_label.bin"
      * @param test_batch_num,  which is equal to the test_dataset size divided by batch_size
      */
-    private static void syncFLJobTest(FLClient client, String train_dateset, int epoch, int train_batch_num,
+    private static void syncFLJobTest(FLLiteClient client, String train_dateset, int epoch, int train_batch_num,
                                       String test_dataset, int test_batch_num) {
         int iterations = 100;
         client.setInput(train_dateset, train_batch_num);
@@ -82,7 +82,7 @@ public class TestFLClient {
      * @param test_dataset,    the binary files for evaluate model preference
      * @param test_batch_num,  which is equal to the test_dataset size divided by batch_size
      */
-    private static void asyncFLJobTest(FLClient client, String train_dateset, int epoch, int train_batch_num,
+    private static void asyncFLJobTest(FLLiteClient client, String train_dateset, int epoch, int train_batch_num,
                                        String test_dataset, int test_batch_num) {
         client.setInput(train_dateset, train_batch_num);
         client.startFLJob();
diff --git a/mindspore/lite/include/train_session.h b/mindspore/lite/include/train_session.h
index 0a7faf4..d3ce3bf 100644
--- a/mindspore/lite/include/train_session.h
+++ b/mindspore/lite/include/train_session.h
@@ -24,7 +24,14 @@
 namespace mindspore {
 namespace session {
 
-/// \brief TrainSession Defines a class that allows training a MindSpore model
+struct TrainFeatureParam{
+  char* name;
+  void *data;
+  size_t elenums;
+  enum TypeId type;
+};
+/// \brief TrainSession De
+/// fines a class that allows training a MindSpore model
 class TrainSession : public session::LiteSession {
  public:
   /// \brief Class destructor
@@ -83,6 +90,10 @@ class TrainSession : public session::LiteSession {
   /// \return boolean indication if model is in eval mode
   bool IsEval() { return train_mode_ == false; }
 
+  virtual int GetFeatureMaps(std::vector<mindspore::session::TrainFeatureParam *>* feature_maps) =0;
+
+  virtual int UpdateFeatureMaps(const std::string &update_ms_file,
+                                TrainFeatureParam* new_features,int size) =0;
  protected:
   bool train_mode_ = false;
 };
diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc
index 494b417..44e8cb9 100644
--- a/mindspore/lite/src/train/train_session.cc
+++ b/mindspore/lite/src/train/train_session.cc
@@ -23,6 +23,7 @@
 #include <iostream>
 #include <fstream>
 #include <memory>
+#include <cstring>
 #include "include/errorcode.h"
 #include "src/common/utils.h"
 #include "src/tensor.h"
@@ -101,7 +102,6 @@ int TrainSession::CompileTrainGraph(mindspore::lite::TrainModel *model) {
   }
   orig_output_node_map_ = output_node_map_;
   orig_output_tensor_map_ = output_tensor_map_;
-
   for (auto inTensor : inputs_) inTensor->MutableData();
   RestoreOps(restore);
   CompileTrainKernels();      // Prepare a list of train kernels
@@ -331,6 +331,63 @@ bool TrainSession::IsMaskOutput(kernel::LiteKernel *kernel) const {
   return (IsOptimizer(kernel) || (kernel->Type() == schema::PrimitiveType_Assign));
 }
 
+int lite::TrainSession::GetFeatureMaps(std::vector<mindspore::session::TrainFeatureParam *> *feature_maps) {
+  for (auto tensor : this->tensors_) {
+    if (tensor->IsConst()) {
+      auto param = new mindspore::session::TrainFeatureParam();
+      int len = tensor->tensor_name().length();
+      char* name = nullptr;
+      if(len>0) {
+        name = new char[len+1];
+        memcpy(name, tensor->tensor_name().c_str(), len);
+        name[len] = 0;
+      }
+      param->name =  name;
+      param->data = new float[tensor->ElementsNum()];
+      memcpy(param->data, tensor->data_c(), tensor->ElementsNum()*sizeof(float));
+      param->data = tensor->data_c();
+      param->elenums = tensor->ElementsNum();
+      param->type = tensor->data_type();
+      feature_maps->push_back(param);
+    }
+  }
+  MS_LOG(INFO) << "get feature map success";
+  return RET_OK;
+}
+int lite::TrainSession::UpdateFeatureMaps(const std::string &update_ms_file,
+                                                                mindspore::session::TrainFeatureParam* new_features,int size) {
+  std::vector<mindspore::session::TrainFeatureParam *> old_features;
+  auto status = GetFeatureMaps(&old_features);
+  if (status != RET_OK) {
+    MS_LOG(ERROR) << "get features map failed:";
+  }
+  for (int i=0;i<size;++i) {
+    mindspore::session::TrainFeatureParam* new_feature = new_features + i;
+    bool find = false;
+    for (auto old_feature : old_features) {
+      if (strcmp(old_feature->name, new_feature->name) == 0) {
+        if(old_feature->elenums != new_feature->elenums) {
+          MS_LOG(ERROR) << "feature name:"<<old_feature->name<<",len diff:"<<"old is:"<<old_feature->elenums<<"new is:"<<new_feature->elenums;
+          return RET_ERROR;
+        }
+        find = true;
+        memcpy(old_feature->data, new_feature->data, new_feature->elenums*sizeof(float));
+        break;
+      }
+    }
+    if (!find) {
+      MS_LOG(ERROR) << "cannot find feature:" << new_feature->name;
+      return RET_ERROR;
+    }
+  }
+  SaveToFile(update_ms_file);
+  for (auto feature : old_features) {
+    delete feature;
+  }
+  MS_LOG(INFO) << "update model:" << update_ms_file << ",feature map success";
+  return RET_OK;
+}
+
 }  // namespace lite
 
 session::TrainSession *session::TrainSession::CreateSession(const char *model_buf, size_t size, lite::Context *context,
diff --git a/mindspore/lite/src/train/train_session.h b/mindspore/lite/src/train/train_session.h
index fdaabef..1cefe51 100644
--- a/mindspore/lite/src/train/train_session.h
+++ b/mindspore/lite/src/train/train_session.h
@@ -78,6 +78,9 @@ class TrainSession : virtual public session::TrainSession, virtual public lite::
   int Resize(const std::vector<tensor::MSTensor *> &inputs, const std::vector<std::vector<int>> &dims) override {
     return lite::RET_ERROR;
   }
+  int GetFeatureMaps(std::vector<mindspore::session::TrainFeatureParam *> *feature_maps) override;
+  int UpdateFeatureMaps(const std::string &update_ms_file,
+                        mindspore::session::TrainFeatureParam* new_features,int size) override;
 
  protected:
   void AllocWorkSpace();
-- 
2.7.4

