<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

#Data process

In [None]:
From 1625b35d55d1a7f5f69edb1a060f8372913b03e1 Mon Sep 17 00:00:00 2001
From: guohongzilong <guohongzilong@huawei.com>
Date: Fri, 12 Mar 2021 11:40:35 +0800
Subject: [PATCH] sync lite code

---
 .../lite/flclient/src/main/native/CMakeLists.txt   | 10 ++--
 .../flclient/src/main/native/src/lenet_train.cpp   |  3 +-
 mindspore/lite/include/train_session.h             | 12 +++++
 mindspore/lite/schema/ops.fbs                      | 17 ++++---
 mindspore/lite/src/train/train_session.cc          | 58 ++++++++++++++++++++++
 mindspore/lite/src/train/train_session.h           |  4 ++
 6 files changed, 89 insertions(+), 15 deletions(-)

diff --git a/mindspore/lite/flclient/src/main/native/CMakeLists.txt b/mindspore/lite/flclient/src/main/native/CMakeLists.txt
index 81a9ff9..1ee548e 100644
--- a/mindspore/lite/flclient/src/main/native/CMakeLists.txt
+++ b/mindspore/lite/flclient/src/main/native/CMakeLists.txt
@@ -12,12 +12,10 @@ set(LITE_DIR ${TOP_DIR}/mindspore/lite)
 set(MS_VERSION_MAJOR ${MS_VERSION_MAJOR})
 set(MS_VERSION_MINOR ${MS_VERSION_MINOR})
 set(MS_VERSION_REVISION ${MS_VERSION_REVISION})
-set(CMAKE_C_FLAGS
-    "${CMAKE_C_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} "
-    "-DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
-set(CMAKE_CXX_FLAGS
-      "${CMAKE_CXX_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} "
-      "-DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
+set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} \
+  -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} \
+  -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
 set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
 
 include_directories(${CMAKE_CURRENT_SOURCE_DIR})
diff --git a/mindspore/lite/flclient/src/main/native/src/lenet_train.cpp b/mindspore/lite/flclient/src/main/native/src/lenet_train.cpp
index 4455efb..86db7f3 100644
--- a/mindspore/lite/flclient/src/main/native/src/lenet_train.cpp
+++ b/mindspore/lite/flclient/src/main/native/src/lenet_train.cpp
@@ -21,6 +21,7 @@
 #include "include/context.h"
 #include "include/errorcode.h"
 #include "src/common/log_adapter.h"
+#include "limits.h"
 
 static char *fl_lenet_I0 = 0;
 static char *fl_lenet_I1 = 0;
@@ -42,7 +43,7 @@ std::vector<int> FillInputData(mindspore::session::TrainSession *train_session,
     } else {
       idx = rand_r(&seed_) % batch_num;
     }
-    std::memcpy(input_data + i * data_size, fl_lenet_I0 + idx * data_size, data_size);
+    std::memcpy(input_data + i * data_size, (float*)fl_lenet_I0 + idx * data_size, data_size*sizeof(float));
     int label_idx = *(reinterpret_cast<int *>(fl_lenet_I1) + idx);
     labels[i * num_classes + label_idx] = 1.0;  // Model expects labels in onehot representation
     labels_vec.push_back(label_idx);
diff --git a/mindspore/lite/include/train_session.h b/mindspore/lite/include/train_session.h
index f5d3dfb..08e42a7 100644
--- a/mindspore/lite/include/train_session.h
+++ b/mindspore/lite/include/train_session.h
@@ -23,6 +23,13 @@
 namespace mindspore {
 namespace session {
 
+struct TrainFeatureParam{
+  char* name;
+  void *data;
+  size_t elenums;
+  enum TypeId type;
+};
+
 /// \brief TrainSession Defines a class that allows training a MindSpore model
 class TrainSession : public session::LiteSession {
  public:
@@ -137,6 +144,11 @@ class TrainSession : public session::LiteSession {
   /// \param[in] loss_name Identifucation name for loss kernels
   void SetLossName(std::string loss_name) { loss_name_ = loss_name; }
 
+  virtual int GetFeatureMaps(std::vector<mindspore::session::TrainFeatureParam *>* feature_maps) =0;
+
+  virtual int UpdateFeatureMaps(const std::string &update_ms_file,
+                                TrainFeatureParam* new_features,int size) =0;
+
  protected:
   bool train_mode_ = false;
   std::string get_loss_name() const { return loss_name_; }
diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs
index d07261e..0916185 100644
--- a/mindspore/lite/schema/ops.fbs
+++ b/mindspore/lite/schema/ops.fbs
@@ -953,6 +953,14 @@ table StridedSlice {
     shrink_axis_mask: long;
 }
 
+table StridedSliceGrad {
+    begin_mask: long;
+    end_mask: long;
+    ellipsis_mask: long;
+    new_axis_mask: long;
+    shrink_axis_mask: long;
+}
+
 table SubFusion {
     activation_type: ActivationType = 0;
 }
@@ -1056,14 +1064,6 @@ table CropAndResize {
 table Erf {
 }
 
-table StridedSliceGrad {
-    begin_mask: long;
-    end_mask: long;
-    ellipsis_mask: long;
-    new_axis_mask: long;
-    shrink_axis_mask: long;
-}
-
 table IsFinite {
 }
 
@@ -1077,3 +1077,4 @@ table UniformReal {
 
 table AbsGrad {
 }
+
diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc
index badd1fd..a3685cb 100644
--- a/mindspore/lite/src/train/train_session.cc
+++ b/mindspore/lite/src/train/train_session.cc
@@ -22,6 +22,7 @@
 #include <iostream>
 #include <fstream>
 #include <memory>
+#include <cstring>
 #include "include/errorcode.h"
 #include "src/common/utils.h"
 #include "src/tensor.h"
@@ -480,6 +481,63 @@ bool TrainSession::IsBN(kernel::LiteKernel *kernel) const {
           (kernel->Type() == schema::PrimitiveType_FusedBatchNorm));
 }
 
+int lite::TrainSession::GetFeatureMaps(std::vector<mindspore::session::TrainFeatureParam *> *feature_maps) {
+  for (auto tensor : this->tensors_) {
+    if (tensor->IsConst()) {
+      auto param = new mindspore::session::TrainFeatureParam();
+      int len = tensor->tensor_name().length();
+      char* name = nullptr;
+      if(len>0) {
+        name = new char[len+1];
+        memcpy(name, tensor->tensor_name().c_str(), len);
+        name[len] = 0;
+      }
+      param->name =  name;
+      param->data = new float[tensor->ElementsNum()];
+      memcpy(param->data, tensor->data_c(), tensor->ElementsNum()*sizeof(float));
+      param->data = tensor->data_c();
+      param->elenums = tensor->ElementsNum();
+      param->type = tensor->data_type();
+      feature_maps->push_back(param);
+    }
+  }
+  MS_LOG(INFO) << "get feature map success";
+  return RET_OK;
+}
+int lite::TrainSession::UpdateFeatureMaps(const std::string &update_ms_file,
+                                          mindspore::session::TrainFeatureParam* new_features,int size) {
+  std::vector<mindspore::session::TrainFeatureParam *> old_features;
+  auto status = GetFeatureMaps(&old_features);
+  if (status != RET_OK) {
+    MS_LOG(ERROR) << "get features map failed:";
+  }
+  for (int i=0;i<size;++i) {
+    mindspore::session::TrainFeatureParam* new_feature = new_features + i;
+    bool find = false;
+    for (auto old_feature : old_features) {
+      if (strcmp(old_feature->name, new_feature->name) == 0) {
+        if(old_feature->elenums != new_feature->elenums) {
+          MS_LOG(ERROR) << "feature name:"<<old_feature->name<<",len diff:"<<"old is:"<<old_feature->elenums<<"new is:"<<new_feature->elenums;
+          return RET_ERROR;
+        }
+        find = true;
+        memcpy(old_feature->data, new_feature->data, new_feature->elenums*sizeof(float));
+        break;
+      }
+    }
+    if (!find) {
+      MS_LOG(ERROR) << "cannot find feature:" << new_feature->name;
+      return RET_ERROR;
+    }
+  }
+  SaveToFile(update_ms_file);
+  for (auto feature : old_features) {
+    delete feature;
+  }
+  MS_LOG(INFO) << "update model:" << update_ms_file << ",feature map success";
+  return RET_OK;
+}
+
 }  // namespace lite
 
 session::TrainSession *session::TrainSession::CreateSession(const char *model_buf, size_t size, lite::Context *context,
diff --git a/mindspore/lite/src/train/train_session.h b/mindspore/lite/src/train/train_session.h
index 266baef..c69fa5d 100644
--- a/mindspore/lite/src/train/train_session.h
+++ b/mindspore/lite/src/train/train_session.h
@@ -91,6 +91,10 @@ class TrainSession : virtual public session::TrainSession, virtual public lite::
     return outputs;
   }
 
+  int GetFeatureMaps(std::vector<mindspore::session::TrainFeatureParam *> *feature_maps) override;
+  int UpdateFeatureMaps(const std::string &update_ms_file,
+                        mindspore::session::TrainFeatureParam* new_features,int size) override;
+
  protected:
   void AllocWorkSpace();
   bool IsLossKernel(const kernel::LiteKernel *kernel) const;
-- 
2.7.4

