In [1]:
#!/usr/bin/env python3
# Copyright (c) 2020 Graphcore Ltd. All rights reserved.
import argparse
from tqdm import tqdm
import torch
import torch.nn as nn
import torchvision
import poptorch
import torch.optim as optim

# The following is a workaround for pytorch issue #1938
from six.moves import urllib
opener = urllib.request.build_opener()
opener.addheaders = [("User-agent", "Mozilla/5.0")]
urllib.request.install_opener(opener)


def get_mnist_data():
    training_data = torch.utils.data.DataLoader(
                    torchvision.datasets.MNIST('~/.torch/datasets', train=True, download=True,
                                               transform=torchvision.transforms.Compose([
                                                torchvision.transforms.ToTensor(),
                                                torchvision.transforms.Normalize((0.1307, ), (0.3081, ))])),
                    batch_size=8 * 50, shuffle=True, drop_last=True)

    validation_data = torch.utils.data.DataLoader(
                      torchvision.datasets.MNIST('~/.torch/datasets', train=False, download=True,
                                                 transform=torchvision.transforms.Compose([
                                                    torchvision.transforms.ToTensor(),
                                                    torchvision.transforms.Normalize((0.1307, ), (0.3081, ))])),
                      batch_size=80, shuffle=True, drop_last=True)
    return training_data, validation_data


class Block(nn.Module):
    def __init__(self, in_channels, num_filters, kernel_size, pool_size):
        super(Block, self).__init__()
        self.conv = nn.Conv2d(in_channels,
                              num_filters,
                              kernel_size=kernel_size)
        self.pool = nn.MaxPool2d(kernel_size=pool_size)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.pool(x)
        x = self.relu(x)
        return x


class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.layer1 = Block(1, 32, 3, 2)
        self.layer2 = Block(32, 64, 3, 2)
        self.layer3 = nn.Linear(1600, 128)
        self.layer3_act = nn.ReLU()
        self.layer3_dropout = torch.nn.Dropout(0.5)
        self.layer4 = nn.Linear(128, 10)
        self.softmax = nn.Softmax(1)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        # Flatten layer
        x = x.view(-1, 1600)
        x = self.layer3_act(self.layer3(x))
        x = self.layer4(self.layer3_dropout(x))
        x = self.softmax(x)
        return x


class TrainingModelWithLoss(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.loss = torch.nn.CrossEntropyLoss()

    def forward(self, args, loss_inputs=None):
        output = self.model(args)
        if loss_inputs is None:
            return output
        else:
            loss = self.loss(output, loss_inputs)
            return output, loss


def accuracy(predictions, labels):
    _, ind = torch.max(predictions, 1)
    # provide labels only for samples, where prediction is available (during the training, not every samples prediction is returned for efficiency reasons)
    labels = labels[-predictions.size()[0]:]
    accuracy = torch.sum(torch.eq(ind, labels)).item() / labels.size()[0] * 100.0
    return accuracy


def train(training_model, training_data):
    nr_batches = len(training_data)
    for epoch in range(1, 10+1):
        print("Epoch {0}/{1}".format(epoch, 10))
        bar = tqdm(training_data, total=nr_batches)
        for data, labels in bar:
            preds, losses = training_model(data, labels)
            with torch.no_grad():
                mean_loss = torch.mean(losses).item()
                acc = accuracy(preds, labels)
            bar.set_description("Loss:{:0.4f} | Accuracy:{:0.2f}%".format(mean_loss, acc))


def test(inference_model, test_data):
    nr_batches = len(test_data)
    sum_acc = 0.0
    with torch.no_grad():
        for data, labels in tqdm(test_data, total=nr_batches):
            output = inference_model(data)
            sum_acc += accuracy(output, labels)
    print("Accuracy on test set: {:0.2f}%".format(sum_acc / len(test_data)))

In [3]:
# parser = argparse.ArgumentParser(description='MNIST training in PopTorch')
# parser.add_argument('--batch-size', type=int, default=8, help='batch size for training (default: 8)')
# parser.add_argument('--batches-per-step', type=int, default=50, help='device iteration (default:50)')
# parser.add_argument('--test-batch-size', type=int, default=80, help='batch size for testing (default: 80)')
# parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train (default: 10)')
# parser.add_argument('--lr', type=float, default=0.05, help='learning rate (default: 0.05)')
# opts = parser.parse_args()

training_data, test_data = get_mnist_data()
model = Network()
model_with_loss = TrainingModelWithLoss(model)
model_opts = poptorch.Options().deviceIterations(50)


# run training, on IPU
model_with_loss.train()  # Switch the model to training mode
# Models are initialised in training mode by default, so the line above will
# have no effect. Its purpose is to show how the mode can be set explicitly.
training_model = poptorch.trainingModel(model_with_loss, model_opts, optimizer=optim.SGD(model.parameters(), lr=0.05))
train(training_model, training_data)

# Update the weights in model by copying from the training IPU. This updates (model.parameters())
training_model.copyWeightsToHost()

# Check validation loss on IPU once trained. Because PopTorch will be compiled on first call the
# weights in model.parameters() will be copied implicitly. Subsequent calls will need to call
# inference_model.copyWeightsToDevice()
model.eval()  # Switch the model to inference mode
inference_model = poptorch.inferenceModel(model)
test(inference_model, test_data)

  0%|          | 0/150 [00:00<?, ?it/s]

Epoch 1/10



Graph compilation:   0%|          | 0/100 [00:00<?][A
Graph compilation:   3%|▎         | 3/100 [00:00<00:03][A
Graph compilation:   6%|▌         | 6/100 [00:07<01:13][A
Graph compilation:   7%|▋         | 7/100 [00:09<01:54][A
Graph compilation:  20%|██        | 20/100 [00:10<01:10][A
Graph compilation:  21%|██        | 21/100 [00:10<00:59][A
Graph compilation:  22%|██▏       | 22/100 [00:14<02:11][A
Graph compilation:  23%|██▎       | 23/100 [00:15<01:53][A
Graph compilation:  24%|██▍       | 24/100 [00:16<01:28][A
Graph compilation:  25%|██▌       | 25/100 [00:16<01:03][A
Graph compilation:  26%|██▌       | 26/100 [00:16<00:50][A
Graph compilation:  27%|██▋       | 27/100 [00:16<00:38][A
Graph compilation:  28%|██▊       | 28/100 [00:16<00:29][A
Graph compilation:  29%|██▉       | 29/100 [00:17<00:31][A
Graph compilation:  30%|███       | 30/100 [00:18<00:39][A
Graph compilation:  31%|███       | 31/100 [00:18<00:30][A
Graph compilation:  33%|███▎      | 33/100 [00:

Epoch 2/10


Loss:1.4612 | Accuracy:100.00%: 100%|██████████| 150/150 [00:09<00:00, 15.54it/s]
Loss:1.4612 | Accuracy:100.00%:   1%|▏         | 2/150 [00:00<00:09, 15.26it/s]

Epoch 3/10


Loss:1.5864 | Accuracy:87.50%: 100%|██████████| 150/150 [00:09<00:00, 15.59it/s] 
Loss:1.4612 | Accuracy:100.00%:   1%|▏         | 2/150 [00:00<00:10, 14.48it/s]

Epoch 4/10


Loss:1.4612 | Accuracy:100.00%: 100%|██████████| 150/150 [00:09<00:00, 15.44it/s]
Loss:1.4612 | Accuracy:100.00%:   1%|▏         | 2/150 [00:00<00:10, 13.69it/s]

Epoch 5/10


Loss:1.4612 | Accuracy:100.00%: 100%|██████████| 150/150 [00:09<00:00, 15.43it/s]
Loss:1.5555 | Accuracy:87.50%:   1%|▏         | 2/150 [00:00<00:09, 15.16it/s] 

Epoch 6/10


Loss:1.4612 | Accuracy:100.00%: 100%|██████████| 150/150 [00:09<00:00, 15.45it/s]
Loss:1.4612 | Accuracy:100.00%:   1%|▏         | 2/150 [00:00<00:08, 16.82it/s]

Epoch 7/10


Loss:1.4612 | Accuracy:100.00%: 100%|██████████| 150/150 [00:09<00:00, 15.62it/s]
Loss:1.4612 | Accuracy:100.00%:   1%|▏         | 2/150 [00:00<00:09, 16.30it/s]

Epoch 8/10


Loss:1.4612 | Accuracy:100.00%: 100%|██████████| 150/150 [00:09<00:00, 15.32it/s]
Loss:1.4615 | Accuracy:100.00%:   1%|▏         | 2/150 [00:00<00:09, 15.07it/s]

Epoch 9/10


Loss:1.4612 | Accuracy:100.00%: 100%|██████████| 150/150 [00:09<00:00, 15.15it/s]
Loss:1.4612 | Accuracy:100.00%:   1%|▏         | 2/150 [00:00<00:10, 13.55it/s]

Epoch 10/10


Loss:1.4612 | Accuracy:100.00%: 100%|██████████| 150/150 [00:09<00:00, 15.53it/s]
  0%|          | 0/125 [00:00<?, ?it/s]
Graph compilation:   0%|          | 0/100 [00:00<?][A
Graph compilation:   4%|▍         | 4/100 [00:00<00:04][A
Graph compilation:   6%|▌         | 6/100 [00:05<01:23][A
Graph compilation:   7%|▋         | 7/100 [00:06<01:25][A
Graph compilation:  22%|██▏       | 22/100 [00:07<00:50][A
Graph compilation:  27%|██▋       | 27/100 [00:09<00:44][A
Graph compilation:  31%|███       | 31/100 [00:10<00:33][A
Graph compilation:  34%|███▍      | 34/100 [00:10<00:23][A
Graph compilation:  42%|████▏     | 42/100 [00:11<00:16][A
Graph compilation:  45%|████▌     | 45/100 [00:11<00:11][A
Graph compilation:  50%|█████     | 50/100 [00:11<00:08][A
Graph compilation:  52%|█████▏    | 52/100 [00:12<00:08][A
Graph compilation:  55%|█████▌    | 55/100 [00:13<00:11][A
Graph compilation:  63%|██████▎   | 63/100 [00:13<00:06][A
Graph compilation:  67%|██████▋   | 67/100 [00

Accuracy on test set: 97.91%



