<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

#Data process

In [None]:
From aabe6d5804e8b9fb86265f706fe0e7723221a935 Mon Sep 17 00:00:00 2001
From: zhengjun10 <zhengjun10@huawei.com>
Date: Thu, 8 Jul 2021 17:27:24 +0800
Subject: [PATCH] remove redundant load api

---
 include/api/serialization.h                 | 20 +++----
 include/api/types.h                         |  1 +
 mindspore/ccsrc/cxx_api/serialization.cc    | 23 +++++---
 mindspore/lite/src/cxx_api/serialization.cc | 82 +++++++++++++++++++++++------
 tests/ut/cpp/cxx_api/serialization_test.cc  | 16 ++----
 5 files changed, 93 insertions(+), 49 deletions(-)

diff --git a/include/api/serialization.h b/include/api/serialization.h
index 0aaa0f0..1826a9a 100644
--- a/include/api/serialization.h
+++ b/include/api/serialization.h
@@ -27,24 +27,24 @@
 #include "include/api/dual_abi_helper.h"
 
 namespace mindspore {
-using Key = struct Key {
+constexpr char kDecModeAesGcm[] = "AES-GCM";
+
+struct MS_API Key {
   const size_t max_key_len = 32;
   size_t len;
   unsigned char key[32];
   Key() : len(0) {}
+  Key(const char *dec_key, size_t key_len);
 };
 
 class MS_API Serialization {
  public:
-  static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph);
   inline static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
-                            const Key &dec_key, const std::string &dec_mode);
-  inline static Status Load(const std::string &file, ModelType model_type, Graph *graph);
-  inline static Status Load(const std::string &file, ModelType model_type, Graph *graph, const Key &dec_key,
-                            const std::string &dec_mode);
+                            const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm);
+  inline static Status Load(const std::string &file, ModelType model_type, Graph *graph, const Key &dec_key = {},
+                            const std::string &dec_mode = kDecModeAesGcm);
   inline static Status Load(const std::vector<std::string> &files, ModelType model_type, std::vector<Graph> *graphs,
-                            const Key &dec_key = {}, const std::string &dec_mode = "AES-GCM");
-  static Status LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters);
+                            const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm);
   static Status SetParameters(const std::map<std::string, Buffer> &parameters, Model *model);
   static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
   static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file);
@@ -64,10 +64,6 @@ Status Serialization::Load(const void *model_data, size_t data_size, ModelType m
   return Load(model_data, data_size, model_type, graph, dec_key, StringToChar(dec_mode));
 }
 
-Status Serialization::Load(const std::string &file, ModelType model_type, Graph *graph) {
-  return Load(StringToChar(file), model_type, graph);
-}
-
 Status Serialization::Load(const std::string &file, ModelType model_type, Graph *graph, const Key &dec_key,
                            const std::string &dec_mode) {
   return Load(StringToChar(file), model_type, graph, dec_key, StringToChar(dec_mode));
diff --git a/include/api/types.h b/include/api/types.h
index 6341993..ddd3556 100644
--- a/include/api/types.h
+++ b/include/api/types.h
@@ -36,6 +36,7 @@ enum ModelType : uint32_t {
   kAIR = 1,
   kOM = 2,
   kONNX = 3,
+  kFlatBuffer = 4,
   // insert new data type here
   kUnknownType = 0xFFFFFFFF
 };
diff --git a/mindspore/ccsrc/cxx_api/serialization.cc b/mindspore/ccsrc/cxx_api/serialization.cc
index 8c01426..3295912 100644
--- a/mindspore/ccsrc/cxx_api/serialization.cc
+++ b/mindspore/ccsrc/cxx_api/serialization.cc
@@ -79,8 +79,20 @@ static Buffer ReadFile(const std::string &file) {
   return buffer;
 }
 
-Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph) {
-  return Load(model_data, data_size, model_type, graph, Key{}, StringToChar("AES-GCM"));
+Key::Key(const char *dec_key, size_t key_len) {
+  len = 0;
+  if (key_len >= max_key_len) {
+    MS_LOG(ERROR) << "Invalid key len " << key_len << " is more than max key len " << max_key_len;
+    return;
+  }
+
+  auto sec_ret = memcpy_s(key, max_key_len, dec_key, key_len);
+  if (sec_ret != EOK) {
+    MS_LOG(ERROR) << "memcpy_s failed, src_len = " << key_len << ", dst_len = " << max_key_len << ", ret = " << sec_ret;
+    return;
+  }
+
+  len = key_len;
 }
 
 Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
@@ -137,7 +149,7 @@ Status Serialization::Load(const void *model_data, size_t data_size, ModelType m
 }
 
 Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph) {
-  return Load(file, model_type, graph, Key{}, StringToChar("AES-GCM"));
+  return Load(file, model_type, graph, Key{},  StringToChar(kDecModeAesGcm));
 }
 
 Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph, const Key &dec_key,
@@ -256,11 +268,6 @@ Status Serialization::Load(const std::vector<std::vector<char>> &files, ModelTyp
   return Status(kMEInvalidInput, err_msg.str());
 }
 
-Status Serialization::LoadCheckPoint(const std::string &, std::map<std::string, Buffer> *) {
-  MS_LOG(ERROR) << "Unsupported feature.";
-  return kMEFailed;
-}
-
 Status Serialization::SetParameters(const std::map<std::string, Buffer> &, Model *) {
   MS_LOG(ERROR) << "Unsupported feature.";
   return kMEFailed;
diff --git a/mindspore/lite/src/cxx_api/serialization.cc b/mindspore/lite/src/cxx_api/serialization.cc
index ed88bea..6bcc18e 100644
--- a/mindspore/lite/src/cxx_api/serialization.cc
+++ b/mindspore/lite/src/cxx_api/serialization.cc
@@ -17,14 +17,42 @@
 #include "include/api/serialization.h"
 #include <algorithm>
 #include <queue>
+#include <set>
 #include "include/api/graph.h"
+#include "include/api/context.h"
 #include "include/api/types.h"
 #include "include/model.h"
+#include "include/ms_tensor.h"
 #include "src/cxx_api/graph/graph_data.h"
+#include "src/cxx_api/model/model_impl.h"
+#include "src/cxx_api/converters.h"
 #include "src/common/log_adapter.h"
+#include "securec/include/securec.h"
 
 namespace mindspore {
-Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph) {
+Key::Key(const char *dec_key, size_t key_len) {
+  len = 0;
+  if (key_len >= max_key_len) {
+    MS_LOG(ERROR) << "Invalid key len " << key_len << " is more than max key len " << max_key_len;
+    return;
+  }
+
+  auto sec_ret = memcpy_s(key, max_key_len, dec_key, key_len);
+  if (sec_ret != EOK) {
+    MS_LOG(ERROR) << "memcpy_s failed, src_len = " << key_len << ", dst_len = " << max_key_len << ", ret = " << sec_ret;
+    return;
+  }
+
+  len = key_len;
+}
+
+Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
+                           const Key &dec_key, const std::vector<char> &dec_mode) {
+  if (dec_key.len != 0 || CharToString(dec_mode) != kDecModeAesGcm) {
+    MS_LOG(ERROR) << "Unsupported Feature.";
+    return kLiteError;
+  }
+
   if (model_data == nullptr) {
     MS_LOG(ERROR) << "model data is nullptr.";
     return kLiteNullptr;
@@ -37,6 +65,7 @@ Status Serialization::Load(const void *model_data, size_t data_size, ModelType m
     MS_LOG(ERROR) << "Unsupported IR.";
     return kLiteInputParamInvalid;
   }
+
   auto model = std::shared_ptr<lite::Model>(lite::Model::Import(static_cast<const char *>(model_data), data_size));
   if (model == nullptr) {
     MS_LOG(ERROR) << "New model failed.";
@@ -51,28 +80,47 @@ Status Serialization::Load(const void *model_data, size_t data_size, ModelType m
   return kSuccess;
 }
 
-Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
-                           const Key &dec_key, const std::vector<char> &dec_mode) {
-  MS_LOG(ERROR) << "Unsupported Feature.";
-  return kLiteError;
-}
+Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph, const Key &dec_key,
+                           const std::vector<char> &dec_mode) {
+  if (dec_key.len != 0 || CharToString(dec_mode) != kDecModeAesGcm) {
+    MS_LOG(ERROR) << "Unsupported Feature.";
+    return kLiteError;
+  }
 
-Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph) {
-  MS_LOG(ERROR) << "Unsupported Feature.";
-  return kLiteError;
+  if (graph == nullptr) {
+    MS_LOG(ERROR) << "graph is nullptr.";
+    return kLiteNullptr;
+  }
+  if (model_type != kFlatBuffer) {
+    MS_LOG(ERROR) << "Unsupported IR.";
+    return kLiteInputParamInvalid;
+  }
+
+  std::string filename = file.data();
+  if (filename.substr(filename.find_last_of(".") + 1) != "ms") {
+    filename = filename + ".ms";
+  }
+
+  auto model = std::shared_ptr<lite::Model>(lite::Model::Import(filename.c_str()));
+  if (model == nullptr) {
+    MS_LOG(ERROR) << "New model failed.";
+    return kLiteNullptr;
+  }
+  auto graph_data = std::shared_ptr<Graph::GraphData>(new (std::nothrow) Graph::GraphData(model));
+  if (graph_data == nullptr) {
+    MS_LOG(ERROR) << "New graph data failed.";
+    return kLiteMemoryFailed;
+  }
+  *graph = Graph(graph_data);
+  return kSuccess;
 }
 
-Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph, const Key &dec_key,
-                           const std::vector<char> &dec_mode) {
+Status Serialization::Load(const std::vector<std::vector<char>> &files, ModelType model_type,
+                           std::vector<Graph> *graphs, const Key &dec_key, const std::vector<char> &dec_mode) {
   MS_LOG(ERROR) << "Unsupported Feature.";
   return kLiteError;
 }
 
-Status Serialization::LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters) {
-  MS_LOG(ERROR) << "Unsupported feature.";
-  return kMEFailed;
-}
-
 Status Serialization::SetParameters(const std::map<std::string, Buffer> &parameters, Model *model) {
   MS_LOG(ERROR) << "Unsupported feature.";
   return kMEFailed;
@@ -87,4 +135,4 @@ Status Serialization::ExportModel(const Model &model, ModelType model_type, cons
   MS_LOG(ERROR) << "Unsupported feature.";
   return kMEFailed;
 }
-}  // namespace mindspore
+}  // namespace mindspore
\ No newline at end of file
diff --git a/tests/ut/cpp/cxx_api/serialization_test.cc b/tests/ut/cpp/cxx_api/serialization_test.cc
index 61cbad2..5e50368 100644
--- a/tests/ut/cpp/cxx_api/serialization_test.cc
+++ b/tests/ut/cpp/cxx_api/serialization_test.cc
@@ -46,11 +46,8 @@ TEST_F(TestCxxApiSerialization, test_load_file_not_exist_FAILED) {
 TEST_F(TestCxxApiSerialization, test_load_encrpty_mindir_SUCCESS) {
   Graph graph;
   std::string key_str = "0123456789ABCDEF";
-  Key key;
-  memcpy(key.key, key_str.c_str(), key_str.size());
-  key.len = key_str.size();
   ASSERT_TRUE(Serialization::Load("./data/mindir/add_encrpty_key_0123456789ABCDEF.mindir", ModelType::kMindIR, &graph,
-                                  key, "AES-GCM") == kSuccess);
+                                  Key(key_str.c_str(), key_str.size()), kDecModeAesGcm) == kSuccess);
 }
 
 TEST_F(TestCxxApiSerialization, test_load_encrpty_mindir_without_key_FAILED) {
@@ -65,21 +62,16 @@ TEST_F(TestCxxApiSerialization, test_load_encrpty_mindir_without_key_FAILED) {
 TEST_F(TestCxxApiSerialization, test_load_encrpty_mindir_with_wrong_key_FAILED) {
   Graph graph;
   std::string key_str = "WRONG_KEY";
-  Key key;
-  memcpy(key.key, key_str.c_str(), key_str.size());
-  key.len = key_str.size();
   auto status = Serialization::Load("./data/mindir/add_encrpty_key_0123456789ABCDEF.mindir", ModelType::kMindIR, &graph,
-                                    key, "AES-GCM");
+                                    Key(key_str.c_str(), key_str.size()), kDecModeAesGcm);
   ASSERT_TRUE(status != kSuccess);
 }
 
 TEST_F(TestCxxApiSerialization, test_load_no_encrpty_mindir_with_wrong_key_FAILED) {
   Graph graph;
   std::string key_str = "WRONG_KEY";
-  Key key;
-  memcpy(key.key, key_str.c_str(), key_str.size());
-  key.len = key_str.size();
-  auto status = Serialization::Load("./data/mindir/add_no_encrpty.mindir", ModelType::kMindIR, &graph, key, "AES-GCM");
+  auto status = Serialization::Load("./data/mindir/add_no_encrpty.mindir", ModelType::kMindIR, &graph,
+                                  Key(key_str.c_str(), key_str.size()), kDecModeAesGcm);
   ASSERT_TRUE(status != kSuccess);
 }
 
-- 
2.7.4

