<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

#Data process

In [None]:
From 698054adf09cec47b7208cc1da8f66c7e3c67d79 Mon Sep 17 00:00:00 2001
From: zhengjun10 <zhengjun10@huawei.com>
Date: Tue, 15 Jun 2021 15:57:11 +0800
Subject: [PATCH] add fl lite

---
 mindspore/lite/include/lite_session.h              | 15 ++++++++
 .../src/main/java/com/mindspore/lite/MSTensor.java |  6 ++++
 mindspore/lite/java/java/linux_x86/build.gradle    |  1 +
 .../main/java/com.mindspore.lite/LiteSession.java  | 21 +++++++++++
 .../lite/java/native/runtime/lite_session.cpp      | 42 ++++++++++++++++++++++
 mindspore/lite/java/native/runtime/ms_tensor.cpp   | 17 +++++++++
 mindspore/lite/src/train/train_session.cc          | 35 ++++++++++++++++++
 mindspore/lite/src/train/train_session.h           |  4 +++
 8 files changed, 141 insertions(+)

diff --git a/mindspore/lite/include/lite_session.h b/mindspore/lite/include/lite_session.h
index 3c2942a..eaada07 100644
--- a/mindspore/lite/include/lite_session.h
+++ b/mindspore/lite/include/lite_session.h
@@ -210,6 +210,21 @@ class MS_API LiteSession {
                      lite::QuantizationType quant_type = lite::QT_DEFAULT, lite::FormatType = lite::FT_FLATBUFFERS) {
     return mindspore::lite::RET_ERROR;
   }
+
+  /// \brief Get model featuremap MindSpore Lite MSTensors of Training model prediction
+  ///
+  /// \return a vector of output tensors (MindSpore Lite MSTensor).
+  virtual std::vector<tensor::MSTensor *> GetFeatureMaps() const {
+    std::vector<tensor::MSTensor *> features;
+    return features;
+  };
+
+  /// \brief update model featuremap save to update_ms_file
+  /// \param[in] features new featuremap
+  /// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
+  virtual int UpdateFeatureMaps(const std::vector<tensor::MSTensor *> &features) {
+    return mindspore::lite::RET_ERROR;
+  };
 };
 }  // namespace session
 }  // namespace mindspore
diff --git a/mindspore/lite/java/java/common/src/main/java/com/mindspore/lite/MSTensor.java b/mindspore/lite/java/java/common/src/main/java/com/mindspore/lite/MSTensor.java
index 8384a63..0a21451 100644
--- a/mindspore/lite/java/java/common/src/main/java/com/mindspore/lite/MSTensor.java
+++ b/mindspore/lite/java/java/common/src/main/java/com/mindspore/lite/MSTensor.java
@@ -29,6 +29,10 @@ public class MSTensor {
         this.tensorPtr = tensorPtr;
     }
 
+    public MSTensor(String tensorName,ByteBuffer buffer) {
+          this.tensorPtr = createTensor(tensorName, buffer);
+    }
+
     public int[] getShape() {
         return this.getShape(this.tensorPtr);
     }
@@ -82,6 +86,8 @@ public class MSTensor {
         return tensorPtr;
     }
 
+    private native long createTensor(String tensorName, ByteBuffer buffer);
+
     private native int[] getShape(long tensorPtr);
 
     private native int getDataType(long tensorPtr);
diff --git a/mindspore/lite/java/java/linux_x86/build.gradle b/mindspore/lite/java/java/linux_x86/build.gradle
index 5d3c76b..a28c4cf 100644
--- a/mindspore/lite/java/java/linux_x86/build.gradle
+++ b/mindspore/lite/java/java/linux_x86/build.gradle
@@ -2,6 +2,7 @@ apply plugin: 'java'
 
 dependencies {
     implementation fileTree(dir: "libs", include: ["*.jar"])
+    implementation project(':common')
 }
 
 archivesBaseName = 'mindspore-lite-java'
diff --git a/mindspore/lite/java/java/linux_x86/src/main/java/com.mindspore.lite/LiteSession.java b/mindspore/lite/java/java/linux_x86/src/main/java/com.mindspore.lite/LiteSession.java
index 96526b0..301e14c 100644
--- a/mindspore/lite/java/java/linux_x86/src/main/java/com.mindspore.lite/LiteSession.java
+++ b/mindspore/lite/java/java/linux_x86/src/main/java/com.mindspore.lite/LiteSession.java
@@ -183,6 +183,23 @@ public class LiteSession {
         return this.setupVirtualBatch(this.sessionPtr, virtualBatchMultiplier, learningRate, momentum);
     }
 
+   public List<MSTensor> getFeaturesMap() {
+         List<Long> ret = this.getFeaturesMap(this.sessionPtr);
+                ArrayList<MSTensor> tensors = new ArrayList<MSTensor>();
+                for (Long msTensorAddr : ret) {
+                    MSTensor msTensor = new MSTensor(msTensorAddr);
+                    tensors.add(msTensor);
+                }
+                return tensors;
+   }
+   public boolean updateFeatures(List<MSTensor> features) {
+            long[] inputsArray = new long[features.size()];
+            for (int i = 0; i < features.size(); i++) {
+                inputsArray[i] = features.get(i).getMSTensorPtr();
+            }
+             return this.updateFeatures(this.sessionPtr, inputsArray);
+   }
+
     private native long createSession(long msConfigPtr);
 
     private native long createSessionWithModel(MappedByteBuffer buffer, long msConfigPtr);
@@ -224,4 +241,8 @@ public class LiteSession {
     private native boolean setLearningRate(long sessionPtr, float learning_rate);
 
     private native boolean setupVirtualBatch(long sessionPtr, int virtualBatchMultiplier, float learningRate, float momentum);
+
+    private native boolean updateFeatures(long sessionPtr, long[] newFeatures);
+
+    private native List<Long> getFeaturesMap(long sessionPtr);
 }
diff --git a/mindspore/lite/java/native/runtime/lite_session.cpp b/mindspore/lite/java/native/runtime/lite_session.cpp
index ae3bc74..d84453c 100644
--- a/mindspore/lite/java/native/runtime/lite_session.cpp
+++ b/mindspore/lite/java/native/runtime/lite_session.cpp
@@ -369,3 +369,45 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_setupV
   return (jboolean)(ret == mindspore::lite::RET_OK);
 }
 
+extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_updateFeatures(JNIEnv *env, jclass,
+                                                                                          jlong session_ptr,
+                                                                                          jlongArray features) {
+  jsize size = static_cast<int>(env->GetArrayLength(features));
+  jlong *input_data = env->GetLongArrayElements(features, nullptr);
+  std::vector<mindspore::tensor::MSTensor *> newFeatures;
+  for (int i = 0; i < size; ++i) {
+    auto *tensor_pointer = reinterpret_cast<void *>(input_data[i]);
+    if (tensor_pointer == nullptr) {
+      MS_LOGE("Tensor pointer from java is nullptr");
+      return false;
+    }
+    auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(tensor_pointer);
+    newFeatures.emplace_back(ms_tensor_ptr);
+  }
+  auto session = reinterpret_cast<mindspore::session::LiteSession *>(session_ptr);
+  auto ret = session->UpdateFeatureMaps(newFeatures);
+  return (jboolean)(ret == mindspore::lite::RET_OK);
+}
+
+extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getFeaturesMap(JNIEnv *env, jobject thiz,
+                                                                                         jlong session_ptr) {
+  jclass array_list = env->FindClass("java/util/ArrayList");
+  jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
+  jobject ret = env->NewObject(array_list, array_list_construct);
+  jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z");
+
+  jclass long_object = env->FindClass("java/lang/Long");
+  jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
+  auto *pointer = reinterpret_cast<void *>(session_ptr);
+  if (pointer == nullptr) {
+    MS_LOGE("Session pointer from java is nullptr");
+    return ret;
+  }
+  auto *train_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
+  auto inputs = train_session_ptr->GetFeatureMaps();
+  for (auto input : inputs) {
+    jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(input));
+    env->CallBooleanMethod(ret, array_list_add, tensor_addr);
+  }
+  return ret;
+}
diff --git a/mindspore/lite/java/native/runtime/ms_tensor.cpp b/mindspore/lite/java/native/runtime/ms_tensor.cpp
index 1c32245..6d1b8e1 100644
--- a/mindspore/lite/java/native/runtime/ms_tensor.cpp
+++ b/mindspore/lite/java/native/runtime/ms_tensor.cpp
@@ -246,3 +246,20 @@ extern "C" JNIEXPORT jstring JNICALL Java_com_mindspore_lite_MSTensor_tensorName
 
   return env->NewStringUTF(ms_tensor_ptr->tensor_name().c_str());
 }
+
+extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_MSTensor_createTensor(JNIEnv *env, jobject thiz,
+                                                                                 jstring tensor_name, jobject buffer) {
+  auto *p_data = reinterpret_cast<jbyte *>(env->GetDirectBufferAddress(buffer));  // get buffer pointer
+  jlong data_len = env->GetDirectBufferCapacity(buffer);                          // get buffer capacity
+  if (p_data == nullptr) {
+    MS_LOGE("GetDirectBufferAddress return null");
+    return false;
+  }
+  char *tensor_data(new char[data_len]);
+  memcpy(tensor_data, p_data, data_len);
+  int tensor_size = static_cast<jint>(data_len / sizeof(float));
+  std::vector<int> shape = {tensor_size};
+  auto tensor =
+    mindspore::tensor::MSTensor::CreateTensor(env->GetStringUTFChars(tensor_name, JNI_FALSE), mindspore::kNumberTypeFloat32, shape, tensor_data, tensor_size);
+  return jlong(tensor);
+}
diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc
index b4c947f..e0a4042 100644
--- a/mindspore/lite/src/train/train_session.cc
+++ b/mindspore/lite/src/train/train_session.cc
@@ -707,6 +707,41 @@ int TrainSession::Export(const std::string &file_name, ModelType model_type, Qua
   if (orig_train_state) Train();
   return status;
 }
+std::vector<tensor::MSTensor *> TrainSession::GetFeatureMaps() const {
+  std::vector<tensor::MSTensor *> features;
+  for (auto cur_tensor : this->tensors_) {
+    if (cur_tensor->IsConst() && cur_tensor->data_type() == kNumberTypeFloat32) {
+      features.push_back(cur_tensor);
+    }
+  }
+  return features;
+}
+
+int TrainSession::UpdateFeatureMaps(const std::vector<tensor::MSTensor *> &features_map) {
+  for (auto feature : features_map) {
+    bool find = false;
+    for (auto tensor : tensors_) {
+      if (!tensor->IsConst() || tensor->data_type() != kNumberTypeFloat32) {
+        continue;
+      }
+      if (feature->tensor_name() != tensor->tensor_name()) {
+        continue;
+      }
+      if (feature->Size() != tensor->Size()) {
+        MS_LOG(ERROR) << "feature name:" << feature->tensor_name() << ",len diff:"
+                      << "old is:" << tensor->Size() << "new is:" << feature->Size();
+        return RET_ERROR;
+      }
+      find = true;
+      memcpy(tensor->data(), feature->data(), tensor->Size());
+    }
+    if (!find) {
+      MS_LOG(ERROR) << "cannot find feature:" << feature->tensor_name() << ",update failed";
+      return RET_ERROR;
+    }
+  }
+  return RET_OK;
+}
 }  // namespace lite
 
 session::LiteSession *session::LiteSession::CreateTrainSession(const std::string &fn, const lite::Context *context,
diff --git a/mindspore/lite/src/train/train_session.h b/mindspore/lite/src/train/train_session.h
index 53de3d7..e712726 100644
--- a/mindspore/lite/src/train/train_session.h
+++ b/mindspore/lite/src/train/train_session.h
@@ -91,6 +91,10 @@ class TrainSession : virtual public lite::LiteSession {
   }
   int Export(const std::string &fb_name, ModelType model_type, QuantizationType quant_type, FormatType) override;
 
+  std::vector<tensor::MSTensor *> GetFeatureMaps() const override;
+
+  int UpdateFeatureMaps(const std::vector<tensor::MSTensor *> &features_map) override;
+
  protected:
   int AllocWorkSpace();
   bool IsLossKernel(const kernel::LiteKernel *kernel) const;
-- 
2.7.4

