<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

#Data process

In [None]:
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_TrainSession_getFeaturesMap(JNIEnv *env, jobject thiz,
                                                                                    jlong session_ptr) {
  jclass array_list = env->FindClass("java/util/ArrayList");
  jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
  jobject ret = env->NewObject(array_list, array_list_construct);
  jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z");

  jclass long_object = env->FindClass("java/lang/Long");
  jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
  auto *pointer = reinterpret_cast<void *>(session_ptr);
  if (pointer == nullptr) {
    MS_LOGE("Session pointer from java is nullptr");
    return ret;
  }
  auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(pointer);
  auto inputs = train_session_ptr->GetFeatureMaps();
  for (auto input : inputs) {
    jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(input));
    env->CallBooleanMethod(ret, array_list_add, tensor_addr);
  }
  return ret;
}

In [None]:
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_updateFeatures(JNIEnv *env, jclass,
                                                                                      jlong session_ptr,
                                                                                      jstring model_path,jlongArray features) {
  jsize size = static_cast<int>(env->GetArrayLength(features));
  jlong *input_data = env->GetLongArrayElements(features, nullptr);
  std::unordered_map<std::string,mindspore::tensor::MSTensor*> newFeatures;
  for (int i = 0; i < size; ++i) {
    auto *tensor_pointer = reinterpret_cast<void *>(input_data[i]);
    if (tensor_pointer == nullptr) {
      MS_LOGE("Tensor pointer from java is nullptr");
      return false;
    }
    auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(tensor_pointer);
    newFeatures.emplace(ms_tensor_ptr->tensor_name(),ms_tensor_ptr);
  }
  auto session = reinterpret_cast<mindspore::session::TrainSession *>(session_ptr);
  auto ret =  session->UpdateFeatureMaps(JstringToChar(env, model_path),newFeatures);
  return (jboolean)(ret == mindspore::lite::RET_OK);;
}

In [None]:
std::vector<tensor::MSTensor*> TrainSession::GetFeatureMaps() {
  std::vector<tensor::MSTensor*> features;
  for (auto cur_tensor : this->tensors_) {
    if (cur_tensor->IsConst() && cur_tensor->data_type() == kNumberTypeFloat32) {
      features.push_back(cur_tensor);
    }
  }
  return features;
}
int TrainSession::UpdateFeatureMaps(const std::string &update_ms_file,
                                    std::unordered_map<std::string,tensor::MSTensor*> features_map) {
  std::set<std::string> need_update_features;
  std::transform(
    features_map.begin(),
    features_map.end(),
    std::inserter(need_update_features,need_update_features.begin()),
    [](const std::unordered_map<std::string,tensor::MSTensor*>::value_type &pair){return pair.first;});
   for(auto tensor:tensors_) {
    if(!tensor->IsConst() || tensor->data_type() !=  kNumberTypeFloat32) {
      continue;
    }
    auto feature_name = tensor->tensor_name();
    if(features_map.find(feature_name) != features_map.end()) {
      auto new_feature = features_map[feature_name];
      if(tensor->ElementsNum() != new_feature->ElementsNum()) {
        MS_LOG(ERROR) << "feature name:" << feature_name << ",len diff:"
                      << "old is:" << tensor->ElementsNum() << "new is:" << new_feature->ElementsNum();
        return RET_ERROR;
      }
      memcpy(tensor->data(), new_feature->data(), new_feature->ElementsNum() * sizeof(float));
      need_update_features.erase(feature_name);
    }
  }
  if(!need_update_features.empty()) {
    for(auto it=need_update_features.begin(); it!=need_update_features.end(); ++it) {
      MS_LOG(ERROR) << "cannot find feature:" << *it << ",update failed";
    }
    return RET_ERROR;
  }
  SaveToFile(update_ms_file);
  return RET_OK;
}

In [None]:
diff --git a/build.sh b/build.sh
index e3d4530..53757c0 100755
--- a/build.sh
+++ b/build.sh
@@ -670,9 +670,20 @@ build_lite_java_arm64() {
     tar -zxvf ${JTARBALL}.tar.gz
     [ -n "${JAVA_PATH}" ] && rm -rf ${JAVA_PATH}/java/app/libs/arm64-v8a/
     mkdir -p ${JAVA_PATH}/java/app/libs/arm64-v8a/
-    cp ${BASEPATH}/output/${JTARBALL}/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
     mkdir -p ${JAVA_PATH}/native/libs/arm64-v8a/
-    cp ${BASEPATH}/output/${JTARBALL}/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/arm64-v8a/
+     if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then
+        cp ${BASEPATH}/output/${JTARBALL}/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
+        cp ${BASEPATH}/output/${JTARBALL}/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/arm64-v8a/
+
+        cp ${BASEPATH}/output/${JTARBALL}/minddata/lib/libminddata-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
+        cp ${BASEPATH}/output/${JTARBALL}/minddata/lib/libminddata-lite.so ${JAVA_PATH}/native/libs/arm64-v8a/
+
+        cp ${BASEPATH}/output/${JTARBALL}/minddata/third_party/libjpeg-turbo/lib/*.so* ${JAVA_PATH}/java/app/libs/arm64-v8a/
+        cp ${BASEPATH}/output/${JTARBALL}/minddata/third_party/libjpeg-turbo/lib/*.so* ${JAVA_PATH}/native/libs/arm64-v8a/
+    else
+        cp ${BASEPATH}/output/${JTARBALL}/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
+        cp ${BASEPATH}/output/${JTARBALL}/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/arm64-v8a/
+    fi
     [ -n "${VERSION_STR}" ] && rm -rf ${JTARBALL}
 }
 
@@ -691,15 +702,31 @@ build_lite_java_arm32() {
     tar -zxvf ${JTARBALL}.tar.gz
     [ -n "${JAVA_PATH}" ] && rm -rf ${JAVA_PATH}/java/app/libs/armeabi-v7a/
     mkdir -p ${JAVA_PATH}/java/app/libs/armeabi-v7a/
-    cp ${BASEPATH}/output/${JTARBALL}/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/
     mkdir -p ${JAVA_PATH}/native/libs/armeabi-v7a/
-    cp ${BASEPATH}/output/${JTARBALL}/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/armeabi-v7a/
+
+    if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then
+        cp ${BASEPATH}/output/${JTARBALL}/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/
+        cp ${BASEPATH}/output/${JTARBALL}/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/armeabi-v7a/
+
+        cp ${BASEPATH}/output/${JTARBALL}/minddata/lib/libminddata-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/
+        cp ${BASEPATH}/output/${JTARBALL}/minddata/lib/libminddata-lite.so ${JAVA_PATH}/native/libs/armeabi-v7a/
+
+        cp ${BASEPATH}/output/${JTARBALL}/minddata/third_party/libjpeg-turbo/lib/*.so* ${JAVA_PATH}/java/app/libs/armeabi-v7a/
+        cp ${BASEPATH}/output/${JTARBALL}/minddata/third_party/libjpeg-turbo/lib/*.so* ${JAVA_PATH}/native/libs/armeabi-v7a/
+    else
+        cp ${BASEPATH}/output/${JTARBALL}/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/
+        cp ${BASEPATH}/output/${JTARBALL}/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/armeabi-v7a/
+    fi
     [ -n "${VERSION_STR}" ] && rm -rf ${JTARBALL}
 }
 
 build_lite_java_x86() {
     # build mindspore-lite x86
     local JTARBALL=mindspore-lite-${VERSION_STR}-inference-linux-x64
+    if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then
+      JTARBALL=mindspore-lite-${VERSION_STR}-train-linux-x64
+    fi
+
     if [[ "X$INC_BUILD" = "Xoff" ]] || [[ ! -f "${BASEPATH}/output/${JTARBALL}.tar.gz" ]]; then
       build_lite "x86_64" "off" ""
     fi
@@ -709,9 +736,21 @@ build_lite_java_x86() {
     tar -zxvf ${JTARBALL}.tar.gz
     [ -n "${JAVA_PATH}" ] && rm -rf ${JAVA_PATH}/java/linux_x86/libs/
     mkdir -p ${JAVA_PATH}/java/linux_x86/libs/
-    cp ${BASEPATH}/output/${JTARBALL}/lib/libmindspore-lite.so ${JAVA_PATH}/java/linux_x86/libs/
     mkdir -p ${JAVA_PATH}/native/libs/linux_x86/
-    cp ${BASEPATH}/output/${JTARBALL}/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/linux_x86/
+     if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then
+        cp ${BASEPATH}/output/${JTARBALL}/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/linux_x86
+        cp ${BASEPATH}/output/${JTARBALL}/lib/libmindspore-lite.so ${JAVA_PATH}/java/linux_x86/libs/
+
+        cp ${BASEPATH}/output/${JTARBALL}/minddata/lib/libminddata-lite.so ${JAVA_PATH}/native/libs/linux_x86/
+        cp ${BASEPATH}/output/${JTARBALL}/minddata/lib/libminddata-lite.so ${JAVA_PATH}/java/linux_x86/libs/
+
+        cp ${BASEPATH}/output/${JTARBALL}/minddata/third_party/libjpeg-turbo/lib/*.so* ${JAVA_PATH}/java/linux_x86/libs/
+        cp ${BASEPATH}/output/${JTARBALL}/minddata/third_party/libjpeg-turbo/lib/*.so* ${JAVA_PATH}/native/libs/linux_x86/
+    else
+        cp ${BASEPATH}/output/${JTARBALL}/lib/libmindspore-lite.so ${JAVA_PATH}/java/linux_x86/libs/
+        cp ${BASEPATH}/output/${JTARBALL}/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/linux_x86/
+    fi
+    [ -n "${VERSION_STR}" ] && rm -rf ${JTARBALL}
 }
 
 build_jni_arm64() {
@@ -794,7 +833,7 @@ build_java() {
     gradle clean
     gradle build
 
-    # build aar
+#    build aar
     build_lite_java_arm64
     build_jni_arm64
     build_lite_java_arm32
@@ -813,6 +852,18 @@ build_java() {
     # copy output
     cp mindspore-lite-maven-${VERSION_STR}.zip ${BASEPATH}/output/
 
+     local inference_or_train=inference
+    if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then
+      inference_or_train=train
+    fi
+
+    # build linux x86 jar
+    if [[ "$X86_64_SIMD" == "sse" || "$X86_64_SIMD" == "avx" ]]; then
+          local LINUX_X86_PACKAGE_NAME=mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64-${X86_64_SIMD}-jar
+    else
+          local LINUX_X86_PACKAGE_NAME=mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64-jar
+    fi
+
     # build linux x86 jar
     check_java_home
     build_lite_java_x86
@@ -827,11 +878,16 @@ build_java() {
     # install and package
     mkdir -p ${JAVA_PATH}/java/linux_x86/build/lib
     cp ${JAVA_PATH}/java/linux_x86/libs/*.so ${JAVA_PATH}/java/linux_x86/build/lib/jar
-    cp -r ${JAVA_PATH}/java/linux_x86/build/lib/jar ${BASEPATH}/output/mindspore-lite-${VERSION_STR}-inference-linux-x64/lib/
+     cd ${JAVA_PATH}/java/linux_x86/build/
+    cp -r ${JAVA_PATH}/java/linux_x86/build/lib ${JAVA_PATH}/java/linux_x86/build/${LINUX_X86_PACKAGE_NAME}
+    tar czvf ${LINUX_X86_PACKAGE_NAME}.tar.gz ${LINUX_X86_PACKAGE_NAME}
+
+    cp ${JAVA_PATH}/java/app/build/mindspore-lite-maven-${VERSION_STR}.zip ${BASEPATH}/output
+    cp ${LINUX_X86_PACKAGE_NAME}.tar.gz ${BASEPATH}/output
+    cp -r ${JAVA_PATH}/java/linux_x86/build/lib/jar ${BASEPATH}/output/mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64/lib/
     cd ${BASEPATH}/output
-    tar czf mindspore-lite-${VERSION_STR}-inference-linux-x64.tar.gz mindspore-lite-${VERSION_STR}-inference-linux-x64
     # copy output
-    [ -n "${VERSION_STR}" ] && rm -rf mindspore-lite-${VERSION_STR}-inference-linux-x64
+    [ -n "${VERSION_STR}" ] && rm -rf ${BASEPATH}/mindspore/lite/build/java/mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64
     exit 0
 }
 


In [None]:
diff --git a/mindspore/lite/java/java/common/src/main/java/com/mindspore/lite/MSTensor.java b/mindspore/lite/java/java/common/src/main/java/com/mindspore/lite/MSTensor.java
index 8384a63..06e2256 100644
--- a/mindspore/lite/java/java/common/src/main/java/com/mindspore/lite/MSTensor.java
+++ b/mindspore/lite/java/java/common/src/main/java/com/mindspore/lite/MSTensor.java
@@ -28,7 +28,9 @@ public class MSTensor {
     public MSTensor(long tensorPtr) {
         this.tensorPtr = tensorPtr;
     }
-
+    public MSTensor(String tensorName,ByteBuffer buffer) {
+      this.tensorPtr = createTensor(tensorName,buffer);
+    }
     public int[] getShape() {
         return this.getShape(this.tensorPtr);
     }
@@ -81,7 +83,7 @@ public class MSTensor {
     protected long getMSTensorPtr() {
         return tensorPtr;
     }
-
+    public static native long createTensor(String tensorName,ByteBuffer buffer);
     private native int[] getShape(long tensorPtr);
 
     private native int getDataType(long tensorPtr);
diff --git a/mindspore/lite/java/java/common/src/main/java/com/mindspore/lite/TrainSession.java b/mindspore/lite/java/java/common/src/main/java/com/mindspore/lite/TrainSession.java
index f3a8d94..d29113e 100644
--- a/mindspore/lite/java/java/common/src/main/java/com/mindspore/lite/TrainSession.java
+++ b/mindspore/lite/java/java/common/src/main/java/com/mindspore/lite/TrainSession.java
@@ -153,8 +153,24 @@ public class TrainSession {
     public boolean setLossName(String lossName) {
         return this.setLossName(this.sessionPtr,lossName);
     }
-    
-    
+    public List<MSTensor> getFeaturesMap() {
+     List<Long> ret = this.getFeaturesMap(this.sessionPtr);
+            ArrayList<MSTensor> tensors = new ArrayList<MSTensor>();
+            for (Long msTensorAddr : ret) {
+                MSTensor msTensor = new MSTensor(msTensorAddr);
+                tensors.add(msTensor);
+            }
+            return tensors;
+    }
+    public boolean updateFeatures(String modelFilename,List<MSTensor> features) {
+        long[] inputsArray = new long[features.size()];
+        for (int i = 0; i < features.size(); i++) {
+            inputsArray[i] = features.get(i).getMSTensorPtr();
+        }
+         return this.updateFeatures(this.sessionPtr,modelFilename, inputsArray);
+    }
+    private native boolean updateFeatures(long sessionPtr,String modelFilename, long[] newFeatures);
+    private native List<Long> getFeaturesMap(long sessionPtr);
     private native long createSession(String modelFilename, long msConfigPtr);
 
     private native void bindThread(long sessionPtr, boolean if_bind);
