<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

#Data process

In [None]:
cmake_minimum_required(VERSION 3.14)

project(FederalLearning)

option(SUPPORT_GPU "if support gpu" off)
set(BUILD_LITE "on")
set(SUPPORT_TRAIN "on")
set(PLATFORM_ARM "on")

set(TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../../../../)
set(LITE_DIR ${TOP_DIR}/mindspore/lite)
set(MS_VERSION_MAJOR ${MS_VERSION_MAJOR})
set(MS_VERSION_MINOR ${MS_VERSION_MINOR})
set(MS_VERSION_REVISION ${MS_VERSION_REVISION})
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")

include_directories(${CMAKE_CURRENT_SOURCE_DIR})
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/linux)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src)
if (ENABLE_MICRO)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/runtime)
endif()
include_directories(${LITE_DIR}) ## lite include
include_directories(${TOP_DIR}) ## api include
include_directories(${TOP_DIR}/mindspore/core/) ## core include
include_directories(${LITE_DIR}/build) ## flatbuffers

if (ENABLE_MICRO)
set(OP_SRC
    src/nnacl/arithmetic_common.c
    src/nnacl/common_func.c
    src/nnacl/fp32/activation.c
    src/nnacl/fp32/arithmetic.c
    src/nnacl/fp32/common_func.c
    src/nnacl/fp32/conv.c
    src/nnacl/fp32/matmul.c
    src/nnacl/fp32/softmax.c
    src/nnacl/fp32_grad/activation_grad.c
    src/nnacl/fp32_grad/gemm.c
    src/nnacl/fp32_grad/pack_ext.c
    src/nnacl/fp32_grad/pooling_grad.c
    src/nnacl/int8/conv_int8.c
    src/nnacl/int8/matmul_int8.c
    src/nnacl/minimal_filtering_generator.c
    src/nnacl/pack.c
    src/nnacl/quantization/fixed_point.c
    src/nnacl/reshape.c
    src/nnacl/winograd_transform.c
    src/nnacl/winograd_utils.c
    src/runtime/kernel/fp32/max_pooling.c
    src/runtime/kernel/fp32_grad/apply_momentum.c
    src/runtime/kernel/fp32_grad/biasadd_grad.c
    src/runtime/kernel/fp32_grad/compute_gradient.c
    src/runtime/kernel/fp32_grad/conv_filter_grad.c
    src/runtime/kernel/fp32_grad/conv_input_grad.c
    src/runtime/kernel/fp32_grad/init_matrix.c
    src/runtime/kernel/fp32_grad/sparse_softmax_cross_entropy_with_logist.c
    src/runtime/load_input.c
    src/fl_lenet.c
    src/weight_files/fl_lenet_weight_epoch_0.c
)
else()
    set(OP_SRC
            src/lenet_train.cpp
            )
    endif()
if (ENABLE_MICRO)
    set(SRC_FILES
            flearning.cpp)
    else()
set(SRC_FILES
        lenet_train_jni.cpp
        )
endif()
find_library(log-lib glog)

add_library(fl SHARED ${SRC_FILES} ${OP_SRC})
target_link_libraries(fl mindspore-lite  glog)
link_directories(${CMAKE_CURRENT_SOURCE_DIR}/lib/)

install(TARGETS fl LIBRARY DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/lib)

In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <jni.h>
#include "include/train_session.h"
#include "include/errorcode.h"
#include "lenet_train.h"
#include <cstring>
#include "src/common/log_adapter.h"

static jobject fbb;
static jmethodID create_string_char;

char *JstringToChar(JNIEnv *env, jstring jstr) {
  char *rtn = nullptr;
  jclass clsstring = env->FindClass("java/lang/String");
  jstring strencode = env->NewStringUTF("GB2312");
  jmethodID mid = env->GetMethodID(clsstring, "getBytes", "(Ljava/lang/String;)[B");
  jbyteArray barr = (jbyteArray)env->CallObjectMethod(jstr, mid, strencode);
  jsize alen = env->GetArrayLength(barr);
  jbyte *ba = env->GetByteArrayElements(barr, JNI_FALSE);
  if (alen > 0) {
    rtn = new char[alen + 1];
    memcpy(rtn, ba, alen);
    rtn[alen] = 0;
  }
  env->ReleaseByteArrayElements(barr, ba, 0);
  return rtn;
}
extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_LiteTrain_train(JNIEnv *env, jobject thiz, jstring ms_file,
                                                                           jint batch_num, jint iterations) {
  return fl_lenet_lite_Train(JstringToChar(env, ms_file), batch_num, iterations);
}

extern "C" jint CreateFeatureMap(JNIEnv *env, const char *name, float *data, size_t size) {
  jstring name1 = env->NewStringUTF(name);
  jint name_offset = env->CallIntMethod(fbb, create_string_char, name1);
  // 1. set data size
  jfloatArray ret = env->NewFloatArray(size);
  env->SetFloatArrayRegion(ret, 0, size, data);
  // 2. get methodid createDataVector
  jclass fm_cls = env->FindClass("mindspore/schema/FeatureMap");
  jmethodID createDataVector =
    env->GetStaticMethodID(fm_cls, "createDataVector", "(Lcom/google/flatbuffers/FlatBufferBuilder;[F)I");
  // 3. calc data offset
  jint data_offset = env->CallStaticIntMethod(fm_cls, createDataVector, fbb, ret);
  jmethodID createFeatureMap =
    env->GetStaticMethodID(fm_cls, "createFeatureMap", "(Lcom/google/flatbuffers/FlatBufferBuilder;II)I");
  jint fm_offset = env->CallStaticIntMethod(fm_cls, createFeatureMap, fbb, name_offset, data_offset);
  return fm_offset;
}

extern "C" JNIEXPORT jintArray JNICALL Java_com_huawei_flclient_LiteTrain_getFeaturesMap(JNIEnv *env, jobject thiz,
                                                                                         jstring ms_file,
                                                                                         jobject builder) {
  fbb = builder;
  jclass fb_clazz = env->GetObjectClass(builder);
  create_string_char = env->GetMethodID(fb_clazz, "createString", "(Ljava/lang/CharSequence;)I");
  TrainFeatureParam **train_features = nullptr;
  int feature_size = 0;
  auto status = fl_lenet_lite_GetFeatures(JstringToChar(env, ms_file), &train_features, &feature_size);
  if(status != mindspore::lite::RET_OK) {
    MS_LOG(ERROR) << "get features failed:" << ms_file;
    return env->NewIntArray(0);
  }
  jintArray ret = env->NewIntArray(feature_size);
  jint *data = env->GetIntArrayElements(ret, NULL);

  for (int i = 0; i < feature_size; i++) {
        data[i] =
          CreateFeatureMap(env, train_features[i]->name, (float *)train_features[i]->data,
          train_features[i]->elenums);
        MS_LOG(INFO) << "upload feature:"<< ", name:" << train_features[i]->name << ", elenums:" <<
        train_features[i]->elenums;
  }
    env->ReleaseIntArrayElements(ret, data, 0);
    for (int i = 0; i < feature_size; i++) {
      delete  train_features[i];
    }
  return ret;
}

extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_LiteTrain_updateFeatures(JNIEnv *env, jobject,
                                                                                    jstring ms_file, jobject features) {
  jclass arr_cls = env->GetObjectClass(features);
  jmethodID size_method = env->GetMethodID(arr_cls, "size", "()I");
  jmethodID get_method = env->GetMethodID(arr_cls, "get", "(I)Ljava/lang/Object;");

  jclass fm_cls = env->FindClass("mindspore/schema/FeatureMap");
  jmethodID weight_name_method = env->GetMethodID(fm_cls, "weightFullname", "()Ljava/lang/String;");
  jmethodID data_length_method = env->GetMethodID(fm_cls, "dataLength", "()I");
  jmethodID data_method = env->GetMethodID(fm_cls, "data", "(I)F");
  jclass clsstring = env->FindClass("java/lang/String");
  jmethodID mid = env->GetMethodID(clsstring, "getBytes", "(Ljava/lang/String;)[B");
  int size = env->CallIntMethod(features, size_method);
  // transform FeatureMap to TrainFeatureParm
  TrainFeatureParam *features_param = (TrainFeatureParam *)malloc(size * sizeof(TrainFeatureParam));
  for (int i = 0; i < size; ++i) {
    TrainFeatureParam *param = features_param + i;
    jobject feature = env->CallObjectMethod(features, get_method, i);
    // set feature_param name
    jstring weight_full_name = (jstring)env->CallObjectMethod(feature, weight_name_method);
    jstring strencode = env->NewStringUTF("GB2312");
    jbyteArray barr = (jbyteArray)env->CallObjectMethod(weight_full_name, mid, strencode);
    char *name = nullptr;
    jsize alen = env->GetArrayLength(barr);
    jbyte *ba = env->GetByteArrayElements(barr, JNI_FALSE);
    if (alen > 0) {
      name = new char[alen + 1];
      if (ba == nullptr) {
        MS_LOG(ERROR) << "name is nullptr";
        return mindspore::lite::RET_ERROR;
      }
      memcpy(name, ba, alen);
      name[alen] = 0;
    }
    param->name = name;
    env->ReleaseByteArrayElements(barr, ba, 0);
    int data_length = env->CallIntMethod(feature, data_length_method);
    float *data = static_cast<float *>(malloc(data_length * sizeof(float)));
    memset(data, 0, data_length * sizeof(float));
    for (int j = 0; j < data_length; ++j) {
      float *addr = data + j;
      *addr = env->CallFloatMethod(feature, data_method, j);
    }
    param->data = data;
    param->elenums = data_length;
    param->type = mindspore::kNumberTypeFloat32;
    MS_LOG(INFO) << "get feature:" << param->name << ",elenums:" << param->elenums;
  }
  return fl_lenet_lite_UpdateFeatures(JstringToChar(env, ms_file), features_param, size);
}

extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_LiteTrain_setInput(JNIEnv *env, jobject, jstring files,
                                                                              jint nums) {
  return fl_lenet_lite_SetInputs(JstringToChar(env, files),nums);
}

extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_LiteTrain_inference(JNIEnv *env, jobject, jstring ms_file,
                                                                               jint batch_num, jint test_nums) {;
  auto accuary = fl_lenet_lite_Inference(JstringToChar(env, ms_file), batch_num, test_nums);
  return accuary;
}

extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_LiteTrain_free(JNIEnv *, jobject) { return 0; }

In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "lenet_train.h"
#include "include/errorcode.h"
#include "include/context.h"
#include <cstring>
#include <iostream>
#include <fstream>
#include "include/api/lite_context.h"
#include "src/common/log_adapter.h"

static char *fl_lenet_I0 = 0;
static char *fl_lenet_I1 = 0;
unsigned int seed_ = time(NULL);

std::vector<int> FillInputData(mindspore::session::TrainSession *train_session, int batch_num, bool serially) {
  std::vector<int> labels_vec;
  auto inputs = train_session->GetInputs();
  int batch_size = inputs[0]->shape()[0];
  static unsigned int idx = 1;
  int data_size = inputs[0]->ElementsNum() / batch_size;
  int num_classes = inputs[1]->shape()[1];
  char *input_data = reinterpret_cast<char *>(inputs.at(0)->MutableData());
  auto labels = reinterpret_cast<float *>(inputs.at(1)->MutableData());
  std::fill(labels, labels + inputs.at(1)->ElementsNum(), 0.f);
  for (int i = 0; i < batch_size; i++) {
    if (serially) {
      idx = ++idx % batch_num;
    } else {
      idx = rand_r(&seed_) % batch_num;
    }
    std::memcpy(input_data + i * data_size, fl_lenet_I0 + idx * data_size, data_size);
    int label_idx = *((int *)(fl_lenet_I1) + idx);
    labels[i * num_classes + label_idx] = 1.0;  // Model expects labels in onehot representation
    labels_vec.push_back(label_idx);
  }
  return labels_vec;
}

mindspore::tensor::MSTensor *SearchOutputsForSize(mindspore::session::TrainSession *train_session, size_t size) {
  auto outputs = train_session->GetOutputs();
  for (auto it = outputs.begin(); it != outputs.end(); ++it) {
    if (it->second->ElementsNum() == size) return it->second;
  }
  MS_LOG(ERROR) << "Model does not have an output tensor with size ";
  return nullptr;
}

float GetLoss(mindspore::session::TrainSession *train_session) {
  auto outputsv = SearchOutputsForSize(train_session, 1);  // Search for Loss which is a single value tensor
  if (outputsv == nullptr) {
    return 10000;
  }
  auto loss = reinterpret_cast<float *>(outputsv->MutableData());
  return loss[0];
}
mindspore::session::TrainSession *GetSession(const std::string &ms_file, bool train_mode) {
  // create model file
  mindspore::lite::Context context;
  context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = mindspore::lite::NO_BIND;
  context.thread_num_ = 1;
  return mindspore::session::TrainSession::CreateSession(ms_file, &context, train_mode);
}

// net training function
int fl_lenet_lite_Inference(const std::string &ms_file, int batch_num, int test_nums) {
  auto session = GetSession(ms_file, false);
  char *origin_input[] = {fl_lenet_I0, fl_lenet_I1};
  float accuracy = 0.0;
  session->Eval();
  auto inputs = session->GetInputs();
  if (inputs[1]->shape().size() != 2) {
    return mindspore::lite::RET_ERROR;
  }
  auto batch_size = inputs[1]->shape()[0];
  auto num_of_class = inputs[1]->shape()[1];
  for (int j = 0; j < test_nums; ++j) {
    auto labels = FillInputData(session, batch_num, true);
    session->RunGraph();
    auto outputsv = SearchOutputsForSize(session, batch_size * num_of_class);
    auto scores = reinterpret_cast<float *>(outputsv->MutableData());
    for (int b = 0; b < batch_size; b++) {
      int max_idx = 0;
      float max_score = scores[num_of_class * b];
      for (int c = 0; c < num_of_class; c++) {
        if (scores[num_of_class * b + c] > max_score) {
          max_score = scores[num_of_class * b + c];
          max_idx = c;
        }
      }
      if (labels[b] == max_idx) accuracy += 1.0;
    }
  }
  fl_lenet_I0 = origin_input[0];
  fl_lenet_I1 = origin_input[1];
  accuracy /= static_cast<float>(batch_size * test_nums);
  MS_LOG(INFO) << "accuracy  is " << accuracy;
  return mindspore::lite::RET_OK;
}

// net training function
int fl_lenet_lite_Train(const std::string &ms_file, const int batch_num, const int iterations) {
  auto session = GetSession(ms_file, true);
  if (iterations <= 0) {
    MS_LOG(ERROR) << "error iterations or epoch!, epoch:"
                 << ", iterations" << iterations;
    return mindspore::lite::RET_ERROR;
  }
  MS_LOG(INFO) << "total iterations :" << iterations << "batch_num:" << batch_num;
  char *origin_input[] = {fl_lenet_I0, fl_lenet_I1};
  float min_loss = 1000.;
  for (int j = 0; j < iterations; ++j) {
    FillInputData(session, batch_num, false);
    session->RunGraph(nullptr, nullptr);
    float loss = GetLoss(session);
    if (min_loss > loss) min_loss = loss;
    if (j % 50 == 0) {
      MS_LOG(INFO) << "iteration:" << j << ",Loss is" << loss << " [min=" << min_loss << "]";
    }
  }
  session->SaveToFile(ms_file);
  fl_lenet_I0 = origin_input[0];
  fl_lenet_I1 = origin_input[1];
  return mindspore::lite::RET_OK;
}

int fl_lenet_lite_UpdateFeatures(const std::string &update_ms_file, TrainFeatureParam *new_features, int size) {
  auto train_session = GetSession(update_ms_file, false);
  auto status = train_session->UpdateFeatureMaps(update_ms_file, new_features, size);
  if (status != mindspore::lite::RET_OK) {
    MS_LOG(ERROR) << "update model feature map failed" << update_ms_file;
  }
  delete train_session;
  return status;
}

int fl_lenet_lite_GetFeatures(const std::string &update_ms_file, mindspore::session::TrainFeatureParam ***feature,
                              int *size) {
  auto train_session = GetSession(update_ms_file, false);
  std::vector<mindspore::session::TrainFeatureParam *> new_features;
  auto status = train_session->GetFeatureMaps(&new_features);
  if (status != mindspore::lite::RET_OK) {
    MS_LOG(ERROR) << "get model feature map failed" << update_ms_file;
    delete train_session;
    return mindspore::lite::RET_ERROR;
  }
  *feature = new (std::nothrow) TrainFeatureParam *[new_features.size()];
  if (*feature == nullptr) {
    MS_LOG(ERROR) << "create features failed";
    delete train_session;
    return mindspore::lite::RET_ERROR;
  }
  for (int i = 0; i < new_features.size(); i++) {
    (*feature)[i] = new_features[i];
  }
  *size = new_features.size();
  delete train_session;
  return mindspore::lite::RET_OK;
}

std::string RealPath(const char *path) {
  if (path == nullptr) {
    MS_LOG(ERROR) << "path is nullptr";
    return "";
  }
  if ((strlen(path)) >= PATH_MAX) {
    MS_LOG(ERROR) << "path is too long";
    return "";
  }
  auto resolved_path = std::make_unique<char[]>(PATH_MAX);
  if (resolved_path == nullptr) {
    MS_LOG(ERROR) << "new resolved_path failed";
    return "";
  }
#ifdef _WIN32
  char *real_path = _fullpath(resolved_path.get(), path, 1024);
#else
  char *real_path = realpath(path, resolved_path.get());
#endif
  if (real_path == nullptr || strlen(real_path) == 0) {
    MS_LOG(ERROR) << "file path is not valid : " << path;
    return "";
  }
  std::string res = resolved_path.get();
  return res;
}

char *ReadFile(const char *file, size_t *size) {
  if (file == nullptr) {
    MS_LOG(ERROR) << "file is nullptr";
    return nullptr;
  }
  //  MS_ASSERT(size != nullptr);
  std::string real_path = RealPath(file);
  std::ifstream ifs(real_path);
  if (!ifs.good()) {
    MS_LOG(ERROR) << "file: " << real_path << " is not exist";
    return nullptr;
  }

  if (!ifs.is_open()) {
    MS_LOG(ERROR) << "file: " << real_path << " open failed";
    return nullptr;
  }

  ifs.seekg(0, std::ios::end);
  *size = ifs.tellg();
  std::unique_ptr<char[]> buf(new (std::nothrow) char[*size]);
  if (buf == nullptr) {
    MS_LOG(ERROR) << "malloc buf failed, file: " << real_path;
    ifs.close();
    return nullptr;
  }
  ifs.seekg(0, std::ios::beg);
  ifs.read(buf.get(), *size);
  ifs.close();

  return buf.release();
}

// Set input tensors.
int fl_lenet_lite_SetInputs(const std::string &files, int num) {
  std::vector<std::string> res;
  if (files.empty()) {
    MS_LOG(ERROR) << "files empty";
    return -1;
  }
  std::string pattern = ",";
  std::string strs = files + pattern;
  size_t pos = strs.find(pattern);
  while (pos != strs.npos) {
    std::string temp = strs.substr(0, pos);
    res.push_back(temp);
    strs = strs.substr(pos + 1, strs.size());
    pos = strs.find(pattern);
  }
  if (res.size() != 2) {
    MS_LOG(ERROR) << "res size not equal 2";
    return -1;
  }
  for (int i = 0; i < 2; i++) {
    size_t size;
    char *bin_buf = ReadFile(res[i].c_str(), &size);
    if (bin_buf == nullptr) {
      MS_LOG(ERROR) << "ReadFile return nullptr";
      return mindspore::lite::RET_ERROR;
    }
    if (i == 0) {
      fl_lenet_I0 = bin_buf;
    }
    if (i == 1) {
      fl_lenet_I1 = bin_buf;
    }
  }

  return 0;
}

In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 * <p>
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * <p>
 * http://www.apache.org/licenses/LICENSE-2.0
 * <p>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.huawei.flclient;

import com.google.flatbuffers.FlatBufferBuilder;
import mindspore.schema.FeatureMap;
import java.util.ArrayList;

public  class LiteTrain {
    static {
        System.loadLibrary("fl");
    }

    private static LiteTrain train;

    private LiteTrain() {
    }
    public static synchronized LiteTrain getInstance() {
        if (train == null) {
            train = new LiteTrain();
        }
        return train;
    }
    /**
     * set the Inference set or Train set
     *
     * @param fileSet   input binary file path which format is NHWC
     * @param batch_num binary file batch num
     * @return
     */
    native int setInput(String fileSet, int num);

    /**
     * inference
     *
     * @return status
     */
    public native int inference(String modelName,int batch_num,int test_nums);

    /**
     * train
     *
     * @return status
     */
    public native int train(String modelName, int batch_num,int iterations);

    /**
     * get the features map of training model
     *
     * @param builder FlatBufferBuilder
     * @return features offset
     */
    native int[] getFeaturesMap(String modelName,FlatBufferBuilder builder);

    /**
     * update the features map of training model
     *
     * @param featureMaps
     * @return status
     */
    native int updateFeatures(String modelName,ArrayList<FeatureMap> featureMaps);

    /**
     * free Inference or Train runtime memory resource
     *
     * @return status
     */
    native int free();
}
