<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

#Data process

In [None]:
diff --git a/mindspore/lite/examples/train_lenet/src/net_runner.cc b/mindspore/lite/examples/train_lenet/src/net_runner.cc
index a44551a..78ecb34 100644
--- a/mindspore/lite/examples/train_lenet/src/net_runner.cc
+++ b/mindspore/lite/examples/train_lenet/src/net_runner.cc
@@ -15,9 +15,9 @@
  */
 
 #include "src/net_runner.h"
-#include <math.h>
+#include <cmath>
 #include <getopt.h>
-#include <stdio.h>
+#include <cstdio>
 #include <cstring>
 #include <iostream>
 #include <fstream>
@@ -43,10 +43,20 @@ using mindspore::lite::AccuracyMetrics;
 using mindspore::session::TrainLoopCallBack;
 using mindspore::session::TrainLoopCallBackData;
 
+constexpr int kPrintNum = 10;
+constexpr float kScalePoint = 255.0f;
+constexpr int kBatchSize = 2;
+constexpr int kNCHWDims = 4;
+constexpr int kNCHWCDim = 2;
+constexpr int kPrintTimes = 100;
+constexpr int kSaveSteps = 1000;
+constexpr float kLearningRate = 0.7f;
 class Rescaler : public mindspore::session::TrainLoopCallBack {
  public:
   explicit Rescaler(float scale) : scale_(scale) {
-    if (scale_ == 0) scale_ = 1.0;
+    if (scale_ == 0) {
+      scale_ = 1.0;
+    }
   }
   ~Rescaler() override = default;
   void StepBegin(const mindspore::session::TrainLoopCallBackData &cb_data) override {
@@ -67,7 +77,7 @@ bool after_callback(const std::vector<mindspore::tensor::MSTensor *> &after_inpu
   for (size_t i = 0; i < after_inputs.size(); i++) {
     int num2p = (after_inputs.at(i)->ElementsNum());
     printf("in%zu(%d): ", i, num2p);
-    if (num2p > 10) num2p = 10;
+    if (num2p > kPrintNum) num2p = kPrintNum;
     if (after_inputs.at(i)->data_type() == mindspore::kNumberTypeInt32) {
       auto d = reinterpret_cast<int *>(after_inputs.at(i)->MutableData());
       for (int j = 0; j < num2p; j++) printf("%d, ", d[j]);
@@ -100,7 +110,7 @@ void NetRunner::InitAndFigureInputs() {
   context.thread_num_ = 2;
 
   session_ = mindspore::session::TrainSession::CreateSession(ms_file_, &context);
-  MS_ASSERT(nullptr != session_);
+  MS_ASSERT(session_ != nullptr);
   loop_ = mindspore::session::TrainLoop::CreateTrainLoop(session_);
 
   acc_metrics_ = std::shared_ptr<AccuracyMetrics>(new AccuracyMetrics);
@@ -110,10 +120,10 @@ void NetRunner::InitAndFigureInputs() {
   auto inputs = session_->GetInputs();
   MS_ASSERT(inputs.size() > 1);
   auto nhwc_input_dims = inputs.at(0)->shape();
-  MS_ASSERT(nhwc_input_dims.size() == 4);
+  MS_ASSERT(nhwc_input_dims.size() == kNCHWDims);
   batch_size_ = nhwc_input_dims.at(0);
   h_ = nhwc_input_dims.at(1);
-  w_ = nhwc_input_dims.at(2);
+  w_ = nhwc_input_dims.at(kNCHWCDim);
 }
 
 float NetRunner::CalculateAccuracy(int max_tests) {
@@ -126,7 +136,7 @@ float NetRunner::CalculateAccuracy(int max_tests) {
   test_ds_ = test_ds_->Map({&typecast}, {"label"});
   test_ds_ = test_ds_->Batch(batch_size_, true);
 
-  Rescaler rescale(255.0);
+  Rescaler rescale(kScalePoint);
 
   loop_->Eval(test_ds_.get(), std::vector<TrainLoopCallBack *>{&rescale});
   std::cout << "Eval Accuracy is " << acc_metrics_->Eval() << std::endl;
@@ -144,7 +154,7 @@ int NetRunner::InitDB() {
   TypeCast typecast("int32");
   train_ds_ = train_ds_->Map({&typecast}, {"label"});
 
-  train_ds_ = train_ds_->Shuffle(2);
+  train_ds_ = train_ds_->Shuffle(kBatchSize);
   train_ds_ = train_ds_->Batch(batch_size_, true);
 
   if (verbose_) {
@@ -159,13 +169,13 @@ int NetRunner::InitDB() {
 }
 
 int NetRunner::TrainLoop() {
-  struct mindspore::lite::StepLRLambda step_lr_lambda(1, 0.7);
+  struct mindspore::lite::StepLRLambda step_lr_lambda(1, kLearningRate);
   mindspore::lite::LRScheduler step_lr_sched(mindspore::lite::StepLRLambda, static_cast<void *>(&step_lr_lambda), 1);
 
-  mindspore::lite::LossMonitor lm(100);
+  mindspore::lite::LossMonitor lm(kPrintTimes);
   mindspore::lite::ClassificationTrainAccuracyMonitor am(1);
-  mindspore::lite::CkptSaver cs(1000, std::string("lenet"));
-  Rescaler rescale(255.0);
+  mindspore::lite::CkptSaver cs(kSaveSteps, std::string("lenet"));
+  Rescaler rescale(kScalePoint);
 
   loop_->Train(epochs_, train_ds_.get(), std::vector<TrainLoopCallBack *>{&rescale, &lm, &cs, &am, &step_lr_sched});
   return 0;
diff --git a/mindspore/lite/examples/transfer_learning/src/dataset.cc b/mindspore/lite/examples/transfer_learning/src/dataset.cc
index 7a0669b..2f8d3c2 100644
--- a/mindspore/lite/examples/transfer_learning/src/dataset.cc
+++ b/mindspore/lite/examples/transfer_learning/src/dataset.cc
@@ -50,6 +50,10 @@ float CH_MEAN[3] = {0.485, 0.456, 0.406};
 float CH_STD[3] = {0.229, 0.224, 0.225};
 
 using LabelId = std::map<std::string, int>;
+constexpr int kClassNum = 10;
+constexpr int kBGRDim = 2;
+constexpr float kRGBMAX = 255.0f;
+constexpr int kRGBDims = 3;
 
 static char *ReadBitmapFile(const std::string &filename, size_t *size) {
   MS_ASSERT(size != nullptr);
@@ -78,7 +82,7 @@ static char *ReadBitmapFile(const std::string &filename, size_t *size) {
 
   ifs.read(reinterpret_cast<char *>(bmp_image), bitmap_header.image_size_bytes);
 
-  size_t buffer_size = bitmap_header.width * bitmap_header.height * 3;
+  size_t buffer_size = bitmap_header.width * bitmap_header.height * kRGBDims;
   float *hwc_bin_image = new (std::nothrow) float[buffer_size];
   if (hwc_bin_image == nullptr) {
     free(bmp_image);
@@ -95,14 +99,16 @@ static char *ReadBitmapFile(const std::string &filename, size_t *size) {
   for (int h = 0; h < bitmap_header.height; h++) {
     for (int w = 0; w < bitmap_header.width; w++) {
       hwc_bin_image[h * hStride + w * channels + 0] =
-        (((static_cast<float>(bmp_image[(height - h - 1) * hStride + w * channels + 2])) / 255.0) - CH_MEAN[0]) /
+        (((static_cast<float>(bmp_image[(height - h - 1) * hStride + w * channels + kBGRDim])) / kRGBMAX) -
+         CH_MEAN[0]) /
         CH_STD[0];
       hwc_bin_image[h * hStride + w * channels + 1] =
-        (((static_cast<float>(bmp_image[(height - h - 1) * hStride + w * channels + 1])) / 255.0) - CH_MEAN[1]) /
+        (((static_cast<float>(bmp_image[(height - h - 1) * hStride + w * channels + 1])) / kRGBMAX) - CH_MEAN[1]) /
         CH_STD[1];
-      hwc_bin_image[h * hStride + w * channels + 2] =
-        (((static_cast<float>(bmp_image[(height - h - 1) * hStride + w * channels + 0])) / 255.0) - CH_MEAN[2]) /
-        CH_STD[2];
+      hwc_bin_image[h * hStride + w * channels + kBGRDim] =
+        (((static_cast<float>(bmp_image[(height - h - 1) * hStride + w * channels + 0])) / kRGBMAX) -
+         CH_MEAN[kBGRDim]) /
+        CH_STD[kBGRDim];
     }
   }
 
@@ -190,7 +196,7 @@ void DataSet::InitializeBMPFoldersDatabase(std::string dpath) {
 std::vector<FileTuple> DataSet::ReadDir(const std::string dpath) {
   std::vector<FileTuple> vec;
   struct dirent *entry = nullptr;
-  num_of_classes_ = 10;
+  num_of_classes_ = kClassNum;
   for (int class_id = 0; class_id < num_of_classes_; class_id++) {
     std::string dirname = dpath + "/" + std::to_string(class_id);
     DIR *dp = opendir(dirname.c_str());
diff --git a/mindspore/lite/examples/transfer_learning/src/net_runner.cc b/mindspore/lite/examples/transfer_learning/src/net_runner.cc
index cc0db4a..56fa14b 100644
--- a/mindspore/lite/examples/transfer_learning/src/net_runner.cc
+++ b/mindspore/lite/examples/transfer_learning/src/net_runner.cc
@@ -15,7 +15,7 @@
  */
 
 #include "src/net_runner.h"
-#include <math.h>
+#include <cmath>
 #include <getopt.h>
 #include <algorithm>
 #include <cstring>
@@ -25,6 +25,9 @@
 #include "src/utils.h"
 
 static unsigned int seed = time(NULL);
+constexpr int kBatchNum = 20;
+constexpr int kPrintNum = 10;
+constexpr float kThreshold = 0.9f;
 
 // Definition of callback function after forwarding operator.
 bool after_callback(const std::vector<mindspore::tensor::MSTensor *> &after_inputs,
@@ -34,7 +37,7 @@ bool after_callback(const std::vector<mindspore::tensor::MSTensor *> &after_inpu
   for (size_t i = 0; i < after_inputs.size(); i++) {
     int num2p = (after_inputs.at(i)->ElementsNum());
     std::cout << "in" << i << "(" << num2p << "): ";
-    if (num2p > 10) num2p = 10;
+    if (num2p > kPrintNum) num2p = kPrintNum;
     if (after_inputs.at(i)->data_type() == mindspore::kNumberTypeInt32) {
       auto d = reinterpret_cast<int *>(after_inputs.at(i)->MutableData());
       for (int j = 0; j < num2p; j++) {
@@ -52,7 +55,7 @@ bool after_callback(const std::vector<mindspore::tensor::MSTensor *> &after_inpu
     auto d = reinterpret_cast<float *>(after_outputs.at(i)->MutableData());
     int num2p = (after_outputs.at(i)->ElementsNum());
     std::cout << "ou" << i << "(" << num2p << "): ";
-    if (num2p > 10) num2p = 10;
+    if (num2p > kPrintNum) num2p = kPrintNum;
     for (int j = 0; j < num2p; j++) {
       std::cout << d[j] << ", ";
     }
@@ -71,7 +74,7 @@ void NetRunner::InitAndFigureInputs() {
   context.thread_num_ = 1;
 
   session_ = mindspore::session::TrainSession::CreateTransferSession(ms_backbone_file_, ms_head_file_, &context);
-  MS_ASSERT(nullptr != session_);
+  MS_ASSERT(session_ != nullptr);
 
   auto inputs = session_->GetInputs();
   MS_ASSERT(inputs.size() > 1);
@@ -107,7 +110,8 @@ std::vector<int> NetRunner::FillInputData(const std::vector<DataLabelTuple> &dat
   std::fill(labels, labels + inputs.at(label_index_)->ElementsNum(), 0.f);
   for (int i = 0; i < batch_size_; i++) {
     if (serially >= 0) {
-      idx = ++idx % total_size;
+      auto reminder = ++idx % total_size;
+      idx = reminder;
     } else {
       idx = rand_r(&seed) % total_size;
     }
@@ -190,12 +194,12 @@ int NetRunner::TrainLoop() {
       session_->SaveToFile(cpkt_fn);
     }
 
-    std::cout << i + 1 << ": Loss is " << loss << " [min=" << min_loss << "]" << std::endl;
-    if ((i + 1) % 20 == 0) {
+    std::cout << (i + 1) << ": Loss is " << loss << " [min=" << min_loss << "]" << std::endl;
+    if ((i + 1) % kBatchNum == 0) {
       float acc = CalculateAccuracy(ds_.test_data());
       if (max_acc < acc) max_acc = acc;
       std::cout << "accuracy on test data = " << acc << " max accuracy = " << max_acc << std::endl;
-      if (acc > 0.9) return 0;
+      if (acc > kThreshold) return 0;
     }
   }
   return 0;
diff --git a/mindspore/lite/src/dequant.cc b/mindspore/lite/src/dequant.cc
index 6987c9a..49ae0bb 100644
--- a/mindspore/lite/src/dequant.cc
+++ b/mindspore/lite/src/dequant.cc
@@ -38,7 +38,7 @@ float *DequantUtil::DequantWeight(lite::Tensor *input_tensor, bool channel_first
   }
 }
 
-int DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_int_data) {
+int DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_int_data, int data_len) {
   MS_ASSERT(input_tensor != nullptr);
   MS_ASSERT(unpack_int_data != nullptr);
   auto quant_params = input_tensor->quantParams();
@@ -50,7 +50,7 @@ int DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_in
   if (enable_huffman_code) {
     std::string encode_str(input_tensor->data()->begin(), input_tensor->data()->end());
     auto huffman_decode = std::make_unique<lite::HuffmanDecode>();
-    auto ret = huffman_decode->DoHuffmanDecode(encode_str, unpack_int_data);
+    auto ret = huffman_decode->DoHuffmanDecode(encode_str, unpack_int_data, data_len);
     if (ret != RET_OK) {
       MS_LOG(ERROR) << "DoHuffmanDecode failed.";
       return ret;
@@ -121,5 +121,4 @@ void DequantUtil::RestoreTensorData(const std::map<Tensor *, std::pair<TypeId, v
     tensor->set_data(data);
   }
 }
-
 }  // namespace mindspore::lite
diff --git a/mindspore/lite/src/dequant.h b/mindspore/lite/src/dequant.h
index 98072bf..3385f7f 100644
--- a/mindspore/lite/src/dequant.h
+++ b/mindspore/lite/src/dequant.h
@@ -31,7 +31,7 @@ class DequantUtil {
  public:
   static float *DequantWeight(lite::Tensor *input_tensor, bool);
 
-  static int UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data);
+  static int UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data, int weight_len);
 
   static std::map<Tensor *, std::pair<TypeId, void *>> DequantTensor(OpParameter *op_param,
                                                                      const std::vector<Tensor *> &in_tensors,
diff --git a/mindspore/lite/src/huffman_decode.cc b/mindspore/lite/src/huffman_decode.cc
index 8432571..dd173d4 100644
--- a/mindspore/lite/src/huffman_decode.cc
+++ b/mindspore/lite/src/huffman_decode.cc
@@ -18,8 +18,7 @@
 
 namespace mindspore {
 namespace lite {
-
-STATUS HuffmanDecode::DoHuffmanDecode(const std::string &input_str, void *decoded_data) {
+STATUS HuffmanDecode::DoHuffmanDecode(const std::string &input_str, void *decoded_data, int data_len) {
   if (decoded_data == nullptr) {
     MS_LOG(ERROR) << "decoded_data is nullptr.";
     return RET_ERROR;
@@ -58,8 +57,11 @@ STATUS HuffmanDecode::DoHuffmanDecode(const std::string &input_str, void *decode
   }
 
   size_t len = huffman_decoded_str.length();
-  memcpy(decoded_data, huffman_decoded_str.c_str(), len);
-
+  if (data_len >= len) {
+    memcpy(decoded_data, huffman_decoded_str.c_str(), len);
+  } else {
+    return RET_ERROR;
+  }
   delete root;
   return RET_OK;
 }
@@ -163,6 +165,5 @@ HuffmanDecode::~HuffmanDecode() {
   }
   this->huffman_nodes_.resize(0);
 }
-
 }  // namespace lite
 }  // namespace mindspore
diff --git a/mindspore/lite/src/huffman_decode.h b/mindspore/lite/src/huffman_decode.h
index 9f15537..5a534d5 100644
--- a/mindspore/lite/src/huffman_decode.h
+++ b/mindspore/lite/src/huffman_decode.h
@@ -44,7 +44,7 @@ class HuffmanDecode {
 
   ~HuffmanDecode();
 
-  STATUS DoHuffmanDecode(const std::string &input_str, void *decoded_data);
+  STATUS DoHuffmanDecode(const std::string &input_str, void *decoded_data, int data_len);
 
  private:
   std::vector<HuffmanNodePtr> huffman_nodes_;
diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc
index b78b974..ec5897e 100644
--- a/mindspore/lite/src/lite_session.cc
+++ b/mindspore/lite/src/lite_session.cc
@@ -119,7 +119,7 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde
           return RET_NULL_PTR;
         }
         if (NeedUnPack()) {
-          auto ret = DequantUtil::UnPackToInt(src_tensor, dst_data);
+          auto ret = DequantUtil::UnPackToInt(src_tensor, dst_data, dst_tensor->Size());
           if (ret != RET_OK) {
             MS_LOG(ERROR) << "unpack to int failed.";
             return RET_NULL_PTR;
@@ -135,7 +135,7 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde
             MS_LOG(ERROR) << "Data from tensor is nullptr";
             return RET_ERROR;
           }
-          auto ret = DequantUtil::UnPackToInt(src_tensor, dst_data);
+          auto ret = DequantUtil::UnPackToInt(src_tensor, dst_data, dst_tensor->Size());
           if (ret != RET_OK) {
             MS_LOG(ERROR) << "unpack to int failed.";
             return RET_ERROR;
diff --git a/mindspore/lite/src/ops/populate/arithmetic_populate.cc b/mindspore/lite/src/ops/populate/arithmetic_populate.cc
index f4bec62..8dae508 100644
--- a/mindspore/lite/src/ops/populate/arithmetic_populate.cc
+++ b/mindspore/lite/src/ops/populate/arithmetic_populate.cc
@@ -59,5 +59,6 @@ Registry g_floorDivParameterRegistry(schema::PrimitiveType_FloorDiv, PopulateAri
 Registry g_floorModParameterRegistry(schema::PrimitiveType_FloorMod, PopulateArithmetic, SCHEMA_CUR);
 Registry g_modParameterRegistry(schema::PrimitiveType_Mod, PopulateArithmetic, SCHEMA_CUR);
 Registry g_squaredDifferenceParameterRegistry(schema::PrimitiveType_SquaredDifference, PopulateArithmetic, SCHEMA_CUR);
+Registry g_populateBiasGradParameterParameterRegistry(schema::PrimitiveType_BiasAddGrad, PopulateArithmetic,SCHEMA_CUR);
 }  // namespace lite
 }  // namespace mindspore
diff --git a/mindspore/lite/src/ops/populate/pooling_populate.cc b/mindspore/lite/src/ops/populate/pooling_populate.cc
index af36828..cf3c962 100644
--- a/mindspore/lite/src/ops/populate/pooling_populate.cc
+++ b/mindspore/lite/src/ops/populate/pooling_populate.cc
@@ -19,6 +19,34 @@
 namespace mindspore {
 namespace lite {
 namespace {
+void SetPoolingParamPadMod(schema::PadMode pad_mode, PoolingParameter *pooling_param) {
+  switch (pad_mode) {
+    case schema::PadMode_SAME:
+      pooling_param->pad_mode_ = Pad_same;
+      break;
+    case schema::PadMode_VALID:
+      pooling_param->pad_mode_ = Pad_valid;
+      break;
+    default:
+      pooling_param->pad_mode_ = Pad_pad;
+      break;
+  }
+}
+
+void SetPoolingParamRoundMod(schema::RoundMode round_mode, PoolingParameter *pooling_param) {
+  switch (round_mode) {
+    case schema::RoundMode_FLOOR:
+      pooling_param->round_mode_ = RoundMode_Floor;
+      break;
+    case schema::RoundMode_CEIL:
+      pooling_param->round_mode_ = RoundMode_Ceil;
+      break;
+    default:
+      pooling_param->round_mode_ = RoundMode_No;
+      break;
+  }
+}
+
 OpParameter *PopulateAvgPoolParameter(const void *primitive) {
   PoolingParameter *pooling_param = reinterpret_cast<PoolingParameter *>(malloc(sizeof(PoolingParameter)));
   if (pooling_param == nullptr) {
@@ -45,17 +73,7 @@ OpParameter *PopulateAvgPoolParameter(const void *primitive) {
   }
 
   auto round_mode = pooling_primitive->round_mode();
-  switch (round_mode) {
-    case schema::RoundMode_FLOOR:
-      pooling_param->round_mode_ = RoundMode_Floor;
-      break;
-    case schema::RoundMode_CEIL:
-      pooling_param->round_mode_ = RoundMode_Ceil;
-      break;
-    default:
-      pooling_param->round_mode_ = RoundMode_No;
-      break;
-  }
+  SetPoolingParamRoundMod(round_mode, pooling_param);
 
   if (pooling_primitive->activation_type() == schema::ActivationType_RELU) {
     pooling_param->act_type_ = ActType_Relu;
@@ -64,18 +82,7 @@ OpParameter *PopulateAvgPoolParameter(const void *primitive) {
   } else {
     pooling_param->act_type_ = ActType_No;
   }
-
-  switch (pooling_primitive->pad_mode()) {
-    case schema::PadMode_SAME:
-      pooling_param->pad_mode_ = Pad_same;
-      break;
-    case schema::PadMode_VALID:
-      pooling_param->pad_mode_ = Pad_valid;
-      break;
-    default:
-      pooling_param->pad_mode_ = Pad_pad;
-      break;
-  }
+  SetPoolingParamPadMod(pooling_primitive->pad_mode(), pooling_param);
   return reinterpret_cast<OpParameter *>(pooling_param);
 }
 
@@ -105,18 +112,7 @@ OpParameter *PopulateMaxPoolParameter(const void *primitive) {
   }
 
   auto round_mode = max_pool_prim->round_mode();
-  switch (round_mode) {
-    case schema::RoundMode_FLOOR:
-      pooling_param->round_mode_ = RoundMode_Floor;
-      break;
-    case schema::RoundMode_CEIL:
-      pooling_param->round_mode_ = RoundMode_Ceil;
-      break;
-    default:
-      pooling_param->round_mode_ = RoundMode_No;
-      break;
-  }
-
+  SetPoolingParamRoundMod(round_mode, pooling_param);
   if (max_pool_prim->activation_type() == schema::ActivationType_RELU) {
     pooling_param->act_type_ = ActType_Relu;
   } else if (max_pool_prim->activation_type() == schema::ActivationType_RELU6) {
@@ -124,23 +120,67 @@ OpParameter *PopulateMaxPoolParameter(const void *primitive) {
   } else {
     pooling_param->act_type_ = ActType_No;
   }
+  SetPoolingParamPadMod(max_pool_prim->pad_mode(), pooling_param);
+  return reinterpret_cast<OpParameter *>(pooling_param);
+}
 
-  switch (max_pool_prim->pad_mode()) {
-    case schema::PadMode_SAME:
-      pooling_param->pad_mode_ = Pad_same;
-      break;
-    case schema::PadMode_VALID:
-      pooling_param->pad_mode_ = Pad_valid;
-      break;
-    default:
-      pooling_param->pad_mode_ = Pad_pad;
-      break;
+OpParameter *PopulateMaxPoolGradParameter(const void *prim) {
+  PoolingParameter *pooling_param = reinterpret_cast<PoolingParameter *>(malloc(sizeof(PoolingParameter)));
+  if (pooling_param == nullptr) {
+    MS_LOG(ERROR) << "malloc PoolingParameter failed.";
+    return nullptr;
+  }
+  auto primitive = static_cast<const schema::Primitive *>(prim);
+  auto value = primitive->value_as_MaxPoolGrad();
+  pooling_param->op_parameter_.type_ = primitive->value_type();
+
+  pooling_param->global_ = false;
+  pooling_param->window_w_ = static_cast<int>(value->kernel_size()->Get(1));
+  pooling_param->window_h_ = static_cast<int>(value->kernel_size()->Get(0));
+
+  pooling_param->pad_u_ = 0;
+  pooling_param->pad_d_ = 0;
+  pooling_param->pad_l_ = 0;
+  pooling_param->pad_r_ = 0;
+  pooling_param->stride_w_ = static_cast<int>(value->strides()->Get(1));
+  pooling_param->stride_h_ = static_cast<int>(value->strides()->Get(0));
+  pooling_param->round_mode_ = RoundMode_No;
+  pooling_param->pool_mode_ = PoolMode_MaxPool;
+  SetPoolingParamPadMod(value->pad_mode(), pooling_param);
+  return reinterpret_cast<OpParameter *>(pooling_param);
+}
+
+OpParameter *PopulateAvgPoolGradParameter(const void *prim) {
+  PoolingParameter *pooling_param = reinterpret_cast<PoolingParameter *>(malloc(sizeof(PoolingParameter)));
+  if (pooling_param == nullptr) {
+    MS_LOG(ERROR) << "malloc PoolingParameter failed.";
+    return nullptr;
   }
+  auto primitive = static_cast<const schema::Primitive *>(prim);
+  auto value = primitive->value_as_AvgPoolGrad();
+  pooling_param->op_parameter_.type_ = primitive->value_type();
+
+  pooling_param->global_ = false;
+  pooling_param->window_w_ = static_cast<int>(value->kernel_size()->Get(1));
+  pooling_param->window_h_ = static_cast<int>(value->kernel_size()->Get(0));
+
+  pooling_param->pad_u_ = 0;
+  pooling_param->pad_d_ = 0;
+  pooling_param->pad_l_ = 0;
+  pooling_param->pad_r_ = 0;
+  pooling_param->stride_w_ = static_cast<int>(value->strides()->Get(1));
+  pooling_param->stride_h_ = static_cast<int>(value->strides()->Get(0));
+
+  SetPoolingParamPadMod(value->pad_mode(), pooling_param);
+  pooling_param->round_mode_ = RoundMode_No;
+  pooling_param->pool_mode_ = PoolMode_AvgPool;
   return reinterpret_cast<OpParameter *>(pooling_param);
 }
 }  // namespace
 
 Registry g_avgPoolParameterRegistry(schema::PrimitiveType_AvgPoolFusion, PopulateAvgPoolParameter, SCHEMA_CUR);
 Registry g_maxPoolParameterRegistry(schema::PrimitiveType_MaxPoolFusion, PopulateMaxPoolParameter, SCHEMA_CUR);
+Registry g_avgPoolGradParameterRegistry(schema::PrimitiveType_AvgPoolGrad, PopulateAvgPoolGradParameter, SCHEMA_CUR);
+Registry g_maxPoolGradParameterRegistry(schema::PrimitiveType_MaxPoolGrad, PopulateMaxPoolGradParameter, SCHEMA_CUR);
 }  // namespace lite
 }  // namespace mindspore
diff --git a/mindspore/lite/src/train/accuracy_metrics.cc b/mindspore/lite/src/train/accuracy_metrics.cc
index 6b79088..6d6a5c5 100644
--- a/mindspore/lite/src/train/accuracy_metrics.cc
+++ b/mindspore/lite/src/train/accuracy_metrics.cc
@@ -22,7 +22,6 @@
 
 namespace mindspore {
 namespace lite {
-
 AccuracyMetrics::AccuracyMetrics(int accuracy_metrics, const std::vector<int> &input_indexes,
                                  const std::vector<int> &output_indexes)
     : Metrics() {
@@ -66,6 +65,5 @@ float AccuracyMetrics::Eval() {
 
   return (total_accuracy_ / total_steps_);
 }
-
 }  // namespace lite
 }  // namespace mindspore
diff --git a/mindspore/lite/src/train/accuracy_monitor.cc b/mindspore/lite/src/train/accuracy_monitor.cc
index a9aabdc..0cb3b5d 100644
--- a/mindspore/lite/src/train/accuracy_monitor.cc
+++ b/mindspore/lite/src/train/accuracy_monitor.cc
@@ -29,7 +29,6 @@
 
 namespace mindspore {
 namespace lite {
-
 void AccuracyMonitor::Begin(const session::TrainLoopCallBackData &cb_data) {
   if (cb_data.epoch_ == 0) accuracies_.clear();
 }
@@ -40,6 +39,5 @@ int AccuracyMonitor::EpochEnd(const session::TrainLoopCallBackData &cb_data) {
   accuracies_.push_back(std::make_pair(cb_data.epoch_, 0.0));
   return mindspore::session::RET_CONTINUE;
 }
-
 }  // namespace lite
 }  // namespace mindspore
diff --git a/mindspore/lite/src/train/classification_train_accuracy_monitor.cc b/mindspore/lite/src/train/classification_train_accuracy_monitor.cc
index b373fd1..7cf7da8 100644
--- a/mindspore/lite/src/train/classification_train_accuracy_monitor.cc
+++ b/mindspore/lite/src/train/classification_train_accuracy_monitor.cc
@@ -27,7 +27,6 @@ using mindspore::WARNING;
 
 namespace mindspore {
 namespace lite {
-
 ClassificationTrainAccuracyMonitor::ClassificationTrainAccuracyMonitor(int print_every_n, int accuracy_metrics,
                                                                        const std::vector<int> &input_indexes,
                                                                        const std::vector<int> &output_indexes) {
@@ -60,7 +59,7 @@ void ClassificationTrainAccuracyMonitor::EpochBegin(const session::TrainLoopCall
 int ClassificationTrainAccuracyMonitor::EpochEnd(const session::TrainLoopCallBackData &cb_data) {
   if (cb_data.step_ > 0) accuracies_.at(cb_data.epoch_).second /= static_cast<float>(cb_data.step_ + 1);
   if ((cb_data.epoch_ + 1) % print_every_n_ == 0) {
-    std::cout << "Epoch (" << cb_data.epoch_ + 1 << "):\tTraining Accuracy is " << accuracies_.at(cb_data.epoch_).second
+    std::cout << "Epoch (" << (cb_data.epoch_ + 1) << "):\tTraining Accuracy is " << accuracies_.at(cb_data.epoch_).second
               << std::endl;
   }
   return mindspore::session::RET_CONTINUE;
@@ -86,6 +85,5 @@ void ClassificationTrainAccuracyMonitor::StepEnd(const session::TrainLoopCallBac
   }
   accuracies_.at(cb_data.epoch_).second += accuracy;
 }
-
 }  // namespace lite
 }  // namespace mindspore
diff --git a/mindspore/lite/src/train/loss_monitor.cc b/mindspore/lite/src/train/loss_monitor.cc
index 2230316..7e9eb3d 100644
--- a/mindspore/lite/src/train/loss_monitor.cc
+++ b/mindspore/lite/src/train/loss_monitor.cc
@@ -26,7 +26,6 @@
 
 namespace mindspore {
 namespace lite {
-
 void LossMonitor::Begin(const session::TrainLoopCallBackData &cb_data) {
   if (cb_data.epoch_ == 0) losses_.clear();
 }
@@ -42,7 +41,7 @@ void LossMonitor::EpochBegin(const session::TrainLoopCallBackData &cb_data) {
 int LossMonitor::EpochEnd(const session::TrainLoopCallBackData &cb_data) {
   if (cb_data.step_ > 0) losses_.at(cb_data.epoch_).second /= static_cast<float>(cb_data.step_ + 1);
   if (print_every_n_ > 0) {
-    std::cout << "Epoch (" << cb_data.epoch_ + 1 << "):\tLoss is " << losses_.at(cb_data.epoch_).second << std::endl;
+    std::cout << "Epoch (" << (cb_data.epoch_ + 1) << "):\tLoss is " << losses_.at(cb_data.epoch_).second << std::endl;
   }
   return mindspore::session::RET_CONTINUE;
 }
@@ -54,12 +53,11 @@ void LossMonitor::StepEnd(const session::TrainLoopCallBackData &cb_data) {
       auto loss = reinterpret_cast<float *>(it->second->MutableData());
       losses_.at(cb_data.epoch_).second += loss[0];
       if ((cb_data.step_ + 1) % print_every_n_ == 0)
-        std::cout << cb_data.epoch_ + 1 << "." << cb_data.step_ + 1 << ":\tLoss is " << loss[0] << std::endl;
+        std::cout << (cb_data.epoch_ + 1) << "." << (cb_data.step_ + 1) << ":\tLoss is " << loss[0] << std::endl;
       return;
     }
   }
   MS_LOG(WARNING) << "Model does not have a loss output tensor of size 1";
 }
-
 }  // namespace lite
 }  // namespace mindspore
diff --git a/mindspore/lite/src/train/lr_scheduler.cc b/mindspore/lite/src/train/lr_scheduler.cc
index 461043d..60c07f9 100644
--- a/mindspore/lite/src/train/lr_scheduler.cc
+++ b/mindspore/lite/src/train/lr_scheduler.cc
@@ -29,7 +29,6 @@
 
 namespace mindspore {
 namespace lite {
-
 int MultiplicativeLRLambda(float *lr, int epoch, void *lr_cb_data) {
   if ((lr == nullptr) || (lr_cb_data == nullptr)) {
     MS_LOG(ERROR) << "nullptr passed as input to MultiplicativeLRLambda";
@@ -70,6 +69,5 @@ int LRScheduler::EpochEnd(const session::TrainLoopCallBackData &cb_data) {
   }
   return mindspore::session::RET_CONTINUE;
 }
-
 }  // namespace lite
 }  // namespace mindspore
diff --git a/mindspore/lite/src/train/train_loop.cc b/mindspore/lite/src/train/train_loop.cc
index fd4b265..d8fe4dd 100644
--- a/mindspore/lite/src/train/train_loop.cc
+++ b/mindspore/lite/src/train/train_loop.cc
@@ -26,7 +26,6 @@
 
 namespace mindspore {
 namespace lite {
-
 using dataset::Dataset;
 using dataset::Iterator;
 using dataset::MSTensorVec;
@@ -133,28 +132,8 @@ int TrainLoop::LoadData(std::vector<tensor::MSTensor *> inputs, dataset::MSTenso
   }
 
   for (unsigned int i = 0; i < num_of_inputs; i++) {
-    unsigned char *input_data = reinterpret_cast<unsigned char *>(inputs.at(i)->MutableData());
-    const unsigned char *row_data = reinterpret_cast<const unsigned char *>(row_vec->at(i).MutableData());
-    auto data_size = row_vec->at(i).DataSize();
-    if (data_size != inputs.at(i)->Size()) {
-      MS_LOG(WARNING) << "Model Input tensor " << i << " size (" << inputs.at(i)->Size()
-                      << ") does not match dataset size (" << data_size << ")\n";
-      return RET_STOP_TRAINING;
-    }
-    std::copy(row_data, row_data + data_size, input_data);
-  }
-  return RET_OK;
-}
-
-int TrainLoop::LoadPartialData(std::vector<tensor::MSTensor *> inputs, dataset::MSTensorVec *row_vec) {
-  auto num_of_inputs = inputs.size();
-  if ((num_of_inputs == 0) || (row_vec == nullptr) || (num_of_inputs < row_vec->size())) {
-    return RET_STOP_TRAINING;
-  }
-
-  for (unsigned int i = 0; i < row_vec->size(); i++) {
-    unsigned char *input_data = reinterpret_cast<unsigned char *>(inputs.at(i)->MutableData());
-    const unsigned char *row_data = reinterpret_cast<const unsigned char *>(row_vec->at(i).MutableData());
+    auto *input_data = reinterpret_cast<unsigned char *>(inputs.at(i)->MutableData());
+    const auto *row_data = reinterpret_cast<const unsigned char *>(row_vec->at(i).MutableData());
     auto data_size = row_vec->at(i).DataSize();
     if (data_size != inputs.at(i)->Size()) {
       MS_LOG(WARNING) << "Model Input tensor " << i << " size (" << inputs.at(i)->Size()
@@ -165,12 +144,10 @@ int TrainLoop::LoadPartialData(std::vector<tensor::MSTensor *> inputs, dataset::
   }
   return RET_OK;
 }
-
 }  // namespace lite
 
 session::TrainLoop *session::TrainLoop::CreateTrainLoop(session::TrainSession *train_session) {
   auto loop = new (std::nothrow) lite::TrainLoop(train_session);
   return loop;
 }
-
 }  // namespace mindspore
diff --git a/mindspore/lite/src/train/train_loop.h b/mindspore/lite/src/train/train_loop.h
index 0098be3..f0a75fa 100644
--- a/mindspore/lite/src/train/train_loop.h
+++ b/mindspore/lite/src/train/train_loop.h
@@ -63,7 +63,6 @@ class TrainLoop : virtual public session::TrainLoop {
 
  protected:
   static int LoadData(std::vector<tensor::MSTensor *> inputs, dataset::MSTensorVec *dataset_vec);
-  static int LoadPartialData(std::vector<tensor::MSTensor *> inputs, dataset::MSTensorVec *dataset_vec);
 
   session::TrainSession *train_session_ = nullptr;
   unsigned int epoch_ = 0;
diff --git a/mindspore/lite/src/train/train_model.cc b/mindspore/lite/src/train/train_model.cc
index bab09e3..3b0b600 100644
--- a/mindspore/lite/src/train/train_model.cc
+++ b/mindspore/lite/src/train/train_model.cc
@@ -19,7 +19,6 @@
 #include "src/common/graph_util.h"
 
 namespace mindspore::lite {
-
 TrainModel *TrainModel::Import(const char *model_buf, size_t size) {
   if (model_buf == nullptr) {
     MS_LOG(ERROR) << "The model buf is nullptr";
diff --git a/mindspore/lite/src/train/train_populate_parameter.cc b/mindspore/lite/src/train/train_populate_parameter.cc
index 1d4a836b..7d7f155 100644
--- a/mindspore/lite/src/train/train_populate_parameter.cc
+++ b/mindspore/lite/src/train/train_populate_parameter.cc
@@ -31,7 +31,10 @@
 #include "nnacl/fp32_grad/resize_grad.h"
 namespace mindspore {
 namespace kernel {
-
+namespace {
+constexpr int kNHWCWDim = 2;
+constexpr int kNHWCCDim = 3;
+}  //  namespace
 OpParameter *PopulateSmoothL1LossParameter(const void *prim) {
   SmoothL1LossParameter *p = reinterpret_cast<SmoothL1LossParameter *>(malloc(sizeof(SmoothL1LossParameter)));
   if (p == nullptr) {
@@ -152,91 +155,6 @@ OpParameter *PopulateSoftmaxCrossEntropyParameter(const void *prim) {
   return reinterpret_cast<OpParameter *>(sce_param);
 }
 
-OpParameter *PopulateMaxPoolGradParameter(const void *prim) {
-  PoolingParameter *pooling_param = reinterpret_cast<PoolingParameter *>(malloc(sizeof(PoolingParameter)));
-  if (pooling_param == nullptr) {
-    MS_LOG(ERROR) << "malloc PoolingParameter failed.";
-    return nullptr;
-  }
-  auto primitive = static_cast<const schema::Primitive *>(prim);
-  auto value = primitive->value_as_MaxPoolGrad();
-  pooling_param->op_parameter_.type_ = primitive->value_type();
-
-  pooling_param->global_ = false;
-  pooling_param->window_w_ = static_cast<int>(value->kernel_size()->Get(1));
-  pooling_param->window_h_ = static_cast<int>(value->kernel_size()->Get(0));
-
-  pooling_param->pad_u_ = 0;
-  pooling_param->pad_d_ = 0;
-  pooling_param->pad_l_ = 0;
-  pooling_param->pad_r_ = 0;
-  pooling_param->stride_w_ = static_cast<int>(value->strides()->Get(1));
-  pooling_param->stride_h_ = static_cast<int>(value->strides()->Get(0));
-  pooling_param->round_mode_ = RoundMode_No;
-  pooling_param->pool_mode_ = PoolMode_MaxPool;
-  switch (value->pad_mode()) {
-    case schema::PadMode_SAME:
-      pooling_param->pad_mode_ = Pad_same;
-      break;
-    case schema::PadMode_VALID:
-      pooling_param->pad_mode_ = Pad_valid;
-      break;
-    default:
-      pooling_param->pad_mode_ = Pad_pad;
-      break;
-  }
-
-  return reinterpret_cast<OpParameter *>(pooling_param);
-}
-
-OpParameter *PopulateAvgPoolGradParameter(const void *prim) {
-  PoolingParameter *pooling_param = reinterpret_cast<PoolingParameter *>(malloc(sizeof(PoolingParameter)));
-  if (pooling_param == nullptr) {
-    MS_LOG(ERROR) << "malloc PoolingParameter failed.";
-    return nullptr;
-  }
-  auto primitive = static_cast<const schema::Primitive *>(prim);
-  auto value = primitive->value_as_AvgPoolGrad();
-  pooling_param->op_parameter_.type_ = primitive->value_type();
-
-  pooling_param->global_ = false;
-  pooling_param->window_w_ = static_cast<int>(value->kernel_size()->Get(1));
-  pooling_param->window_h_ = static_cast<int>(value->kernel_size()->Get(0));
-
-  pooling_param->pad_u_ = 0;
-  pooling_param->pad_d_ = 0;
-  pooling_param->pad_l_ = 0;
-  pooling_param->pad_r_ = 0;
-  pooling_param->stride_w_ = static_cast<int>(value->strides()->Get(1));
-  pooling_param->stride_h_ = static_cast<int>(value->strides()->Get(0));
-
-  switch (value->pad_mode()) {
-    case schema::PadMode_SAME:
-      pooling_param->pad_mode_ = Pad_same;
-      break;
-    case schema::PadMode_VALID:
-      pooling_param->pad_mode_ = Pad_valid;
-      break;
-    default:
-      pooling_param->pad_mode_ = Pad_pad;
-      break;
-  }
-  pooling_param->round_mode_ = RoundMode_No;
-  pooling_param->pool_mode_ = PoolMode_AvgPool;
-  switch (value->pad_mode()) {
-    case schema::PadMode_SAME:
-      pooling_param->pad_mode_ = Pad_same;
-      break;
-    case schema::PadMode_VALID:
-      pooling_param->pad_mode_ = Pad_valid;
-      break;
-    default:
-      pooling_param->pad_mode_ = Pad_pad;
-      break;
-  }
-  return reinterpret_cast<OpParameter *>(pooling_param);
-}
-
 OpParameter *PopulateActivationGradParameter(const void *prim) {
   ActivationParameter *act_param = reinterpret_cast<ActivationParameter *>(malloc(sizeof(ActivationParameter)));
   if (act_param == nullptr) {
@@ -251,6 +169,21 @@ OpParameter *PopulateActivationGradParameter(const void *prim) {
   return reinterpret_cast<OpParameter *>(act_param);
 }
 
+void SetConvParam(ConvParameter *param, const flatbuffers::Vector<int64_t> *kernel_size,
+                  const flatbuffers::Vector<int64_t> *stride, const flatbuffers::Vector<int64_t> *dilation,
+                  const flatbuffers::Vector<int64_t> *pad_list) {
+  param->kernel_h_ = kernel_size->Get(0);
+  param->kernel_w_ = kernel_size->Get(1);
+  param->stride_h_ = stride->Get(0);
+  param->stride_w_ = stride->Get(1);
+  param->dilation_h_ = dilation->Get(0);
+  param->dilation_w_ = dilation->Get(1);
+  param->pad_u_ = pad_list->Get(0);
+  param->pad_d_ = pad_list->Get(1);
+  param->pad_l_ = pad_list->Get(kNHWCWDim);
+  param->pad_r_ = pad_list->Get(kNHWCCDim);
+}
+
 OpParameter *PopulateConvolutionGradFilterParameter(const void *prim) {
   ConvParameter *param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
   if (param == nullptr) {
@@ -261,17 +194,7 @@ OpParameter *PopulateConvolutionGradFilterParameter(const void *prim) {
   auto primitive = static_cast<const schema::Primitive *>(prim);
   auto value = primitive->value_as_Conv2DBackpropFilterFusion();
   param->op_parameter_.type_ = primitive->value_type();
-
-  param->kernel_h_ = value->kernel_size()->Get(0);
-  param->kernel_w_ = value->kernel_size()->Get(1);
-  param->stride_h_ = value->stride()->Get(0);
-  param->stride_w_ = value->stride()->Get(1);
-  param->dilation_h_ = value->dilation()->Get(0);
-  param->dilation_w_ = value->dilation()->Get(1);
-  param->pad_u_ = value->pad_list()->Get(0);
-  param->pad_d_ = value->pad_list()->Get(1);
-  param->pad_l_ = value->pad_list()->Get(2);
-  param->pad_r_ = value->pad_list()->Get(3);
+  SetConvParam(param, value->kernel_size(), value->stride(), value->dilation(), value->pad_list());
   param->group_ = value->group();
   param->act_type_ = ActType_No;
   switch (value->activation_type()) {
@@ -284,7 +207,6 @@ OpParameter *PopulateConvolutionGradFilterParameter(const void *prim) {
     default:
       break;
   }
-
   return reinterpret_cast<OpParameter *>(param);
 }
 
@@ -297,17 +219,7 @@ OpParameter *PopulateConvolutionGradInputParameter(const void *prim) {
   auto primitive = static_cast<const schema::Primitive *>(prim);
   auto value = primitive->value_as_Conv2DBackpropInputFusion();
   param->op_parameter_.type_ = primitive->value_type();
-
-  param->kernel_h_ = value->kernel_size()->Get(0);
-  param->kernel_w_ = value->kernel_size()->Get(1);
-  param->stride_h_ = value->stride()->Get(0);
-  param->stride_w_ = value->stride()->Get(1);
-  param->dilation_h_ = value->dilation()->Get(0);
-  param->dilation_w_ = value->dilation()->Get(1);
-  param->pad_u_ = value->pad_list()->Get(0);
-  param->pad_d_ = value->pad_list()->Get(1);
-  param->pad_l_ = value->pad_list()->Get(2);
-  param->pad_r_ = value->pad_list()->Get(3);
+  SetConvParam(param, value->kernel_size(), value->stride(), value->dilation(), value->pad_list());
   param->group_ = value->group();
   param->act_type_ = ActType_No;
   switch (value->activation_type()) {
@@ -339,17 +251,6 @@ OpParameter *PopulatePowerGradParameter(const void *prim) {
   return reinterpret_cast<OpParameter *>(power_param);
 }
 
-OpParameter *PopulateBiasGradParameter(const void *prim) {
-  ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
-  if (arithmetic_param == nullptr) {
-    MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
-    return nullptr;
-  }
-  auto primitive = static_cast<const schema::Primitive *>(prim);
-  arithmetic_param->op_parameter_.type_ = primitive->value_type();
-  return reinterpret_cast<OpParameter *>(arithmetic_param);
-}
-
 OpParameter *PopulateBNGradParameter(const void *prim) {
   BNGradParameter *bnGrad_param = reinterpret_cast<BNGradParameter *>(malloc(sizeof(BNGradParameter)));
   if (bnGrad_param == nullptr) {
@@ -430,32 +331,9 @@ OpParameter *PopulateResizeGradParameter(const void *prim) {
   return reinterpret_cast<OpParameter *>(resize_grad_param);
 }
 
-OpParameter *PopulateStridedSliceGradParameter(const void *prim) {
-  StridedSliceParameter *strided_slice_param =
-    reinterpret_cast<StridedSliceParameter *>(malloc(sizeof(StridedSliceParameter)));
-  if (strided_slice_param == nullptr) {
-    MS_LOG(ERROR) << "malloc StridedSliceParameter failed.";
-    return nullptr;
-  }
-  memset(strided_slice_param, 0, sizeof(StridedSliceParameter));
-
-  auto primitive = static_cast<const schema::Primitive *>(prim);
-  auto value = primitive->value_as_StridedSliceGrad();
-  strided_slice_param->op_parameter_.type_ = primitive->value_type();
-
-  strided_slice_param->begins_mask_ = value->begin_mask();
-  strided_slice_param->ends_mask_ = value->end_mask();
-  strided_slice_param->ellipsisMask_ = value->ellipsis_mask();
-  strided_slice_param->newAxisMask_ = value->new_axis_mask();
-  strided_slice_param->shrinkAxisMask_ = value->shrink_axis_mask();
-  return reinterpret_cast<OpParameter *>(strided_slice_param);
-}
-
 void PopulateTrainParameters() {
   lite::Registry ApplyMomentumParameterRegistry(schema::PrimitiveType_ApplyMomentum, PopulateApplyMomentumParameter,
                                                 lite::SCHEMA_CUR);
-  lite::Registry BiasGradParameterRegistry(schema::PrimitiveType_BiasAddGrad, PopulateBiasGradParameter,
-                                           lite::SCHEMA_CUR);
   lite::Registry SoftmaxCrossEntropyParameterRegistry(schema::PrimitiveType_SoftmaxCrossEntropyWithLogits,
                                                       PopulateSoftmaxCrossEntropyParameter, lite::SCHEMA_CUR);
   lite::Registry SparseSoftmaxCrossEntropyParameterRegistry(schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits,
@@ -469,10 +347,6 @@ void PopulateTrainParameters() {
                                                    PopulateConvolutionGradFilterParameter, lite::SCHEMA_CUR);
   lite::Registry Conv2DGradInputParameterRegistry(schema::PrimitiveType_Conv2DBackpropInputFusion,
                                                   PopulateConvolutionGradInputParameter, lite::SCHEMA_CUR);
-  lite::Registry avgPoolParameterRegistry(schema::PrimitiveType_AvgPoolGrad, PopulateAvgPoolGradParameter,
-                                          lite::SCHEMA_CUR);
-  lite::Registry maxPoolParameterRegistry(schema::PrimitiveType_MaxPoolGrad, PopulateMaxPoolGradParameter,
-                                          lite::SCHEMA_CUR);
   lite::Registry PowerGradParameterRegistry(schema::PrimitiveType_PowerGrad, PopulatePowerGradParameter,
                                             lite::SCHEMA_CUR);
   lite::Registry SgdParameterRegistry(schema::PrimitiveType_SGD, PopulateSgdParameter, lite::SCHEMA_CUR);
@@ -508,8 +382,6 @@ void PopulateTrainParameters() {
                                                            lite::DefaultPopulateParameter, lite::SCHEMA_CUR);
   lite::Registry FlattenGradParameterRegistry(schema::PrimitiveType_FlattenGrad, lite::DefaultPopulateParameter,
                                               lite::SCHEMA_CUR);
-  lite::Registry StridedSliceGradParameterRegistry(schema::PrimitiveType_StridedSliceGrad,
-                                                   PopulateStridedSliceGradParameter, lite::SCHEMA_CUR);
   lite::Registry SqrtGradParameterRegistry(schema::PrimitiveType_SqrtGrad, lite::DefaultPopulateParameter,
                                            lite::SCHEMA_CUR);
   lite::Registry RsqrtGradParameterRegistry(schema::PrimitiveType_RsqrtGrad, lite::DefaultPopulateParameter,
diff --git a/mindspore/lite/src/train/train_populate_parameter_v0.cc b/mindspore/lite/src/train/train_populate_parameter_v0.cc
index 44b19b1..8f52614 100644
--- a/mindspore/lite/src/train/train_populate_parameter_v0.cc
+++ b/mindspore/lite/src/train/train_populate_parameter_v0.cc
@@ -585,7 +585,6 @@ OpParameter *PopulateArithmeticGradParameter(const void *primitive) {
   }
   return reinterpret_cast<OpParameter *>(arithmetic_param);
 }
-
 }  // namespace
 
 void PopulateTrainV0Parameters() {
@@ -658,5 +657,4 @@ void PopulateTrainV0Parameters() {
   lite::Registry g_sigmoidCrossEntropyWithLogitsGradRegistry(
     schema::v0::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad, DefaultPopulateParameter, mindspore::lite::SCHEMA_V0);
 }
-
 }  // namespace mindspore::kernel
diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc
index 8dd5fc5..ead56e7 100644
--- a/mindspore/lite/src/train/train_session.cc
+++ b/mindspore/lite/src/train/train_session.cc
@@ -38,7 +38,6 @@
 
 namespace mindspore {
 namespace lite {
-
 std::unique_ptr<char[]> ReadFileToBuf(const std::string &filename, size_t *size) {
   std::ifstream ifs(filename);
   if (!ifs.good()) {
@@ -222,7 +221,7 @@ int TrainSession::SaveToFile(const std::string &filename) const {
     return lite::RET_NULL_PTR;
   }
   std::ofstream ofs(filename);
-  if ((true != ofs.good()) || (true != ofs.is_open())) {
+  if (!ofs.good() || !ofs.is_open()) {
     MS_LOG(ERROR) << "Could not open file \"" << filename << "\" for writing";
     return RET_ERROR;
   }
@@ -276,31 +275,34 @@ void TrainSession::CompileEvalOutputs() {
   eval_output_tensor_map_.clear();
   eval_output_tensor_names_.clear();
   for (auto kernel : this->train_kernels_) {
-    if (IsLossKernel(kernel) && !(IsGradKernel(kernel))) {
-      for (auto in_kernel : kernel->in_kernels()) {
-        if (IsLossKernel(in_kernel) || IsGradKernel(in_kernel)) continue;
-        // insert if not already in
-        if (eval_output_node_map_.find(in_kernel->name()) == eval_output_node_map_.end()) {
-          auto *ms_tensor = in_kernel->out_tensors().at(0);
-          if (ms_tensor != nullptr) {
-            eval_output_node_map_[in_kernel->name()].emplace_back(ms_tensor);
-            auto index = TSFindTensor(tensors_, ms_tensor);
-            if (index != tensors_.size()) {
-              eval_output_tensor_map_.insert(std::make_pair(std::to_string(index), ms_tensor));
-              if (!ms_tensor->tensor_name().empty()) {
-                eval_output_tensor_names_.emplace_back(ms_tensor->tensor_name());
-              } else {
-                eval_output_tensor_names_.emplace_back(std::to_string(index));
-              }
-            }
-          }
+    if (!IsLossKernel(kernel) || IsGradKernel(kernel)) {
+      continue;
+    }
+    for (auto in_kernel : kernel->in_kernels()) {
+      if (IsLossKernel(in_kernel) || IsGradKernel(in_kernel)) continue;
+      // insert if not already in
+      if (eval_output_node_map_.find(in_kernel->name()) != eval_output_node_map_.end()) {
+        continue;
+      }
+      auto *ms_tensor = in_kernel->out_tensors().at(0);
+      if (ms_tensor != nullptr) {
+        eval_output_node_map_[in_kernel->name()].emplace_back(ms_tensor);
+        auto index = TSFindTensor(tensors_, ms_tensor);
+        if (index == tensors_.size()) {
+          continue;
+        }
+        eval_output_tensor_map_.insert(std::make_pair(std::to_string(index), ms_tensor));
+        if (!ms_tensor->tensor_name().empty()) {
+          eval_output_tensor_names_.emplace_back(ms_tensor->tensor_name());
+        } else {
+          eval_output_tensor_names_.emplace_back(std::to_string(index));
         }
       }
     }
   }
-  if (eval_output_node_map_.size() == 0) eval_output_node_map_ = orig_output_node_map_;
-  if (eval_output_tensor_map_.size() == 0) eval_output_tensor_map_ = orig_output_tensor_map_;
-  if (eval_output_tensor_names_.size() == 0) eval_output_tensor_names_ = orig_output_tensor_names_;
+  if (eval_output_node_map_.empty()) eval_output_node_map_ = orig_output_node_map_;
+  if (eval_output_tensor_map_.empty()) eval_output_tensor_map_ = orig_output_tensor_map_;
+  if (eval_output_tensor_names_.empty()) eval_output_tensor_names_ = orig_output_tensor_names_;
 }
 
 void TrainSession::CompileTrainOutputs() {
@@ -328,9 +330,9 @@ void TrainSession::CompileTrainOutputs() {
       }
     }
   }
-  if (train_output_node_map_.size() == 0) train_output_node_map_ = orig_output_node_map_;
-  if (train_output_tensor_map_.size() == 0) train_output_tensor_map_ = orig_output_tensor_map_;
-  if (train_output_tensor_names_.size() == 0) train_output_tensor_names_ = orig_output_tensor_names_;
+  if (train_output_node_map_.empty()) train_output_node_map_ = orig_output_node_map_;
+  if (train_output_tensor_map_.empty()) train_output_tensor_map_ = orig_output_tensor_map_;
+  if (train_output_tensor_names_.empty()) train_output_tensor_names_ = orig_output_tensor_names_;
 }
 
 void TrainSession::BuildInferenceKernelsRecursive(kernel::LiteKernel *kernel, std::vector<kernel::LiteKernel *> *v) {
@@ -363,7 +365,7 @@ void TrainSession::CompileInferenceKernels() {
     auto kernel = TSFindKernel(train_kernels_, kernel_name);
     BuildInferenceKernelsRecursive(kernel, &inference_kernels_);
   }
-  if (inference_kernels_.size() == 0) {
+  if (inference_kernels_.empty()) {
     inference_kernels_ = this->train_kernels_;
   }
 }
@@ -574,5 +576,4 @@ session::TrainSession *session::TrainSession::CreateSession(const std::string &f
   }
   return session::TrainSession::CreateSession(buf.get(), size, context, train_mode);
 }
-
 }  // namespace mindspore
diff --git a/mindspore/lite/src/train/train_utils.cc b/mindspore/lite/src/train/train_utils.cc
index bf3ac8a..f0cf23f 100644
--- a/mindspore/lite/src/train/train_utils.cc
+++ b/mindspore/lite/src/train/train_utils.cc
@@ -22,9 +22,12 @@
 
 namespace mindspore {
 namespace lite {
-
+namespace {
+constexpr int kMatrixDims = 2;
+}  //  namespace
 float CalculateSparseClassification(tensor::MSTensor *input, tensor::MSTensor *output) {
-  if ((input->shape().size() != 1) || (input->data_type() != kNumberTypeInt32) || (output->shape().size() != 2)) {
+  if ((input->shape().size() != 1) || (input->data_type() != kNumberTypeInt32) ||
+      (output->shape().size() != kMatrixDims)) {
     MS_LOG(WARNING) << "SparceClassification got a " << input->shape() << "-D input tensor, " << output->shape()
                     << "-D output tensor";
     return 0.0;
@@ -50,7 +53,7 @@ float CalculateSparseClassification(tensor::MSTensor *input, tensor::MSTensor *o
 }
 
 float CalculateOneHotClassification(tensor::MSTensor *input, tensor::MSTensor *output) {
-  if ((input->shape().size() != 2) || (output->shape().size() != 2)) {
+  if ((input->shape().size() != kMatrixDims) || (output->shape().size() != kMatrixDims)) {
     MS_LOG(WARNING) << "OneHotClassification got a " << input->shape() << "-D input tensor, " << output->shape()
                     << "-D output tensor";
     return 0.0;
@@ -76,10 +79,11 @@ float CalculateOneHotClassification(tensor::MSTensor *input, tensor::MSTensor *o
         label = c;
       }
     }
-    if (label == max_idx) accuracy += 1.0;
+    if (label == max_idx) {
+      accuracy += 1.0;
+    }
   }
   return accuracy / (static_cast<float>(batch_size));
 }
-
 }  // namespace lite
 }  // namespace mindspore
diff --git a/mindspore/lite/src/train/transfer_session.cc b/mindspore/lite/src/train/transfer_session.cc
index 7be49ca..54e02ea 100644
--- a/mindspore/lite/src/train/transfer_session.cc
+++ b/mindspore/lite/src/train/transfer_session.cc
@@ -26,18 +26,16 @@
 #include "src/common/utils.h"
 #include "src/tensor.h"
 #include "src/train/loss_kernel.h"
-#include "src/train/optimizer_kernel.h"
-#include "src/sub_graph_kernel.h"
-#include "src/train/train_populate_parameter.h"
-#include "src/runtime/runtime_api.h"
 #include "src/executor.h"
-#include "src/kernel_registry.h"
-#include "src/runtime/kernel/arm/fp32_grad/convolution.h"
 #include "nnacl/fp32/pack_fp32.h"
 
 namespace mindspore {
 namespace lite {
-
+namespace {
+constexpr int kNHWCHDim = 2;
+constexpr int kNHWCCDim = 3;
+constexpr int kNHWCDims = 4;
+}  //  namespace
 TransferSession::TransferSession(const char *model_buf_backbone, size_t size_backbone, lite::Context *context)
     : is_valid_(false) {
   lite_model_ = reinterpret_cast<char *>(malloc(size_backbone));
@@ -84,23 +82,23 @@ int TransferSession::CompileTransferGraph() {
             break;
           }
         }
-        if (match == false && input->shape().size() == 4) {
+        if (!match && input->shape().size() == kNHWCDims) {
           int nchw2nhwc_mask[4] = {0, 3, 1, 2};
-          nchw2nhwc_ = CompileFormatTransform(output, input, nchw2nhwc_mask, 4);
+          nchw2nhwc_ = CompileFormatTransform(output, input, nchw2nhwc_mask, kNHWCDims);
           match = nchw2nhwc_;
         }
-        if (true == match) {
+        if (match) {
           break;
         }
       }
     }
-    if (true == match) {
+    if (match) {
       backbone_head_map_.push_back(std::make_pair(input, output));
     } else {
       combined_inputs_.push_back(input);
     }
   }
-  if (0 == backbone_head_map_.size()) {
+  if (backbone_head_map_.empty()) {
     ret = RET_ERROR;
   }
   return ret;
@@ -110,7 +108,7 @@ mindspore::tensor::MSTensor *TransferSession::GetInputsByTensorName(const std::s
   /* First look in backbone netwok */
   auto ret = backbone_session_->GetInputsByTensorName(tensor_name);
   /* If not found look in head network */
-  if (nullptr == ret) {
+  if (ret == nullptr) {
     ret = TrainSession::GetInputsByTensorName(tensor_name);
   }
   return ret;
@@ -142,9 +140,9 @@ int TransferSession::RunGraph(const KernelCallBack &before, const KernelCallBack
     char *input_data = reinterpret_cast<char *>(input->MutableData());
     char *output_data = reinterpret_cast<char *>(output->MutableData());
     if (nchw2nhwc_) {
-      int plane = input->shape().at(1) * input->shape().at(2);
+      int plane = input->shape().at(1) * input->shape().at(kNHWCHDim);
       int batch = input->shape().at(0);
-      int channel = input->shape().at(3);
+      int channel = input->shape().at(kNHWCCDim);
       PackNCHWToNHWCFp32(output_data, input_data, batch, plane, channel, 0, 1);
     } else {
       std::copy(output_data, output_data + output->Size(), input_data);
@@ -153,7 +151,6 @@ int TransferSession::RunGraph(const KernelCallBack &before, const KernelCallBack
   ret = lite::TrainSession::RunGraph(before, after);
   return ret;
 }
-
 }  // namespace lite
 
 session::TrainSession *session::TrainSession::CreateTransferSession(const char *model_buf_backbone,
@@ -239,5 +236,4 @@ session::TrainSession *session::TrainSession::CreateTransferSession(const std::s
   return session::TrainSession::CreateTransferSession(buf_backbone.get(), size_backbone, buf_head.get(), size_head,
                                                       context, train_mode);
 }
-
 }  // namespace mindspore
