In [1]:
from tqdm import tqdm
import torch
import torch.optim
import torchnet as tnt
from torchvision.datasets.mnist import MNIST
from torchnet.engine import Engine
from torch.autograd import Variable
import torch.nn.functional as F
from torch.nn.init import kaiming_normal

In [2]:
def get_iterator(mode):
    ds = MNIST(root='./', download=True, train=mode)
    data = getattr(ds, 'train_data' if mode else 'test_data')
    labels = getattr(ds, 'train_labels' if mode else 'test_labels')
    tds = tnt.dataset.TensorDataset([data, labels])
    return tds.parallel(batch_size=128, num_workers=4, shuffle=mode)


def conv_init(ni, no, k):
    return kaiming_normal(torch.Tensor(no, ni, k, k))


def linear_init(ni, no):
    return kaiming_normal(torch.Tensor(no, ni))


def f(params, inputs, mode):
    o = inputs.view(inputs.size(0), 1, 28, 28)
    o = F.conv2d(o, params['conv0.weight'], params['conv0.bias'], stride=2)
    o = F.relu(o)
    o = F.conv2d(o, params['conv1.weight'], params['conv1.bias'], stride=2)
    o = F.relu(o)
    o = o.view(o.size(0), -1)
    o = F.linear(o, params['linear2.weight'], params['linear2.bias'])
    o = F.relu(o)
    o = F.linear(o, params['linear3.weight'], params['linear3.bias'])
    return o

In [3]:
def main():
    params = {
        'conv0.weight': conv_init(1, 50, 5), 'conv0.bias': torch.zeros(50),
        'conv1.weight': conv_init(50, 50, 5), 'conv1.bias': torch.zeros(50),
        'linear2.weight': linear_init(800, 512), 'linear2.bias': torch.zeros(512),
        'linear3.weight': linear_init(512, 10), 'linear3.bias': torch.zeros(10),
    }
    params = {k: Variable(v, requires_grad=True) for k, v in params.items()}

    optimizer = torch.optim.SGD(
        params.values(), lr=0.01, momentum=0.9, weight_decay=0.0005)

    engine = Engine()
    meter_loss = tnt.meter.AverageValueMeter()
    classerr = tnt.meter.ClassErrorMeter(accuracy=True)

    def h(sample):
        inputs = Variable(sample[0].float() / 255.0)
        targets = Variable(torch.LongTensor(sample[1]))
        o = f(params, inputs, sample[2])
        return F.cross_entropy(o, targets), o

    def reset_meters():
        classerr.reset()
        meter_loss.reset()

    def on_sample(state):
        state['sample'].append(state['train'])

    def on_forward(state):
        classerr.add(state['output'].data,
                     torch.LongTensor(state['sample'][1]))
        meter_loss.add(state['loss'].data[0])

    def on_start_epoch(state):
        reset_meters()
        state['iterator'] = tqdm(state['iterator'])

    def on_end_epoch(state):
        print('Training loss: %.4f, accuracy: %.2f%%' % (meter_loss.value()[0], classerr.value()[0]))
        # do validation at the end of each epoch
        reset_meters()
        engine.test(h, get_iterator(False))
        print('Testing loss: %.4f, accuracy: %.2f%%' % (meter_loss.value()[0], classerr.value()[0]))

    engine.hooks['on_sample'] = on_sample
    engine.hooks['on_forward'] = on_forward
    engine.hooks['on_start_epoch'] = on_start_epoch
    engine.hooks['on_end_epoch'] = on_end_epoch
    engine.train(h, get_iterator(True), maxepoch=10, optimizer=optimizer)

In [4]:
if __name__ == '__main__':
    main()

100%|██████████| 469/469 [00:34<00:00, 13.73it/s]

Training loss: 0.2502, accuracy: 92.43%



  0%|          | 0/469 [00:00<?, ?it/s]

Testing loss: 0.0971, accuracy: 96.96%


100%|██████████| 469/469 [00:36<00:00, 12.97it/s]

Training loss: 0.0873, accuracy: 97.39%



  0%|          | 0/469 [00:00<?, ?it/s]

Testing loss: 0.0668, accuracy: 97.85%


100%|██████████| 469/469 [00:36<00:00, 12.85it/s]

Training loss: 0.0619, accuracy: 98.14%



  0%|          | 0/469 [00:00<?, ?it/s]

Testing loss: 0.0658, accuracy: 97.93%


100%|██████████| 469/469 [00:37<00:00, 12.67it/s]

Training loss: 0.0483, accuracy: 98.53%



  0%|          | 0/469 [00:00<?, ?it/s]

Testing loss: 0.0490, accuracy: 98.49%


100%|██████████| 469/469 [00:36<00:00, 12.95it/s]

Training loss: 0.0397, accuracy: 98.74%



  0%|          | 0/469 [00:00<?, ?it/s]

Testing loss: 0.0447, accuracy: 98.55%


100%|██████████| 469/469 [00:37<00:00, 12.64it/s]

Training loss: 0.0325, accuracy: 99.03%



  0%|          | 0/469 [00:00<?, ?it/s]

Testing loss: 0.0495, accuracy: 98.26%


100%|██████████| 469/469 [00:35<00:00, 13.27it/s]

Training loss: 0.0289, accuracy: 99.13%



  0%|          | 0/469 [00:00<?, ?it/s]

Testing loss: 0.0443, accuracy: 98.59%


100%|██████████| 469/469 [00:35<00:00, 13.29it/s]

Training loss: 0.0237, accuracy: 99.34%



  0%|          | 0/469 [00:00<?, ?it/s]

Testing loss: 0.0433, accuracy: 98.56%


100%|██████████| 469/469 [00:35<00:00, 13.24it/s]

Training loss: 0.0210, accuracy: 99.39%



  0%|          | 0/469 [00:00<?, ?it/s]

Testing loss: 0.0425, accuracy: 98.61%


100%|██████████| 469/469 [00:37<00:00, 12.67it/s]

Training loss: 0.0192, accuracy: 99.46%





Testing loss: 0.0402, accuracy: 98.60%
