<a href="https://colab.research.google.com/github/thijsgelton/7-MxNet-Mnist/blob/main/mnist_mxnet_cgcv_group7.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install requests mxnet-cu100

Collecting mxnet-cu100
[?25l  Downloading https://files.pythonhosted.org/packages/85/09/a13d45136ce70589cceee4081f485f8f47fc5eb716d07981d4c2547763df/mxnet_cu100-1.8.0.post0-py2.py3-none-manylinux2014_x86_64.whl (352.6MB)
[K     |████████████████████████████████| 352.6MB 51kB/s 
Collecting graphviz<0.9.0,>=0.8.1
  Downloading https://files.pythonhosted.org/packages/53/39/4ab213673844e0c004bed8a0781a0721a3f6bb23eb8854ee75c236428892/graphviz-0.8.4-py2.py3-none-any.whl
Installing collected packages: graphviz, mxnet-cu100
  Found existing installation: graphviz 0.10.1
    Uninstalling graphviz-0.10.1:
      Successfully uninstalled graphviz-0.10.1
Successfully installed graphviz-0.8.4 mxnet-cu100-1.8.0.post0


In [None]:
import mxnet as mx

In [None]:
mnist = mx.test_utils.get_mnist()

In [None]:
batch_size = 100
train_data = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'], batch_size, shuffle=True)
val_data = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)

In [None]:
from __future__ import print_function
import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from mxnet import autograd as ag
import mxnet.ndarray as F

In [None]:
class ConvNet(nn.Block):
  
  def __init__(self, **kwargs):
    super(ConvNet, self).__init__(**kwargs)
    self.conv1 = nn.Conv2D(20, kernel_size=(5,5))
    self.pool1 = nn.MaxPool2D(pool_size=(2,2), strides=(2,2)) # No learnable parameters so can just use 1 maxpool
    self.conv2 = nn.Conv2D(50, kernel_size=(5, 5))
    self.fc1 = nn.Dense(500)
    self.fc2 = nn.Dense(10)

  def forward(self, x):
    x = self.pool1(F.tanh(self.conv1(x)))
    x = self.pool1(F.tanh(self.conv2(x)))

    x = x.reshape((0, -1))
    x = F.tanh(self.fc1(x))
    x = F.tanh(self.fc2(x))
    return x

In [None]:
conv_net = ConvNet()

In [None]:
ctx = [mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()]

In [None]:
with mx.Context(ctx[0]):
  gpu_array = mx.nd.ones((2, 3))

In [None]:
gpu_array.context

gpu(0)

In [None]:
with mx.Context(ctx[0]):
  conv_net.initialize(mx.init.Xavier(magnitude=2.24), force_reinit=True)
  trainer = gluon.Trainer(conv_net.collect_params(), 'sgd', {'learning_rate': 0.03, 'wd': 0.001})

In [None]:
epoch = 20

In [None]:
metric = mx.metric.Accuracy()
criterion = gluon.loss.SoftmaxCrossEntropyLoss()

for i in range(epoch):

  train_data.reset()
  for batch in train_data:
    data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx)
    label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx)

    outputs = []
    with ag.record():
      for x, y in zip(data, label):
        output = conv_net(x)
        loss = criterion(output, y)

        loss.backward() # Compute gradients
        outputs.append(output)
    metric.update(label, outputs)
    trainer.step(batch.data[0].shape[0])
  name, acc = metric.get()
  metric.reset()
  print(f"Training accuracy at epoch {i}: {name}={acc:.2f}")


Training accuracy at epoch 0: accuracy=0.85
Training accuracy at epoch 1: accuracy=0.94
Training accuracy at epoch 2: accuracy=0.95
Training accuracy at epoch 3: accuracy=0.96
Training accuracy at epoch 4: accuracy=0.97
Training accuracy at epoch 5: accuracy=0.97
Training accuracy at epoch 6: accuracy=0.97
Training accuracy at epoch 7: accuracy=0.98
Training accuracy at epoch 8: accuracy=0.98
Training accuracy at epoch 9: accuracy=0.98
Training accuracy at epoch 10: accuracy=0.98
Training accuracy at epoch 11: accuracy=0.98
Training accuracy at epoch 12: accuracy=0.98
Training accuracy at epoch 13: accuracy=0.98
Training accuracy at epoch 14: accuracy=0.98
Training accuracy at epoch 15: accuracy=0.98
Training accuracy at epoch 16: accuracy=0.98
Training accuracy at epoch 17: accuracy=0.98
Training accuracy at epoch 18: accuracy=0.98
Training accuracy at epoch 19: accuracy=0.99


In [None]:
metric = mx.metric.Accuracy()
val_data.reset()
for batch in val_data:
    data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx)
    label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx)
    outputs = []
    for x in data:
        outputs.append(conv_net(x))
    metric.update(label, outputs)
print('validation acc: %s=%f'%metric.get())
assert metric.get()[1] > 0.98

validation acc: accuracy=0.985300
