In [4]:
import mxnet as mx
import os
import tarfile
import numpy as np
from multiprocessing import cpu_count
CPU_COUNT = cpu_count()


In [6]:
def transform(data, label):
    data = data.astype('float32')/255
    return data, label

train_dataset = mx.gluon.data.vision.datasets.FashionMNIST(train=True, transform=transform)
valid_dataset = mx.gluon.data.vision.datasets.FashionMNIST(train=False, transform=transform)
batch_size = 32
train_data_loader = mx.gluon.data.DataLoader(train_dataset, batch_size, shuffle=True, num_workers=CPU_COUNT)
valid_data_loader = mx.gluon.data.DataLoader(valid_dataset, batch_size, num_workers=CPU_COUNT)

In [7]:
from mxnet import gluon, autograd, ndarray

def construct_net():
    net = gluon.nn.HybridSequential()
    with net.name_scope():
        net.add(gluon.nn.Dense(128, activation="relu"))
        net.add(gluon.nn.Dense(64, activation="relu"))
        net.add(gluon.nn.Dense(10))
    return net

# construct and initialize network.
ctx =  mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()

net = construct_net()
net.hybridize()
net.initialize(mx.init.Xavier(), ctx=ctx)
# define loss and trainer.
criterion = gluon.loss.SoftmaxCrossEntropyLoss()
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.1})

In [8]:
epochs = 5
for epoch in range(epochs):
    # training loop (with autograd and trainer steps, etc.)
    cumulative_train_loss = mx.nd.zeros(1, ctx=ctx)
    training_samples = 0
    for batch_idx, (data, label) in enumerate(train_data_loader):
        data = data.as_in_context(ctx).reshape((-1, 784)) # 28*28=784
        label = label.as_in_context(ctx)
        with autograd.record():
            output = net(data)
            loss = criterion(output, label)
            
        loss.backward()
        trainer.step(data.shape[0])
        cumulative_train_loss += loss.sum()
        training_samples += data.shape[0]
    train_loss = cumulative_train_loss.asscalar()/training_samples

    # validation loop
    cumulative_valid_loss = mx.nd.zeros(1, ctx)
    valid_samples = 0
    for batch_idx, (data, label) in enumerate(valid_data_loader):
        data = data.as_in_context(ctx).reshape((-1, 784)) # 28*28=784
        label = label.as_in_context(ctx)
        output = net(data)
        loss = criterion(output, label)
        cumulative_valid_loss += loss.sum()
        valid_samples += data.shape[0]
    valid_loss = cumulative_valid_loss.asscalar()/valid_samples

    print("Epoch {}, training loss: {:.2f}, validation loss: {:.2f}".format(epoch, train_loss, valid_loss))

Epoch 0, training loss: 0.55, validation loss: 0.47
Epoch 1, training loss: 0.40, validation loss: 0.37
Epoch 2, training loss: 0.36, validation loss: 0.37
Epoch 3, training loss: 0.33, validation loss: 0.35
Epoch 4, training loss: 0.32, validation loss: 0.32
