<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

#Data process

In [None]:
From a7f4e9ec35ff3e7aaf9c48383dbc9aaa831a706e Mon Sep 17 00:00:00 2001
From: zhengjun10 <zhengjun10@huawei.com>
Date: Thu, 15 Jul 2021 10:22:04 +0800
Subject: [PATCH] fix java inference bug

---
 .../main/java/com/mindspore/lite/LiteSession.java  | 16 +++------
 .../main/java/com/mindspore/lite/TrainSession.java | 39 ++++++++++++++++++++++
 .../main/java/com.mindspore.lite/LiteSession.java  | 16 +++------
 .../main/java/com.mindspore.lite/TrainSession.java | 39 ++++++++++++++++++++++
 mindspore/lite/java/native/CMakeLists.txt          | 22 +++++-------
 5 files changed, 95 insertions(+), 37 deletions(-)
 create mode 100644 mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/TrainSession.java
 create mode 100644 mindspore/lite/java/java/linux_x86/src/main/java/com.mindspore.lite/TrainSession.java

diff --git a/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/LiteSession.java b/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/LiteSession.java
index 331da09..7088b77 100644
--- a/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/LiteSession.java
+++ b/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/LiteSession.java
@@ -63,20 +63,14 @@ public class LiteSession {
         }
     }
 
-    public static LiteSession createTrainSession(String modelName, final MSConfig config, boolean trainMode) {
-        LiteSession liteSession = new LiteSession();
-        liteSession.sessionPtr = liteSession.createTrainSession(modelName, config.getMSConfigPtr(), trainMode, 0);
-        if (liteSession.sessionPtr == 0) {
-            return null;
-        } else {
-            return liteSession;
-        }
-    }
-
     public long getSessionPtr() {
         return sessionPtr;
     }
 
+   public void setSessionPtr(long sessionPtr) {
+        this.sessionPtr = sessionPtr;
+    }
+
     public void bindThread(boolean ifBind) {
         this.bindThread(this.sessionPtr, ifBind);
     }
@@ -204,8 +198,6 @@ public class LiteSession {
 
     private native long createSessionWithModel(MappedByteBuffer buffer, long msConfigPtr);
 
-    private native long createTrainSession(String filename, long msContextPtr, boolean trainMode, long msTrainCfgPtr);
-
     private native boolean compileGraph(long sessionPtr, long modelPtr);
 
     private native void bindThread(long sessionPtr, boolean ifBind);
diff --git a/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/TrainSession.java b/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/TrainSession.java
new file mode 100644
index 0000000..13af2b4
--- /dev/null
+++ b/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/TrainSession.java
@@ -0,0 +1,39 @@
+/**
+ * Copyright 2021 Huawei Technologies Co., Ltd
+ * <p>
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.mindspore.lite;
+
+import com.mindspore.lite.LiteSession;
+import com.mindspore.lite.config.MSConfig;
+
+public class TrainSession {
+    static {
+        System.loadLibrary("mindspore-lite-train-jni");
+    }
+    public static LiteSession createTrainSession(String modelName, final MSConfig config, boolean trainMode) {
+        LiteSession liteSession = new LiteSession();
+        long sessionPtr = createTrainSession(modelName, config.getMSConfigPtr(), trainMode, 0);
+        if (sessionPtr == 0) {
+            return null;
+        } else {
+             liteSession.setSessionPtr(sessionPtr);
+            return liteSession;
+        }
+    }
+
+    private static native long createTrainSession(String fileName, long msContextPtr, boolean trainMode,
+                                                  long msTrainCfgPtr);
+}
diff --git a/mindspore/lite/java/java/linux_x86/src/main/java/com.mindspore.lite/LiteSession.java b/mindspore/lite/java/java/linux_x86/src/main/java/com.mindspore.lite/LiteSession.java
index a49e940..859c021 100644
--- a/mindspore/lite/java/java/linux_x86/src/main/java/com.mindspore.lite/LiteSession.java
+++ b/mindspore/lite/java/java/linux_x86/src/main/java/com.mindspore.lite/LiteSession.java
@@ -63,20 +63,14 @@ public class LiteSession {
         }
     }
 
-    public static LiteSession createTrainSession(String modelName, final MSConfig config, boolean trainMode) {
-        LiteSession liteSession = new LiteSession();
-        liteSession.sessionPtr = liteSession.createTrainSession(modelName, config.getMSConfigPtr(), trainMode, 0);
-        if (liteSession.sessionPtr == 0) {
-            return null;
-        } else {
-            return liteSession;
-        }
-    }
-
     public long getSessionPtr() {
         return sessionPtr;
     }
 
+    public void setSessionPtr(long sessionPtr) {
+        this.sessionPtr = sessionPtr;
+    }
+
     public void bindThread(boolean ifBind) {
         this.bindThread(this.sessionPtr, ifBind);
     }
@@ -204,8 +198,6 @@ public class LiteSession {
 
     private native long createSessionWithModel(MappedByteBuffer buffer, long msConfigPtr);
 
-    private native long createTrainSession(String fileName, long msContextPtr, boolean trainMode, long msTrainCfgPtr);
-
     private native boolean compileGraph(long sessionPtr, long modelPtr);
 
     private native void bindThread(long sessionPtr, boolean ifBind);
diff --git a/mindspore/lite/java/java/linux_x86/src/main/java/com.mindspore.lite/TrainSession.java b/mindspore/lite/java/java/linux_x86/src/main/java/com.mindspore.lite/TrainSession.java
new file mode 100644
index 0000000..13af2b4
--- /dev/null
+++ b/mindspore/lite/java/java/linux_x86/src/main/java/com.mindspore.lite/TrainSession.java
@@ -0,0 +1,39 @@
+/**
+ * Copyright 2021 Huawei Technologies Co., Ltd
+ * <p>
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.mindspore.lite;
+
+import com.mindspore.lite.LiteSession;
+import com.mindspore.lite.config.MSConfig;
+
+public class TrainSession {
+    static {
+        System.loadLibrary("mindspore-lite-train-jni");
+    }
+    public static LiteSession createTrainSession(String modelName, final MSConfig config, boolean trainMode) {
+        LiteSession liteSession = new LiteSession();
+        long sessionPtr = createTrainSession(modelName, config.getMSConfigPtr(), trainMode, 0);
+        if (sessionPtr == 0) {
+            return null;
+        } else {
+             liteSession.setSessionPtr(sessionPtr);
+            return liteSession;
+        }
+    }
+
+    private static native long createTrainSession(String fileName, long msContextPtr, boolean trainMode,
+                                                  long msTrainCfgPtr);
+}
diff --git a/mindspore/lite/java/native/CMakeLists.txt b/mindspore/lite/java/native/CMakeLists.txt
index 1696c0b..f2b3990 100644
--- a/mindspore/lite/java/native/CMakeLists.txt
+++ b/mindspore/lite/java/native/CMakeLists.txt
@@ -92,12 +92,6 @@ set(JNI_SRC
 
 set(LITE_SO_NAME mindspore-lite)
 
-if(SUPPORT_TRAIN)
-  set(JNI_SRC
-          ${JNI_SRC}
-          ${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp
-  )
-endif()
 add_library(mindspore-lite-jni SHARED ${JNI_SRC})
 
 if(PLATFORM_ARM64 OR PLATFORM_ARM32)
@@ -108,13 +102,15 @@ else()
 endif()
 
 if(SUPPORT_TRAIN)
-  set(LITE_TRAIN_SO_NAME mindspore-lite-train minddata-lite)
-  if(PLATFORM_ARM64 OR PLATFORM_ARM32)
-    find_library(log-lib log)
-    target_link_libraries(mindspore-lite-jni ${LITE_TRAIN_SO_NAME} ${log-lib})
-  else()
-    target_link_libraries(mindspore-lite-jni ${LITE_TRAIN_SO_NAME})
-  endif()
+    set(LITE_TRAIN_SO_NAME mindspore-lite-train minddata-lite)
+    set(JNI_TRAIN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp)
+    add_library(mindspore-lite-train-jni SHARED ${JNI_TRAIN_SRC})
+    if(PLATFORM_ARM64 OR PLATFORM_ARM32)
+        find_library(log-lib log)
+        target_link_libraries(mindspore-lite-train-jni ${LITE_TRAIN_SO_NAME} ${log-lib})
+    else()
+        target_link_libraries(mindspore-lite-train-jni ${LITE_TRAIN_SO_NAME})
+    endif()
 endif()
 
 set(NDK_STRIP
-- 
2.7.4

