<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Editâ†’Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

#Data process

In [None]:
diff --git a/mindspore/lite/examples/export_models/models/densenet_train_export.py b/mindspore/lite/examples/export_models/models/densenet_train_export.py
index 20bd76f..df785e8 100644
--- a/mindspore/lite/examples/export_models/models/densenet_train_export.py
+++ b/mindspore/lite/examples/export_models/models/densenet_train_export.py
@@ -17,14 +17,14 @@
 import sys
 import os
 import numpy as np
-from train_utils import SaveInOut, TrainWrap
+from train_utils import save_inout, train_wrap
 import mindspore.common.dtype as mstype
 from mindspore import context, Tensor, nn
 from mindspore.train.serialization import export
-
+from official.cv.densenet121.src.network.densenet import DenseNet121
 sys.path.append(os.environ['CLOUD_MODEL_ZOO'] + 'official/cv/densenet121/')
 #pylint: disable=wrong-import-position
-from official.cv.densenet121.src.network.densenet import DenseNet121
+
 
 
 
@@ -35,7 +35,7 @@ n = DenseNet121(num_classes=10)
 loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
 optimizer = nn.SGD(n.trainable_params(), learning_rate=0.001, momentum=0.9, dampening=0.0, weight_decay=0.0,
                    nesterov=True, loss_scale=0.9)
-net = TrainWrap(n, loss_fn, optimizer)
+net = train_wrap(n, loss_fn, optimizer)
 
 batch = 2
 x = Tensor(np.random.randn(batch, 3, 224, 224), mstype.float32)
@@ -43,4 +43,4 @@ label = Tensor(np.zeros([batch, 10]).astype(np.float32))
 export(net, x, label, file_name="mindir/densenet_train", file_format='MINDIR')
 
 if len(sys.argv) > 1:
-    SaveInOut(sys.argv[1] + "densenet", x, label, n, net)
+    save_inout(sys.argv[1] + "densenet", x, label, n, net)
diff --git a/mindspore/lite/examples/export_models/models/effnet.py b/mindspore/lite/examples/export_models/models/effnet.py
index 4971757..fc49872 100755
--- a/mindspore/lite/examples/export_models/models/effnet.py
+++ b/mindspore/lite/examples/export_models/models/effnet.py
@@ -20,6 +20,7 @@ from mindspore.ops import operations as P
 from mindspore.common.initializer import TruncatedNormal
 from mindspore import Tensor
 
+
 def weight_variable():
     """weight initial"""
     return TruncatedNormal(0.02)
@@ -40,6 +41,7 @@ def _make_value_divisible(value, factor, min_value=None):
         new_value += factor
     return new_value
 
+
 class Swish(nn.Cell):
     def __init__(self):
         super().__init__()
@@ -58,7 +60,8 @@ class AdaptiveAvgPool(nn.Cell):
         self.output_size = output_size
 
     def construct(self, x):
-        return self.mean(x, (2, 3)) ## This is not a general case
+        return self.mean(x, (2, 3)) # This is not a general case
+
 
 class SELayer(nn.Cell):
     """SELayer"""
@@ -74,24 +77,27 @@ class SELayer(nn.Cell):
         self.act2 = nn.Sigmoid()
 
     def construct(self, x):
-        o = self.avg_pool(x) #.view(b,c)
+        o = self.avg_pool(x) # .view(b,c)
         o = self.conv_reduce(o)
         o = self.act1(o)
         o = self.conv_expand(o)
-        o = self.act2(o) #.view(b, c, 1,1)
+        o = self.act2(o) # .view(b, c, 1,1)
         return x * o
 
+
 class DepthwiseSeparableConv(nn.Cell):
     """DepthwiseSeparableConv"""
     def __init__(self, in_chs, out_chs, dw_kernel_size=3, stride=1, noskip=False, se_ratio=0.0, drop_connect_rate=0.0):
         super().__init__()
-        assert stride in [1, 2]
+        if stride not in [1, 2]:
+            print("ERROR stride param")
+            return
         self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
         self.drop_connect_rate = drop_connect_rate
 
         self.conv_dw = nn.Conv2d(in_channels=in_chs, out_channels=in_chs, kernel_size=dw_kernel_size, stride=stride,
                                  pad_mode="pad", padding=1, has_bias=False, group=in_chs)
-        self.bn1 = nn.BatchNorm2d(in_chs, eps=0.001) #,momentum=0.1)
+        self.bn1 = nn.BatchNorm2d(in_chs, eps=0.001) # momentum=0.1)
         self.act1 = Swish()
 
        # Squeeze-and-excitation
@@ -101,7 +107,7 @@ class DepthwiseSeparableConv(nn.Cell):
             print("ERRRRRORRRR -- not prepared for this one\n")
 
         self.conv_pw = nn.Conv2d(in_channels=in_chs, out_channels=out_chs, kernel_size=1, stride=stride, has_bias=False)
-        self.bn2 = nn.BatchNorm2d(out_chs, eps=0.001) #,momentum=0.1)
+        self.bn2 = nn.BatchNorm2d(out_chs, eps=0.001) # momentum=0.1)
 
     def construct(self, x):
         """construct"""
@@ -120,12 +126,13 @@ class DepthwiseSeparableConv(nn.Cell):
             x += residual
         return x
 
+
 def conv_3x3_bn(inp, oup, stride):
     weight = weight_variable()
     return nn.SequentialCell([
         nn.Conv2d(in_channels=inp, out_channels=oup, kernel_size=3, stride=stride, padding=1, weight_init=weight,
                   has_bias=False, pad_mode='pad'),
-        nn.BatchNorm2d(oup, eps=0.001),  #, momentum=0.1),
+        nn.BatchNorm2d(oup, eps=0.001),  # momentum=0.1),
         nn.HSwish()])
 
 
@@ -142,7 +149,9 @@ class InvertedResidual(nn.Cell):
     """InvertedResidual"""
     def __init__(self, in_chs, out_chs, kernel_size, stride, padding, expansion, se_ratio):
         super().__init__()
-        assert stride in [1, 2]
+        if stride not in [1, 2]:
+            print("ERROR stride param")
+            return
         mid_chs: int = _make_value_divisible(in_chs * expansion, 1)
         self.has_residual = (in_chs == out_chs and stride == 1)
         self.drop_connect_rate = 0
@@ -210,7 +219,7 @@ class EfficientNet(nn.Cell):
 
         self.conv_stem = nn.Conv2d(in_channels=3, out_channels=stem_size, kernel_size=3, stride=2, has_bias=False)
 
-        self.bn1 = nn.BatchNorm2d(stem_size, eps=0.001) #momentum=0.1)
+        self.bn1 = nn.BatchNorm2d(stem_size, eps=0.001) # momentum=0.1)
         self.act1 = Swish()
         in_chs = stem_size
 
@@ -240,7 +249,7 @@ class EfficientNet(nn.Cell):
         self.blocks = nn.SequentialCell(layers)
 
         self.conv_head = nn.Conv2d(in_channels=320, out_channels=self.num_features_, kernel_size=1)
-        self.bn2 = nn.BatchNorm2d(self.num_features_, eps=0.001) #,momentum=0.1)
+        self.bn2 = nn.BatchNorm2d(self.num_features_, eps=0.001) # momentum=0.1)
         self.act2 = Swish()
         self.global_pool = AdaptiveAvgPool(output_size=(1, 1))
         self.classifier = nn.Dense(self.num_features_, num_classes)
diff --git a/mindspore/lite/examples/export_models/models/effnet_train_export.py b/mindspore/lite/examples/export_models/models/effnet_train_export.py
index bf341f2..3384cc2 100644
--- a/mindspore/lite/examples/export_models/models/effnet_train_export.py
+++ b/mindspore/lite/examples/export_models/models/effnet_train_export.py
@@ -16,7 +16,7 @@
 
 import sys
 import numpy as np
-from train_utils import SaveInOut, TrainWrap
+from train_utils import save_inout, train_wrap
 from effnet import effnet
 import mindspore.common.dtype as mstype
 from mindspore import context, Tensor, nn
@@ -28,11 +28,11 @@ n = effnet(num_classes=10)
 loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
 optimizer = nn.SGD(n.trainable_params(), learning_rate=0.01, momentum=0.9, dampening=0.0, weight_decay=0.0,
                    nesterov=True, loss_scale=1.0)
-net = TrainWrap(n, loss_fn, optimizer)
+net = train_wrap(n, loss_fn, optimizer)
 
 x = Tensor(np.random.randn(2, 3, 224, 224), mstype.float32)
 label = Tensor(np.zeros([2, 10]).astype(np.float32))
 export(net, x, label, file_name="mindir/effnet_train", file_format='MINDIR')
 
 if len(sys.argv) > 1:
-    SaveInOut(sys.argv[1] + "effnet", x, label, n, net)
+    save_inout(sys.argv[1] + "effnet", x, label, n, net)
diff --git a/mindspore/lite/examples/export_models/models/effnet_tune_train_export.py b/mindspore/lite/examples/export_models/models/effnet_tune_train_export.py
index 2b21ee8..3e61b44 100644
--- a/mindspore/lite/examples/export_models/models/effnet_tune_train_export.py
+++ b/mindspore/lite/examples/export_models/models/effnet_tune_train_export.py
@@ -17,7 +17,7 @@
 import sys
 from os import path
 import numpy as np
-from train_utils import TrainWrap, SaveT
+from train_utils import train_wrap, save_t
 from effnet import effnet
 import mindspore.common.dtype as mstype
 from mindspore import context, Tensor, nn
@@ -26,11 +26,13 @@ from mindspore.common.parameter import ParameterTuple
 
 context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False)
 
+
 class TransferNet(nn.Cell):
     def __init__(self, backbone, head):
         super().__init__(TransferNet)
         self.backbone = backbone
         self.head = head
+
     def construct(self, x):
         x = self.backbone(x)
         x = self.head(x)
@@ -56,7 +58,7 @@ trainable_weights_list.extend(n.head.trainable_params())
 trainable_weights = ParameterTuple(trainable_weights_list)
 sgd = nn.SGD(trainable_weights, learning_rate=0.01, momentum=0.9,
              dampening=0.01, weight_decay=0.0, nesterov=False, loss_scale=1.0)
-net = TrainWrap(n, optimizer=sgd, weights=trainable_weights)
+net = train_wrap(n, optimizer=sgd, weights=trainable_weights)
 
 BATCH_SIZE = 8
 X = Tensor(np.random.randn(BATCH_SIZE, 3, 224, 224), mstype.float32)
@@ -66,10 +68,10 @@ export(net, X, label, file_name="mindir/effnet_tune_train", file_format='MINDIR'
 if len(sys.argv) > 1:
     name_prefix = sys.argv[1] + "effnet_tune"
     x_name = name_prefix + "_input1.bin"
-    SaveT(Tensor(X.asnumpy().transpose(0, 2, 3, 1)), x_name)
+    save_t(Tensor(X.asnumpy().transpose(0, 2, 3, 1)), x_name)
 
     l_name = name_prefix + "_input2.bin"
-    SaveT(label, l_name)
+    save_t(label, l_name)
 
     #train network
     n.head.set_train(True)
@@ -80,4 +82,4 @@ if len(sys.argv) > 1:
     n.set_train(False)
     y = n(X)
     y_name = name_prefix + "_output1.bin"
-    SaveT(y, y_name)
+    save_t(y, y_name)
diff --git a/mindspore/lite/examples/export_models/models/googlenet_train_export.py b/mindspore/lite/examples/export_models/models/googlenet_train_export.py
index c2ddcc2..91a0062 100644
--- a/mindspore/lite/examples/export_models/models/googlenet_train_export.py
+++ b/mindspore/lite/examples/export_models/models/googlenet_train_export.py
@@ -16,7 +16,7 @@
 
 import sys
 import numpy as np
-from train_utils import SaveInOut, TrainWrap
+from train_utils import save_inout, train_wrap
 from official.cv.googlenet.src.googlenet import GoogleNet
 import mindspore.common.dtype as mstype
 from mindspore import context, Tensor, nn
@@ -28,7 +28,7 @@ n = GoogleNet(num_classes=10)
 loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
 optimizer = nn.SGD(n.trainable_params(), learning_rate=0.01, momentum=0.9, dampening=0.0, weight_decay=5e-4,
                    nesterov=True, loss_scale=0.9)
-net = TrainWrap(n, loss_fn, optimizer)
+net = train_wrap(n, loss_fn, optimizer)
 
 batch = 2
 x = Tensor(np.random.randn(batch, 3, 224, 224), mstype.float32)
@@ -36,4 +36,4 @@ label = Tensor(np.zeros([batch, 10]).astype(np.float32))
 export(net, x, label, file_name="mindir/googlenet_train", file_format='MINDIR')
 
 if len(sys.argv) > 1:
-    SaveInOut(sys.argv[1] + "googlenet", x, label, n, net)
+    save_inout(sys.argv[1] + "googlenet", x, label, n, net)
diff --git a/mindspore/lite/examples/export_models/models/lenet_train_export.py b/mindspore/lite/examples/export_models/models/lenet_train_export.py
index 1b7dfda..4e03aab 100644
--- a/mindspore/lite/examples/export_models/models/lenet_train_export.py
+++ b/mindspore/lite/examples/export_models/models/lenet_train_export.py
@@ -16,7 +16,7 @@
 
 import sys
 import numpy as np
-from train_utils import SaveInOut, TrainWrap
+from train_utils import save_inout, train_wrap
 from official.cv.lenet.src.lenet import LeNet5
 import mindspore.common.dtype as mstype
 from mindspore import context, Tensor, nn
@@ -28,11 +28,11 @@ n = LeNet5()
 loss_fn = nn.MSELoss()
 optimizer = nn.Adam(n.trainable_params(), learning_rate=1e-2, beta1=0.5, beta2=0.7, eps=1e-2, use_locking=True,
                     use_nesterov=False, weight_decay=0.0, loss_scale=0.3)
-net = TrainWrap(n, loss_fn, optimizer)
+net = train_wrap(n, loss_fn, optimizer)
 
 x = Tensor(np.random.randn(32, 1, 32, 32), mstype.float32)
 label = Tensor(np.zeros([32, 10]).astype(np.float32))
 export(net, x, label, file_name="mindir/lenet_train", file_format='MINDIR')
 
 if len(sys.argv) > 1:
-    SaveInOut(sys.argv[1] + "lenet", x, label, n, net, sparse=False)
+    save_inout(sys.argv[1] + "lenet", x, label, n, net, sparse=False)
diff --git a/mindspore/lite/examples/export_models/models/mini_alexnet.py b/mindspore/lite/examples/export_models/models/mini_alexnet.py
index 9a8b828..e6008fa 100644
--- a/mindspore/lite/examples/export_models/models/mini_alexnet.py
+++ b/mindspore/lite/examples/export_models/models/mini_alexnet.py
@@ -17,13 +17,16 @@
 import mindspore.nn as nn
 from mindspore.ops import operations as P
 
+
 def conv(in_channels, out_channels, kernel_size, stride=1, padding=0, pad_mode="valid", has_bias=True):
     return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
                      has_bias=has_bias, pad_mode=pad_mode)
 
+
 def fc_with_initialize(input_channels, out_channels, has_bias=True):
     return nn.Dense(input_channels, out_channels, has_bias=has_bias)
 
+
 class AlexNet(nn.Cell):
     """
     Alexnet
diff --git a/mindspore/lite/examples/export_models/models/mini_alexnet_train_export.py b/mindspore/lite/examples/export_models/models/mini_alexnet_train_export.py
index 544daad..1b9a82d 100644
--- a/mindspore/lite/examples/export_models/models/mini_alexnet_train_export.py
+++ b/mindspore/lite/examples/export_models/models/mini_alexnet_train_export.py
@@ -16,7 +16,7 @@
 
 import sys
 import numpy as np
-from train_utils import SaveInOut, TrainWrap
+from train_utils import save_inout, train_wrap
 from mini_alexnet import AlexNet
 from mindspore import context, Tensor, nn
 from mindspore.train.serialization import export
@@ -31,11 +31,11 @@ n = AlexNet(phase='test')
 loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
 optimizer = nn.Adam(n.trainable_params(), learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
                     use_nesterov=False, weight_decay=0.0, loss_scale=1.0)
-net = TrainWrap(n, loss_fn, optimizer)
+net = train_wrap(n, loss_fn, optimizer)
 
 x = Tensor(np.ones([batch, 1, 32, 32]).astype(np.float32) * 0.01)
 label = Tensor(np.zeros([batch, number_of_classes]).astype(np.float32))
 export(net, x, label, file_name="mindir/mini_alexnet_train", file_format='MINDIR')
 
 if len(sys.argv) > 1:
-    SaveInOut(sys.argv[1] + "mini_alexnet", x, label, n, net, sparse=False)
+    save_inout(sys.argv[1] + "mini_alexnet", x, label, n, net, sparse=False)
diff --git a/mindspore/lite/examples/export_models/models/mobilenetv1_train_export.py b/mindspore/lite/examples/export_models/models/mobilenetv1_train_export.py
index f668a96..b3b26de 100644
--- a/mindspore/lite/examples/export_models/models/mobilenetv1_train_export.py
+++ b/mindspore/lite/examples/export_models/models/mobilenetv1_train_export.py
@@ -16,7 +16,7 @@
 
 import sys
 import numpy as np
-from train_utils import SaveInOut, TrainWrap
+from train_utils import save_inout, train_wrap
 from official.cv.mobilenetv1.src.mobilenet_v1 import MobileNetV1
 import mindspore.common.dtype as mstype
 from mindspore import context, Tensor, nn
@@ -28,7 +28,7 @@ n = MobileNetV1(10)
 loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
 optimizer = nn.SGD(n.trainable_params(), learning_rate=1e-2, momentum=0.9, dampening=0.1, weight_decay=0.0,
                    nesterov=False, loss_scale=1.0)
-net = TrainWrap(n, loss_fn, optimizer)
+net = train_wrap(n, loss_fn, optimizer)
 
 batch = 2
 x = Tensor(np.random.randn(batch, 3, 224, 224), mstype.float32)
@@ -37,4 +37,4 @@ label = Tensor(np.zeros([batch, 10]).astype(np.float32))
 export(net, x, label, file_name="mindir/mobilenetv1_train", file_format='MINDIR')
 
 if len(sys.argv) > 1:
-    SaveInOut(sys.argv[1] + "mobilenetv1", x, label, n, net)
+    save_inout(sys.argv[1] + "mobilenetv1", x, label, n, net)
diff --git a/mindspore/lite/examples/export_models/models/mobilenetv2_train_export.py b/mindspore/lite/examples/export_models/models/mobilenetv2_train_export.py
index 8f1d543..0433063 100644
--- a/mindspore/lite/examples/export_models/models/mobilenetv2_train_export.py
+++ b/mindspore/lite/examples/export_models/models/mobilenetv2_train_export.py
@@ -16,7 +16,7 @@
 
 import sys
 import numpy as np
-from train_utils import SaveInOut, TrainWrap
+from train_utils import save_inout, train_wrap
 from official.cv.mobilenetv2.src.mobilenetV2 import MobileNetV2Backbone, MobileNetV2Head, mobilenet_v2
 import mindspore.common.dtype as mstype
 from mindspore import context, Tensor, nn
@@ -31,11 +31,11 @@ n = mobilenet_v2(backbone_net, head_net)
 
 loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
 optimizer = nn.Momentum(n.trainable_params(), 0.01, 0.9, use_nesterov=False)
-net = TrainWrap(n, loss_fn, optimizer)
+net = train_wrap(n, loss_fn, optimizer)
 
 x = Tensor(np.random.randn(batch, 3, 224, 224), mstype.float32)
 label = Tensor(np.zeros([batch, 10]).astype(np.float32))
 export(net, x, label, file_name="mindir/mobilenetv2_train", file_format='MINDIR')
 
 if len(sys.argv) > 1:
-    SaveInOut(sys.argv[1] + "mobilenetv2", x, label, n, net, sparse=False)
+    save_inout(sys.argv[1] + "mobilenetv2", x, label, n, net, sparse=False)
diff --git a/mindspore/lite/examples/export_models/models/mobilenetv3_train_export.py b/mindspore/lite/examples/export_models/models/mobilenetv3_train_export.py
index 29d618d..f6743b6 100644
--- a/mindspore/lite/examples/export_models/models/mobilenetv3_train_export.py
+++ b/mindspore/lite/examples/export_models/models/mobilenetv3_train_export.py
@@ -16,7 +16,7 @@
 
 import sys
 import numpy as np
-from train_utils import SaveInOut, TrainWrap
+from train_utils import save_inout, train_wrap
 from official.cv.mobilenetv3.src.mobilenetV3 import mobilenet_v3_small
 import mindspore.common.dtype as mstype
 from mindspore import context, Tensor, nn
@@ -28,7 +28,7 @@ n = mobilenet_v3_small(num_classes=10)
 loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False, reduction='mean')
 optimizer = nn.Adam(n.trainable_params(), learning_rate=1e-2, beta1=0.5, beta2=0.7, eps=1e-2, use_locking=True,
                     use_nesterov=False, weight_decay=0.1, loss_scale=0.3)
-net = TrainWrap(n, loss_fn, optimizer)
+net = train_wrap(n, loss_fn, optimizer)
 
 batch = 2
 x = Tensor(np.random.randn(batch, 3, 224, 224), mstype.float32)
@@ -36,4 +36,4 @@ label = Tensor(np.zeros([batch, 10]).astype(np.float32))
 export(net, x, label, file_name="mindir/mobilenetv3_train", file_format='MINDIR')
 
 if len(sys.argv) > 1:
-    SaveInOut(sys.argv[1] + "mobilenetv3", x, label, n, net, sparse=False)
+    save_inout(sys.argv[1] + "mobilenetv3", x, label, n, net, sparse=False)
diff --git a/mindspore/lite/examples/export_models/models/nin_train_export.py b/mindspore/lite/examples/export_models/models/nin_train_export.py
index 72ccc5e..786f739 100644
--- a/mindspore/lite/examples/export_models/models/nin_train_export.py
+++ b/mindspore/lite/examples/export_models/models/nin_train_export.py
@@ -16,7 +16,7 @@
 
 import sys
 import numpy as np
-from train_utils import SaveInOut, TrainWrap
+from train_utils import save_inout, train_wrap
 from NetworkInNetwork import NiN
 import mindspore.common.dtype as mstype
 from mindspore import context, Tensor, nn
@@ -28,7 +28,7 @@ n = NiN(num_classes=10)
 loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
 optimizer = nn.SGD(n.trainable_params(), learning_rate=0.01, momentum=0.9, dampening=0.0, weight_decay=5e-4,
                    nesterov=True, loss_scale=0.9)
-net = TrainWrap(n, loss_fn, optimizer)
+net = train_wrap(n, loss_fn, optimizer)
 
 batch = 2
 x = Tensor(np.random.randn(batch, 3, 32, 32), mstype.float32)
@@ -36,4 +36,4 @@ label = Tensor(np.zeros([batch]).astype(np.int32))
 export(net, x, label, file_name="mindir/nin_train", file_format='MINDIR')
 
 if len(sys.argv) > 1:
-    SaveInOut(sys.argv[1] + "nin", x, label, n, net)
+    save_inout(sys.argv[1] + "nin", x, label, n, net)
diff --git a/mindspore/lite/examples/export_models/models/resnet_train_export.py b/mindspore/lite/examples/export_models/models/resnet_train_export.py
index c0dbe90..c18bcf3 100644
--- a/mindspore/lite/examples/export_models/models/resnet_train_export.py
+++ b/mindspore/lite/examples/export_models/models/resnet_train_export.py
@@ -16,7 +16,7 @@
 
 import sys
 import numpy as np
-from train_utils import SaveInOut, TrainWrap
+from train_utils import save_inout, train_wrap
 from official.cv.resnet.src.resnet import resnet50
 import mindspore.common.dtype as mstype
 from mindspore import context, Tensor, nn
@@ -29,11 +29,11 @@ n = resnet50(class_num=10)
 loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
 optimizer = nn.SGD(n.trainable_params(), learning_rate=0.01, momentum=0.9, dampening=0.0, weight_decay=0.0,
                    nesterov=True, loss_scale=1.0)
-net = TrainWrap(n, loss_fn, optimizer)
+net = train_wrap(n, loss_fn, optimizer)
 
 x = Tensor(np.random.randn(batch, 3, 224, 224), mstype.float32)
 label = Tensor(np.zeros([batch, 10]).astype(np.float32))
 export(net, x, label, file_name="mindir/resnet_train", file_format='MINDIR')
 
 if len(sys.argv) > 1:
-    SaveInOut(sys.argv[1] + "resnet", x, label, n, net)
+    save_inout(sys.argv[1] + "resnet", x, label, n, net)
diff --git a/mindspore/lite/examples/export_models/models/shufflenetv2_train_export.py b/mindspore/lite/examples/export_models/models/shufflenetv2_train_export.py
index 97aa4ec..bf76d48 100644
--- a/mindspore/lite/examples/export_models/models/shufflenetv2_train_export.py
+++ b/mindspore/lite/examples/export_models/models/shufflenetv2_train_export.py
@@ -16,7 +16,7 @@
 
 import sys
 import numpy as np
-from train_utils import SaveInOut, TrainWrap
+from train_utils import save_inout, train_wrap
 from official.cv.shufflenetv2.src.shufflenetv2 import ShuffleNetV2
 import mindspore.common.dtype as mstype
 from mindspore import context, Tensor, nn
@@ -28,7 +28,7 @@ n = ShuffleNetV2(n_class=10)
 loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
 optimizer = nn.Momentum(n.trainable_params(), 0.01, 0.9, use_nesterov=False)
 
-net = TrainWrap(n, loss_fn, optimizer)
+net = train_wrap(n, loss_fn, optimizer)
 
 batch = 2
 x = Tensor(np.random.randn(batch, 3, 224, 224), mstype.float32)
@@ -36,4 +36,4 @@ label = Tensor(np.zeros([batch, 10]).astype(np.float32))
 export(net, x, label, file_name="mindir/shufflenetv2_train", file_format='MINDIR')
 
 if len(sys.argv) > 1:
-    SaveInOut(sys.argv[1] + "shufflenetv2", x, label, n, net)
+    save_inout(sys.argv[1] + "shufflenetv2", x, label, n, net)
diff --git a/mindspore/lite/examples/export_models/models/train_utils.py b/mindspore/lite/examples/export_models/models/train_utils.py
index e32fda1..2fa6690 100644
--- a/mindspore/lite/examples/export_models/models/train_utils.py
+++ b/mindspore/lite/examples/export_models/models/train_utils.py
@@ -16,9 +16,11 @@
 
 from mindspore import nn, Tensor
 from mindspore.common.parameter import ParameterTuple
+import os
 
-def TrainWrap(net, loss_fn=None, optimizer=None, weights=None):
-    """TrainWrap"""
+
+def train_wrap(net, loss_fn=None, optimizer=None, weights=None):
+    """train_wrap"""
     if loss_fn is None:
         loss_fn = nn.SoftmaxCrossEntropyWithLogits()
     loss_net = nn.WithLossCell(net, loss_fn)
@@ -32,22 +34,22 @@ def TrainWrap(net, loss_fn=None, optimizer=None, weights=None):
     return train_net
 
 
-def SaveT(t, file):
+def save_t(t, file):
     x = t.asnumpy()
     x.tofile(file)
 
 
-def SaveInOut(name, x, l, net, net_train, sparse=False, epoch=1):
-    """SaveInOut"""
+def save_inout(name, x, l, net, net_train, sparse=False, epoch=1):
+    """save_inout"""
     x_name = name + "_input1.bin"
     if sparse:
         x_name = name + "_input2.bin"
-    SaveT(Tensor(x.asnumpy().transpose(0, 2, 3, 1)), x_name)
+    save_t(Tensor(x.asnumpy().transpose(0, 2, 3, 1)), x_name)
 
     l_name = name + "_input2.bin"
     if sparse:
         l_name = name + "_input1.bin"
-    SaveT(l, l_name)
+    save_t(l, l_name)
 
     net.set_train(False)
     y = net(x)
@@ -62,10 +64,10 @@ def SaveInOut(name, x, l, net, net_train, sparse=False, epoch=1):
     if isinstance(y, tuple):
         i = 1
         for t in y:
-            with open(name + "_output" + str(i) + ".bin", 'w') as f:
+            with os.fdopen(name + "_output" + str(i) + ".bin", 'w') as f:
                 for j in t.asnumpy().flatten():
                     f.write(str(j)+' ')
             i = i + 1
     else:
         y_name = name + "_output1.bin"
-        SaveT(y, y_name)
+        save_t(y, y_name)
diff --git a/mindspore/lite/examples/export_models/models/vgg_train_export.py b/mindspore/lite/examples/export_models/models/vgg_train_export.py
index 0078252..c18b33c 100644
--- a/mindspore/lite/examples/export_models/models/vgg_train_export.py
+++ b/mindspore/lite/examples/export_models/models/vgg_train_export.py
@@ -16,7 +16,7 @@
 
 import sys
 import numpy as np
-from train_utils import SaveInOut, TrainWrap
+from train_utils import save_inout, train_wrap
 from official.cv.vgg16.src.vgg import vgg16
 import mindspore.common.dtype as mstype
 from mindspore import context, Tensor, nn
@@ -29,11 +29,11 @@ batch = 2
 n = vgg16(num_classes=10)
 loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
 optimizer = nn.Momentum(n.trainable_params(), 0.01, 0.9, use_nesterov=False)
-net = TrainWrap(n, loss_fn, optimizer)
+net = train_wrap(n, loss_fn, optimizer)
 
 x = Tensor(np.random.randn(batch, 3, 224, 224), mstype.float32)
 label = Tensor(np.zeros([batch, 10]).astype(np.float32))
 export(net, x, label, file_name="mindir/vgg_train", file_format='MINDIR')
 
 if len(sys.argv) > 1:
-    SaveInOut(sys.argv[1] + "vgg", x, label, n, net)
+    save_inout(sys.argv[1] + "vgg", x, label, n, net)
diff --git a/mindspore/lite/examples/export_models/models/xception_train_export.py b/mindspore/lite/examples/export_models/models/xception_train_export.py
index 6b82b3b..e544d7e 100644
--- a/mindspore/lite/examples/export_models/models/xception_train_export.py
+++ b/mindspore/lite/examples/export_models/models/xception_train_export.py
@@ -16,7 +16,7 @@
 
 import sys
 import numpy as np
-from train_utils import SaveInOut, TrainWrap
+from train_utils import save_inout, train_wrap
 from official.cv.xception.src.Xception import Xception
 import mindspore.common.dtype as mstype
 from mindspore import context, Tensor, nn
@@ -31,7 +31,7 @@ n.dropout = nn.Dropout(keep_prob=1.0)
 loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
 optimizer = nn.SGD(n.trainable_params(), learning_rate=0.01, momentum=0.9, dampening=0.0, weight_decay=0.0,
                    nesterov=True, loss_scale=1.0)
-net = TrainWrap(n, loss_fn, optimizer)
+net = train_wrap(n, loss_fn, optimizer)
 
 batch = 2
 x = Tensor(np.random.randn(batch, 3, 299, 299), mstype.float32)
@@ -39,4 +39,4 @@ label = Tensor(np.zeros([batch, 1000]).astype(np.float32))
 export(net, x, label, file_name="mindir/xception_train", file_format='MINDIR')
 
 if len(sys.argv) > 1:
-    SaveInOut(sys.argv[1] + "xception", x, label, n, net)
+    save_inout(sys.argv[1] + "xception", x, label, n, net)
diff --git a/mindspore/lite/examples/train_lenet/model/lenet_export.py b/mindspore/lite/examples/train_lenet/model/lenet_export.py
index 8a9cd7c..c774887 100644
--- a/mindspore/lite/examples/train_lenet/model/lenet_export.py
+++ b/mindspore/lite/examples/train_lenet/model/lenet_export.py
@@ -19,7 +19,7 @@ from mindspore import context, Tensor
 import mindspore.common.dtype as mstype
 from mindspore.train.serialization import export
 from lenet import LeNet5
-from train_utils import TrainWrap
+from train_utils import train_wrap
 
 n = LeNet5()
 n.set_train()
@@ -28,7 +28,7 @@ context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU", save_graphs
 BATCH_SIZE = 32
 x = Tensor(np.ones((BATCH_SIZE, 1, 32, 32)), mstype.float32)
 label = Tensor(np.zeros([BATCH_SIZE]).astype(np.int32))
-net = TrainWrap(n)
+net = train_wrap(n)
 export(net, x, label, file_name="lenet_tod", file_format='MINDIR')
 
 print("finished exporting")
diff --git a/mindspore/lite/examples/train_lenet/model/train_utils.py b/mindspore/lite/examples/train_lenet/model/train_utils.py
index 9e8e3fa..9e3ad76 100644
--- a/mindspore/lite/examples/train_lenet/model/train_utils.py
+++ b/mindspore/lite/examples/train_lenet/model/train_utils.py
@@ -17,9 +17,10 @@
 import mindspore.nn as nn
 from mindspore.common.parameter import ParameterTuple
 
-def TrainWrap(net, loss_fn=None, optimizer=None, weights=None):
+
+def train_wrap(net, loss_fn=None, optimizer=None, weights=None):
     """
-    TrainWrap
+    train_wrap
     """
     if loss_fn is None:
         loss_fn = nn.SoftmaxCrossEntropyWithLogits(reduction='mean', sparse=True)
diff --git a/mindspore/lite/examples/transfer_learning/model/effnet.py b/mindspore/lite/examples/transfer_learning/model/effnet.py
index 8ed066f..eba29b5 100755
--- a/mindspore/lite/examples/transfer_learning/model/effnet.py
+++ b/mindspore/lite/examples/transfer_learning/model/effnet.py
@@ -44,6 +44,7 @@ class Swish(nn.Cell):
         m = x*s
         return m
 
+
 class AdaptiveAvgPool(nn.Cell):
     def __init__(self, output_size=None):
         super().__init__(AdaptiveAvgPool)
@@ -53,6 +54,7 @@ class AdaptiveAvgPool(nn.Cell):
     def construct(self, x):
         return self.mean(x, (2, 3))
 
+
 class SELayer(nn.Cell):
     """
     SELayer
@@ -77,6 +79,7 @@ class SELayer(nn.Cell):
         o = self.act2(o)
         return x * o
 
+
 class DepthwiseSeparableConv(nn.Cell):
     """
     DepthwiseSeparableConv
@@ -84,7 +87,9 @@ class DepthwiseSeparableConv(nn.Cell):
     def __init__(self, in_chs, out_chs, dw_kernel_size=3,
                  stride=1, noskip=False, se_ratio=0.0, drop_connect_rate=0.0):
         super().__init__(DepthwiseSeparableConv)
-        assert stride in [1, 2]
+        if stride not in [1, 2]:
+            print("ERROR")
+            return
         self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
         self.drop_connect_rate = drop_connect_rate
 
@@ -117,6 +122,7 @@ class DepthwiseSeparableConv(nn.Cell):
             x += residual
         return x
 
+
 def conv_3x3_bn(inp, oup, stride):
     weight = weight_variable()
     return nn.SequentialCell([
@@ -125,6 +131,7 @@ def conv_3x3_bn(inp, oup, stride):
         nn.BatchNorm2d(oup, eps=0.001),  # , momentum=0.1),
         nn.HSwish()])
 
+
 def conv_1x1_bn(inp, oup):
     weight = weight_variable()
     return nn.SequentialCell([
@@ -133,13 +140,16 @@ def conv_1x1_bn(inp, oup):
         nn.BatchNorm2d(oup, eps=0.001),
         nn.HSwish()])
 
+
 class InvertedResidual(nn.Cell):
     """
     InvertedResidual
     """
     def __init__(self, in_chs, out_chs, kernel_size, stride, padding, expansion, se_ratio):
         super().__init__(InvertedResidual)
-        assert stride in [1, 2]
+        if stride not in [1, 2]:
+            print("ERROR")
+            return
         mid_chs: int = _make_divisible(in_chs * expansion, 1)
         self.has_residual = (in_chs == out_chs and stride == 1)
         self.drop_connect_rate = 0
@@ -194,6 +204,7 @@ class InvertedResidual(nn.Cell):
             x += residual
         return x
 
+
 class EfficientNet(nn.Cell):
     """
     EfficientNet
@@ -295,6 +306,7 @@ class EfficientNet(nn.Cell):
             elif isinstance(m, nn.Dense):
                 init_linear_weight(m)
 
+
 def effnet(**kwargs):
     """
     Constructs a EfficientNet model
diff --git a/mindspore/lite/include/train/accuracy_metrics.h b/mindspore/lite/include/train/accuracy_metrics.h
index e3822fd..9dfa451 100644
--- a/mindspore/lite/include/train/accuracy_metrics.h
+++ b/mindspore/lite/include/train/accuracy_metrics.h
@@ -41,6 +41,7 @@ class AccuracyMetrics : public Metrics {
   std::vector<int> output_indexes_ = {0};
   float total_accuracy_ = 0.0;
   float total_steps_ = 0.0;
+  friend class ClassificationTrainAccuracyMonitor;
 };
 
 }  // namespace lite
diff --git a/mindspore/lite/include/train/classification_train_accuracy_monitor.h b/mindspore/lite/include/train/classification_train_accuracy_monitor.h
index 5c85592..0c461c5 100644
--- a/mindspore/lite/include/train/classification_train_accuracy_monitor.h
+++ b/mindspore/lite/include/train/classification_train_accuracy_monitor.h
@@ -44,9 +44,7 @@ class ClassificationTrainAccuracyMonitor : public session::TrainLoopCallBack {
 
  private:
   std::vector<GraphPoint> accuracies_;
-  int accuracy_metrics_ = METRICS_CLASSIFICATION;
-  std::vector<int> input_indexes_ = {1};
-  std::vector<int> output_indexes_ = {0};
+  std::shared_ptr<AccuracyMetrics> accuracy_metrics_;
   int print_every_n_ = 0;
 };
 
diff --git a/mindspore/lite/src/dequant.h b/mindspore/lite/src/dequant.h
index 919b388..1e554c9 100644
--- a/mindspore/lite/src/dequant.h
+++ b/mindspore/lite/src/dequant.h
@@ -40,19 +40,9 @@ class DequantUtil {
   static void RestoreTensorData(const std::map<Tensor *, std::pair<TypeId, void *>> &tensor_origin_data_map);
 
   template <typename ST, typename DT = float>
-  static DT *DequantData(lite::Tensor *input_tensor, bool channel_first = true) {
-    const auto *quant_datas = static_cast<const ST *>(input_tensor->MutableData());
-    if (quant_datas == nullptr) {
-      MS_LOG(ERROR) << "Get quant tensor failed.";
-      return nullptr;
-    }
-    DT *dequant_datas = static_cast<DT *>(malloc(input_tensor->ElementsNum() * sizeof(DT)));
-    if (dequant_datas == nullptr) {
-      MS_LOG(ERROR) << "Malloc failed.";
-      return nullptr;
-    }
+  static bool IsPerBatch(lite::Tensor *input_tensor,ST * quant_datas,DT *dequant_datas) {
     if (input_tensor->shape().size() == kPerBatch &&
-        input_tensor->quant_params().size() == static_cast<size_t>(input_tensor->shape().at(0))) {  // per batch matmul
+      input_tensor->quant_params().size() == static_cast<size_t>(input_tensor->shape().at(0))) {  // per batch matmul
       auto per_batch_size = input_tensor->shape().at(0);
       auto quant_param = input_tensor->quant_params();
       for (int i = 0; i < per_batch_size; i++) {
@@ -64,42 +54,63 @@ class DequantUtil {
           dequant_datas[i * matrix_size + j] = static_cast<DT>((quant_datas[i * matrix_size + j] - zero_point) * scale);
         }
       }
-    } else if (input_tensor->quant_params().size() != kPerTensor) {
-      auto channels = static_cast<size_t>(input_tensor->Batch());
-      if (!channel_first) {
-        if (input_tensor->shape().size() != 2) {
-          MS_LOG(ERROR) << "unexpected shape size: " << input_tensor->shape().size();
-          free(dequant_datas);
-          return nullptr;
-        }
-        channels = input_tensor->shape()[1];
-      }
-      if (input_tensor->quant_params().size() != channels) {
-        MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->quant_params().size() << channels;
+      return true;
+    }
+    return false;
+  }
+  template <typename ST, typename DT = float>
+  static DT * DequantPerChannel(lite::Tensor *input_tensor,ST * quant_datas,DT *dequant_datas,bool channel_first) {
+    auto channels = static_cast<size_t>(input_tensor->Batch());
+    if (!channel_first) {
+      if (input_tensor->shape().size() != 2) {
+        MS_LOG(ERROR) << "unexpected shape size: " << input_tensor->shape().size();
         free(dequant_datas);
         return nullptr;
       }
-      size_t per_channel_size = input_tensor->ElementsNum() / channels;
-      auto quant_param = input_tensor->quant_params();
-      for (size_t i = 0; i < channels; i++) {
-        auto param = quant_param.at(i);
-        auto scale = param.scale;
-        auto zero_point = param.zeroPoint;
-        auto var_corr = param.var_corr;
-        auto mean_corr = param.mean_corr;
-        if (var_corr < 0 || var_corr > 10) {
-          MS_LOG(WARNING) << "unexpected var_corr: " << var_corr;
-          var_corr = 1;
-        }
-        for (size_t j = 0; j < per_channel_size; j++) {
-          auto index = per_channel_size * i + j;
-          if (!channel_first) {
-            index = channels * j + i;
-          }
-          auto dequant_data = (quant_datas[index] - zero_point) * scale;
-          dequant_datas[index] = static_cast<DT>(dequant_data * var_corr + mean_corr);
+      channels = input_tensor->shape()[1];
+    }
+    if (input_tensor->quant_params().size() != channels) {
+      MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->quant_params().size() << channels;
+      free(dequant_datas);
+      return nullptr;
+    }
+    size_t per_channel_size = input_tensor->ElementsNum() / channels;
+    auto quant_param = input_tensor->quant_params();
+    for (size_t i = 0; i < channels; i++) {
+      auto param = quant_param.at(i);
+      auto scale = param.scale;
+      auto zero_point = param.zeroPoint;
+      auto var_corr = param.var_corr;
+      auto mean_corr = param.mean_corr;
+      if (var_corr < 0 || var_corr > 10) {
+        MS_LOG(WARNING) << "unexpected var_corr: " << var_corr;
+        var_corr = 1;
+      }
+      for (size_t j = 0; j < per_channel_size; j++) {
+        auto index = per_channel_size * i + j;
+        if (!channel_first) {
+          index = channels * j + i;
         }
+        auto dequant_data = (quant_datas[index] - zero_point) * scale;
+        dequant_datas[index] = static_cast<DT>(dequant_data * var_corr + mean_corr);
       }
+    }
+    return dequant_datas;
+  }
+  template <typename ST, typename DT = float>
+  static DT *DequantData(lite::Tensor *input_tensor, bool channel_first = true) {
+    const auto *quant_datas = static_cast<const ST *>(input_tensor->MutableData());
+    if (quant_datas == nullptr) {
+      MS_LOG(ERROR) << "Get quant tensor failed.";
+      return nullptr;
+    }
+    DT *dequant_datas = static_cast<DT *>(malloc(input_tensor->ElementsNum() * sizeof(DT)));
+    if (dequant_datas == nullptr) {
+      MS_LOG(ERROR) << "Malloc failed.";
+      return nullptr;
+    }
+    if (!IsPerBatch(input_tensor, quant_datas, dequant_datas) && input_tensor->quant_params().size() != kPerTensor) {
+        return DequantPerChannel(input_tensor, quant_datas, dequant_datas,channel_first);
     } else {
       auto quant_param = input_tensor->quant_params();
       auto quant_clusters = input_tensor->quant_clusters();
diff --git a/mindspore/lite/src/train/classification_train_accuracy_monitor.cc b/mindspore/lite/src/train/classification_train_accuracy_monitor.cc
index 5e54e37..8523863 100644
--- a/mindspore/lite/src/train/classification_train_accuracy_monitor.cc
+++ b/mindspore/lite/src/train/classification_train_accuracy_monitor.cc
@@ -27,20 +27,11 @@ using mindspore::WARNING;
 
 namespace mindspore {
 namespace lite {
+
 ClassificationTrainAccuracyMonitor::ClassificationTrainAccuracyMonitor(int print_every_n, int accuracy_metrics,
                                                                        const std::vector<int> &input_indexes,
                                                                        const std::vector<int> &output_indexes) {
-  if (input_indexes.size() == output_indexes.size()) {
-    input_indexes_ = input_indexes;
-    output_indexes_ = output_indexes;
-  } else {
-    MS_LOG(WARNING) << "input to output mapping vectors sizes do not match";
-  }
-  if (accuracy_metrics != METRICS_CLASSIFICATION) {
-    MS_LOG(WARNING) << "Only classification metrics is supported";
-  } else {
-    accuracy_metrics_ = accuracy_metrics;
-  }
+  accuracy_metrics_ = std::make_shared<AccuracyMetrics>(accuracy_metrics,input_indexes,output_indexes);
   print_every_n_ = print_every_n;
 }
 
@@ -59,8 +50,8 @@ void ClassificationTrainAccuracyMonitor::EpochBegin(const session::TrainLoopCall
 int ClassificationTrainAccuracyMonitor::EpochEnd(const session::TrainLoopCallBackData &cb_data) {
   if (cb_data.step_ > 0) accuracies_.at(cb_data.epoch_).second /= static_cast<float>(cb_data.step_ + 1);
   if ((cb_data.epoch_ + 1) % print_every_n_ == 0) {
-    std::cout << "Epoch (" << (cb_data.epoch_ + 1) << "):\tTraining Accuracy is "
-              << accuracies_.at(cb_data.epoch_).second << std::endl;
+    std::cout << "Epoch (" << cb_data.epoch_ + 1 << "):\tTraining Accuracy is " << accuracies_.at(cb_data.epoch_).second
+              << std::endl;
   }
   return mindspore::session::RET_CONTINUE;
 }
@@ -70,20 +61,23 @@ void ClassificationTrainAccuracyMonitor::StepEnd(const session::TrainLoopCallBac
   auto outputs = cb_data.session_->GetPredictions();
 
   float accuracy = 0.0;
-  for (unsigned int i = 0; i < input_indexes_.size(); i++) {
-    if ((inputs.size() <= static_cast<unsigned int>(input_indexes_[i])) ||
-        (outputs.size() <= static_cast<unsigned int>(output_indexes_[i]))) {
-      MS_LOG(WARNING) << "indices " << input_indexes_[i] << "/" << output_indexes_[i]
+  auto input_indexes = accuracy_metrics_->input_indexes_;
+  auto output_indexes = accuracy_metrics_->output_indexes_;
+  for (unsigned int i = 0; i < input_indexes.size(); i++) {
+    if ((inputs.size() <= static_cast<unsigned int>(input_indexes[i])) ||
+        (outputs.size() <= static_cast<unsigned int>(output_indexes[i]))) {
+      MS_LOG(WARNING) << "indices " << input_indexes[i] << "/" << output_indexes[i]
                       << " is outside of input/output range";
       return;
     }
-    if (inputs.at(input_indexes_[i])->data_type() == kNumberTypeInt32) {
-      accuracy += CalculateSparseClassification(inputs.at(input_indexes_[i]), outputs.at(output_indexes_[i]));
+    if (inputs.at(input_indexes[i])->data_type() == kNumberTypeInt32) {
+      accuracy += CalculateSparseClassification(inputs.at(input_indexes[i]), outputs.at(output_indexes[i]));
     } else {
-      accuracy += CalculateOneHotClassification(inputs.at(input_indexes_[i]), outputs.at(output_indexes_[i]));
+      accuracy += CalculateOneHotClassification(inputs.at(input_indexes[i]), outputs.at(output_indexes[i]));
     }
   }
   accuracies_.at(cb_data.epoch_).second += accuracy;
 }
+
 }  // namespace lite
 }  // namespace mindspore
