# 载入数据

In [1]:
import sys
sys.path.append('E:/zlab/')
from gluonx.cifar import Cifar
from gluonx.loader import Transforms
# ---------------------------------------------
import time
import numpy as np
import mxnet as mx

from mxnet import gluon, nd, init
from mxnet import autograd as ag
from mxnet.gluon import nn

from gluoncv.model_zoo import get_model
from gluoncv.utils import makedirs, TrainingHistory

batch_size = 8
cifar = Cifar(batch_size)

  from ._conv import register_converters as _register_converters


In [None]:
class SemiModel(nn.HybridBlock):
    def __init__(self, features, **kwargs):
        super().__init__(**kwargs)
        self.features = features
        self.output = nn.Dense(10)

    def hybrid_forward(self, F, x):
        x = self.features(x)
        return self.output(x)


class TrainX(Transforms):
    def __init__(self, ctx, net, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__dict__ = self
        self.ctx = ctx
        # 网络预设
        _net = get_model('cifar_resnet20_v1', pretrained=True)
        net.features = _net.features
        net.output.initialize(init.Xavier(magnitude=0.24))
        net.collect_params().reset_ctx(self.ctx)
        net.hybridize()
        self.net = net
        self._opt(self.net)
        
    def get_data(self, batch, aug):
        Xs = nd.stack(*[aug(x)for x in batch[0]]).as_in_context(self.ctx)
        ys = batch[1].as_in_context(self.ctx)
        return Xs, ys

    def test_metric(self, net, val_data):
        metric = mx.metric.Accuracy()
        for i, batch in enumerate(val_data):
            Xs, ys = self.get_data(batch, self.test_aug)
            outputs = net(Xs)
            metric.update(ys, outputs)
        return metric.get()
    
    def _opt(self, net):
        '''
        Opt 初始化
        '''
        # Nesterov accelerated gradient descent
        optimizer = 'nag'
        # Set parameters
        optimizer_params = {'learning_rate': 0.1, 'wd': 0.0001, 'momentum': 0.9}
        # Define our trainer for net
        self.trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params)
        self.loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
        
    def train_op(self, epochs, trainset, valset):
        lr_decay_count = 0
        # Learning rate decay factor
        lr_decay = 0.1
        # Epochs where learning rate decays
        lr_decay_epoch = [80, 160, np.inf]
        train_metric = mx.metric.Accuracy()
        train_history = TrainingHistory(['training-error', 'validation-error'])
        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            train_loss = 0
            # Learning rate decay
            if epoch == lr_decay_epoch[lr_decay_count]:
                self.trainer.set_learning_rate(self.trainer.learning_rate*lr_decay)
                lr_decay_count += 1
            # Loop through each batch of training data
            for i, batch in enumerate(trainset):
                # Extract data and label
                Xs, ys = self.get_data(batch, self.train_aug)
                # AutoGrad
                with ag.record():
                    yhats = self.net(Xs)
                    loss = self.loss_fn(yhats, ys)
                # Backpropagation
                loss.backward()
                # Optimize
                self.trainer.step(trainset.batch_size)
                # Update metrics
                train_loss += loss.sum().asscalar()
                train_metric.update(ys, yhats)
            name, acc = train_metric.get()
            # Evaluate on Validation data
            name, val_acc = self.test_metric(self.net, valset)
            # Update history and print metrics
            train_history.update([1-acc, 1-val_acc])
            print('[Epoch %d] train=%f val=%f loss=%f time: %f' %
                  (epoch, acc, val_acc, train_loss, time.time()-tic))
        # We can plot the metric scores with:
        train_history.plot()

In [None]:
%pylab inline
#pretrain_net = get_model('resnet50_v2')
pretrain_net = get_model('cifar_resnet20_v1')
net = SemiModel(pretrain_net.features)
ctx = mx.gpu(0)
T = TrainX(ctx, net)

In [None]:
T.train_op(100, cifar.trainset, cifar.testset)