In [1]:
from time import time
import math
import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from mxnet.gluon.data.vision import datasets, transforms
from mxnet.gluon.data import Sampler, RandomSampler

In [2]:
class AdvSequentialSampler(Sampler):
    """Samples elements from [0, length) sequentially.

    Parameters
    ----------
    length : int
        Length of the sequence.
    """
    def __init__(self, length, start_idx):
        self._length = length
        self._start_idx = start_idx


    def __iter__(self):
        return iter(range(self._start_idx, self._length + self._start_idx))

    def __len__(self):
        return self._length

In [3]:
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(0.13, 0.31)])

org_train_data = datasets.MNIST(train=True)
valid_rate = 0.1
org_length = len(org_train_data)
batch_size = 128
end_idx = int(org_length * (1 - valid_rate))
train_dataloader = gluon.data.DataLoader(
    org_train_data.transform_first(transform), sampler=RandomSampler(end_idx),
    batch_size=batch_size, num_workers=4)
valid_dataloader = gluon.data.DataLoader(
    org_train_data.transform_first(transform), sampler=AdvSequentialSampler(org_length - end_idx, end_idx),
    batch_size=batch_size, num_workers=4)
test_dataloader = gluon.data.DataLoader(
    datasets.MNIST(train=False).transform_first(transform),
    batch_size=batch_size, num_workers=4)

In [4]:
ctx = mx.gpu() if mx.context.num_gpus() else mx.cpu()
model = nn.HybridSequential()
with model.name_scope():
    model.add(
        nn.Conv2D(6, kernel_size=5, activation='relu'),
        nn.MaxPool2D(pool_size=(2, 2)),
        nn.Conv2D(16, kernel_size=5, activation='relu'),
        nn.MaxPool2D(pool_size=(2, 2)),
        nn.Flatten(),
        nn.Dense(120, activation='relu'),
        nn.Dense(84, activation='relu'),
        nn.Dense(10)
    )
model.hybridize()
model.initialize(ctx=ctx)

In [5]:
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
max_epoch = 10
step = max_epoch * math.ceil(end_idx / batch_size)
lr_sch = mx.lr_scheduler.FactorScheduler(step=step, factor=0.1)
sgd_optimizer = mx.optimizer.SGD(learning_rate=0.1, lr_scheduler=lr_sch)
trainer = gluon.Trainer(model.collect_params(), optimizer=sgd_optimizer)
metric = mx.metric.Accuracy()

In [6]:
def validate(model, dataloader, metric, mode='Validation'):
    metric.reset()
    for sample_batch, labels in dataloader:
        sample_batch = sample_batch.as_in_context(ctx)
        labels = labels.as_in_context(ctx)
        preds = model(sample_batch)
        metric.update(labels=labels, preds=preds)
    
    acc = metric.get()[1]
    print('\t\t{} accuracy {:2.5f}'.format(mode, acc))
    return acc

In [7]:
best_acc = 0
best_epoch = 0
base_name = 'model'
ckpt_file_path = base_name + '.ckpt'
for epoch in range(max_epoch):
    total_loss = .0
    tic = time()
    for sample_batch, labels in train_dataloader:
        sample_batch = sample_batch.as_in_context(ctx)
        labels = labels.as_in_context(ctx)
        with mx.autograd.record():
            loss = loss_fn(model(sample_batch), labels)
        loss.backward()
        trainer.step(len(sample_batch))
        total_loss += loss.sum().asscalar()
        
    print('[Epoch {}]\tAvg. loss: {:2.4f}\tTime: {:2.2f}'.format(epoch + 1, total_loss/end_idx, time()-tic))
    acc = validate(model, valid_dataloader, metric)
    if acc > best_acc:
        best_acc = acc
        best_epoch = epoch
        model.export(base_name, epoch=epoch)

[Epoch 1]	Avg. loss: 0.7597	Time: 11.51
		Validation accuracy 0.94767
[Epoch 2]	Avg. loss: 0.1035	Time: 11.44
		Validation accuracy 0.97900
[Epoch 3]	Avg. loss: 0.0697	Time: 11.39
		Validation accuracy 0.98267
[Epoch 4]	Avg. loss: 0.0527	Time: 11.94
		Validation accuracy 0.98200
[Epoch 5]	Avg. loss: 0.0429	Time: 12.14
		Validation accuracy 0.98667
[Epoch 6]	Avg. loss: 0.0352	Time: 13.42
		Validation accuracy 0.98617
[Epoch 7]	Avg. loss: 0.0310	Time: 12.81
		Validation accuracy 0.98750
[Epoch 8]	Avg. loss: 0.0258	Time: 13.76
		Validation accuracy 0.98717
[Epoch 9]	Avg. loss: 0.0223	Time: 14.96
		Validation accuracy 0.98433
[Epoch 10]	Avg. loss: 0.0199	Time: 14.13
		Validation accuracy 0.98767


In [8]:
validate(model, test_dataloader, metric, 'Test')

		Test accuracy 0.98950


0.9895

In [9]:
print('Loading with json and params files')
model = gluon.nn.SymbolBlock.imports(base_name + '-symbol.json', ['data'], '{}-{:04d}.params'.format(base_name, best_epoch), ctx=ctx)
validate(model, test_dataloader, metric, 'Test')

Loading with json and params files
		Test accuracy 0.98950


0.9895