<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

#Data process

In [None]:
From bf0a70e426ec66139dbb82a511af50e3fbef5350 Mon Sep 17 00:00:00 2001
From: zhengjun10 <zhengjun10@huawei.com>
Date: Wed, 7 Jul 2021 17:18:33 +0800
Subject: [PATCH] fix unify lite session train and inference

---
 .../lite/examples/train_lenet/src/net_runner.cc    |  3 +-
 .../examples/transfer_learning/src/net_runner.cc   |  3 +-
 mindspore/lite/include/lite_session.h              | 23 ----------
 mindspore/lite/include/train/train_session.h       | 49 ++++++++++++++++++++++
 mindspore/lite/src/train/train_session.cc          |  2 +-
 mindspore/lite/src/train/train_session.h           |  1 +
 mindspore/lite/src/train/transfer_session.cc       |  2 +-
 .../runtime/kernel/arm/fp32_grad/network_test.cc   |  7 ++--
 mindspore/lite/tools/benchmark_train/net_train.cc  |  5 ++-
 9 files changed, 63 insertions(+), 32 deletions(-)
 create mode 100644 mindspore/lite/include/train/train_session.h

diff --git a/mindspore/lite/examples/train_lenet/src/net_runner.cc b/mindspore/lite/examples/train_lenet/src/net_runner.cc
index 4143c72..6cbfa8b 100644
--- a/mindspore/lite/examples/train_lenet/src/net_runner.cc
+++ b/mindspore/lite/examples/train_lenet/src/net_runner.cc
@@ -29,6 +29,7 @@
 #include "include/train/ckpt_saver.h"
 #include "include/train/lr_scheduler.h"
 #include "include/train/accuracy_metrics.h"
+#include "include/train/train_session.h"
 #include "include/train/classification_train_accuracy_monitor.h"
 #include "src/utils.h"
 #include "include/dataset/datasets.h"
@@ -142,7 +143,7 @@ void NetRunner::InitAndFigureInputs() {
   context.device_list_[0].device_type_ = mindspore::lite::DT_CPU;
   context.thread_num_ = 2;
 
-  session_ = mindspore::session::LiteSession::CreateTrainSession(ms_file_, &context, true);
+  session_ = mindspore::session::TrainSession::CreateTrainSession(ms_file_, &context, true);
   MS_ASSERT(session_ != nullptr);
 
   session_->SetupVirtualBatch(virtual_batch_);
diff --git a/mindspore/lite/examples/transfer_learning/src/net_runner.cc b/mindspore/lite/examples/transfer_learning/src/net_runner.cc
index 3e5e87b..3cc4086 100644
--- a/mindspore/lite/examples/transfer_learning/src/net_runner.cc
+++ b/mindspore/lite/examples/transfer_learning/src/net_runner.cc
@@ -25,6 +25,7 @@
 #include <iostream>
 #include "include/context.h"
 #include "include/lite_session.h"
+#include "include/train/train_session.h"
 #include "src/utils.h"
 
 static unsigned int seed = time(NULL);
@@ -77,7 +78,7 @@ void NetRunner::InitAndFigureInputs() {
   context.device_list_[0].device_info_.cpu_device_info_.enable_float16_ = enable_fp16_;
   context.thread_num_ = 1;
 
-  session_ = mindspore::session::LiteSession::CreateTransferSession(ms_backbone_file_, ms_head_file_, &context);
+  session_ = mindspore::session::TrainSession::CreateTransferSession(ms_backbone_file_, ms_head_file_, &context);
   MS_ASSERT(session_ != nullptr);
 
   auto inputs = session_->GetInputs();
diff --git a/mindspore/lite/include/lite_session.h b/mindspore/lite/include/lite_session.h
index 4733107..6827a13 100644
--- a/mindspore/lite/include/lite_session.h
+++ b/mindspore/lite/include/lite_session.h
@@ -128,29 +128,6 @@ class MS_API LiteSession {
   /// \return STATUS as an error code of resize inputs, STATUS is defined in errorcode.h.
   virtual int Resize(const Vector<tensor::MSTensor *> &inputs, const Vector<Vector<int>> &dims) = 0;
 
-  /// \brief Static method to create a TrainSession object
-  ///
-  /// \param[in] filename name of flatbuffer that holds the flatbuffer
-  /// \param[in] context Defines the context of the session to be created
-  /// \param[in] train_mode training mode to initialize Session with
-  /// \param[in] cfg training configuration, set to null for default configuration
-  ///
-  /// \return Pointer of MindSpore LiteSession
-  static LiteSession *CreateTrainSession(const std::string &filename, const lite::Context *context,
-                                         bool train_mode = false, const lite::TrainCfg *cfg = nullptr);
-
-  /// \brief Static method to create a TransferSession object
-  ///
-  /// \param[in] filename_backbone Filename to read backbone net flatbuffer from
-  /// \param[in] filename_head Filename to read head net flatbuffer from
-  /// \param[in] context Defines the context of the session to be created
-  /// \param[in] train_mode training mode to initialize Session with
-  ///
-  /// \return Pointer of MindSpore LiteSession
-  static LiteSession *CreateTransferSession(const std::string &filename_backbone, const std::string &filename_head,
-                                            const lite::Context *context, bool train_mode = false,
-                                            const lite::TrainCfg *cfg = nullptr);
-
   /// \brief Set model to train mode
   /// \return STATUS as an error code of compiling graph, STATUS is defined in errorcode.h
   virtual int Train() { return mindspore::lite::RET_ERROR; }
diff --git a/mindspore/lite/include/train/train_session.h b/mindspore/lite/include/train/train_session.h
new file mode 100644
index 0000000..44534a5
--- /dev/null
+++ b/mindspore/lite/include/train/train_session.h
@@ -0,0 +1,49 @@
+/**
+ * Copyright 2020 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_SESSION_H_
+#define MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_SESSION_H_
+#include "include/lite_session.h"
+
+namespace mindspore {
+namespace session {
+class TrainSession {
+ public:
+  /// \brief Static method to create a TransferSession object
+  ///
+  /// \param[in] filename_backbone Filename to read backbone net flatbuffer from
+  /// \param[in] filename_head Filename to read head net flatbuffer from
+  /// \param[in] context Defines the context of the session to be created
+  /// \param[in] train_mode training mode to initialize Session with
+  ///
+  /// \return Pointer of MindSpore LiteSession
+  static LiteSession *CreateTransferSession(const std::string &filename_backbone, const std::string &filename_head,
+                                            const lite::Context *context, bool train_mode = false,
+                                            const lite::TrainCfg *cfg = nullptr);
+
+  /// \brief Static method to create a TrainSession object
+  ///
+  /// \param[in] filename name of flatbuffer that holds the flatbuffer
+  /// \param[in] context Defines the context of the session to be created
+  /// \param[in] train_mode training mode to initialize Session with
+  /// \param[in] cfg training configuration, set to null for default configuration
+  ///
+  /// \return Pointer of MindSpore LiteSession
+  static LiteSession *CreateTrainSession(const std::string &filename, const lite::Context *context,
+                                         bool train_mode = false, const lite::TrainCfg *cfg = nullptr);
+};
+}  // namespace session
+}  // namespace mindspore
+#endif  // MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_SESSION_H_
diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc
index 9adafd6..014d8a2 100644
--- a/mindspore/lite/src/train/train_session.cc
+++ b/mindspore/lite/src/train/train_session.cc
@@ -747,7 +747,7 @@ int TrainSession::UpdateFeatureMaps(const std::vector<tensor::MSTensor *> &featu
 }
 }  // namespace lite
 
-session::LiteSession *session::LiteSession::CreateTrainSession(const std::string &fn, const lite::Context *context,
+session::LiteSession *session::TrainSession::CreateTrainSession(const std::string &fn, const lite::Context *context,
                                                                bool train_mode, const lite::TrainCfg *cfg) {
   auto session = std::make_unique<lite::TrainSession>();
   if (session == nullptr) {
diff --git a/mindspore/lite/src/train/train_session.h b/mindspore/lite/src/train/train_session.h
index 929ad78..2f8f2a1 100644
--- a/mindspore/lite/src/train/train_session.h
+++ b/mindspore/lite/src/train/train_session.h
@@ -22,6 +22,7 @@
 #include <memory>
 #include <map>
 #include "include/train/train_cfg.h"
+#include "include/train/train_session.h"
 #include "src/lite_session.h"
 
 /*
diff --git a/mindspore/lite/src/train/transfer_session.cc b/mindspore/lite/src/train/transfer_session.cc
index b132831..ecaf2ac 100644
--- a/mindspore/lite/src/train/transfer_session.cc
+++ b/mindspore/lite/src/train/transfer_session.cc
@@ -290,7 +290,7 @@ static session::LiteSession *CreateTransferSessionInt(const char *model_buf_back
   return session;
 }
 
-session::LiteSession *session::LiteSession::CreateTransferSession(const std::string &filename_backbone,
+session::LiteSession *session::TrainSession::CreateTransferSession(const std::string &filename_backbone,
                                                                   const std::string &filename_head,
                                                                   const lite::Context *ctxt, bool train_mode,
                                                                   const lite::TrainCfg *cfg) {
diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc
index f2a209a..89de284 100644
--- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc
+++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc
@@ -28,6 +28,7 @@
 #include "include/context.h"
 #include "include/errorcode.h"
 #include "include/train/train_cfg.h"
+#include "include/train/train_session.h"
 #include "src/common/log_adapter.h"
 #include "src/common/file_utils.h"
 #include "src/kernel_registry.h"
@@ -102,7 +103,7 @@ TEST_F(NetworkTest, efficient_net) {
   context->thread_num_ = 1;
 
   std::string net = "./test_data/nets/effnetb0_fwd_nofuse.ms";
-  auto session = session::LiteSession::CreateTrainSession(net, context, false);
+  auto session = session::TrainSession::CreateTrainSession(net, context, false);
   ASSERT_NE(session, nullptr);
 
   std::string in = "./test_data/nets/effNet_input_x_1_3_224_224.bin";
@@ -150,7 +151,7 @@ TEST_F(NetworkTest, noname) {
 
   lite::TrainCfg cfg;
   cfg.loss_name_ = "nhwc";
-  auto session = mindspore::session::LiteSession::CreateTrainSession(net, &context, true, &cfg);
+  auto session = mindspore::session::TrainSession::CreateTrainSession(net, &context, true, &cfg);
   ASSERT_NE(session, nullptr);
   auto tensors_map = session->GetOutputs();
   auto tensor_names = session->GetOutputTensorNames();
@@ -170,7 +171,7 @@ TEST_F(NetworkTest, setname) {
   lite::TrainCfg train_cfg;
   train_cfg.loss_name_ = "nhwc";
 
-  auto session = mindspore::session::LiteSession::CreateTrainSession(net, &context, true, &train_cfg);
+  auto session = mindspore::session::TrainSession::CreateTrainSession(net, &context, true, &train_cfg);
   ASSERT_NE(session, nullptr);
 
   auto tensors_map = session->GetOutputs();
diff --git a/mindspore/lite/tools/benchmark_train/net_train.cc b/mindspore/lite/tools/benchmark_train/net_train.cc
index 39578c5..4f7c2b4 100644
--- a/mindspore/lite/tools/benchmark_train/net_train.cc
+++ b/mindspore/lite/tools/benchmark_train/net_train.cc
@@ -30,6 +30,7 @@
 #include "include/version.h"
 #include "include/model.h"
 #include "include/train/train_cfg.h"
+#include "include/train/train_session.h"
 
 namespace mindspore {
 namespace lite {
@@ -338,7 +339,7 @@ int NetTrain::CreateAndRunNetwork(const std::string &filename, const std::string
       MS_LOG(INFO) << "CreateTransferSession from models files" << filename << " and " << bb_filename;
       std::cout << "CreateTranferSession from model file " << filename << " and " << bb_filename << std::endl;
       session = std::unique_ptr<session::LiteSession>(
-        session::LiteSession::CreateTransferSession(bb_filename, filename, &context, true, &train_cfg));
+        session::TrainSession::CreateTransferSession(bb_filename, filename, &context, true, &train_cfg));
       if (session == nullptr) {
         MS_LOG(ERROR) << "RunNetTrain CreateTranferSession failed while running " << model_name.c_str();
         std::cout << "RunNetTrain CreateTranferSession failed while running " << model_name.c_str() << std::endl;
@@ -349,7 +350,7 @@ int NetTrain::CreateAndRunNetwork(const std::string &filename, const std::string
       MS_LOG(INFO) << "CreateTrainSession from model file" << filename.c_str();
       std::cout << "CreateTrainSession from model file " << filename.c_str() << std::endl;
       session = std::unique_ptr<session::LiteSession>(
-        session::LiteSession::CreateTrainSession(filename, &context, true, &train_cfg));
+        session::TrainSession::CreateTrainSession(filename, &context, true, &train_cfg));
       if (session == nullptr) {
         MS_LOG(ERROR) << "RunNetTrain CreateTrainSession failed while running " << model_name.c_str();
         std::cout << "RunNetTrain CreateTrainSession failed while running " << model_name.c_str() << std::endl;
-- 
2.7.4

