In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

from utils import mnist

In [2]:
train_loader, test_loader = mnist()

In [3]:
class Net(nn.Module):
    def __init__(self, log_softmax=False):
        super(Net, self).__init__()
        innerLayerShape1 = 128
        self.fc1 = nn.Linear(28*28, innerLayerShape1)
        self.fc2 = nn.Linear(innerLayerShape1, 10)
        self.log_softmax = log_softmax
        self.optim = optim.SGD(self.parameters(), lr=0.1)
        
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = torch.sigmoid(self.fc1(x))
        x = self.fc2(x)
        if self.log_softmax:
            x = F.log_softmax(x, dim=1)
        else:
            x = torch.log(F.softmax(x, dim=1))
        return x
    
    def loss(self, output, target, **kwargs):
        self._loss = F.nll_loss(output, target, **kwargs)
        return self._loss

In [4]:
def train(epoch, models):
    train_loss = [0]*len(models)
    train_loss_count = [0]*len(models)
    for batch_idx, (data, target) in enumerate(train_loader):
        for model in models:
            model.optim.zero_grad()
            output = model(data)
            loss = model.loss(output, target)
            loss.backward()
            model.optim.step()
            
        if batch_idx % 200 == 0:
            line = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLosses '.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader))
            lossesArr = []
            for i, m in enumerate(models):
                lossModel = m._loss.item()
                train_loss[i] += lossModel
                train_loss_count[i] += 1
                lossesArr.append('{}: {:.6f}'.format(i, lossModel))
            losses = ' '.join(lossesArr)
            print(line + losses)
            
    else:
        batch_idx += 1
        line = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLosses '.format(
            epoch, batch_idx * len(data), len(train_loader.dataset),
            100. * batch_idx / len(train_loader))
        lossesArr = []
        for i, m in enumerate(models):
            lossModel = m._loss.item()
            train_loss[i] += lossModel
            train_loss_count[i] += 1
            lossesArr.append('{}: {:.6f}'.format(i, lossModel))
        losses = ' '.join(lossesArr)
        print(line + losses)
    for i in range(len(models)):
        train_loss[i] /= train_loss_count[i]
        print('Loss: {:.4f}'.format(train_loss[i]))

In [5]:
models = [Net(), Net(True)]

In [6]:
avg_lambda = lambda l: 'Loss: {:.4f}'.format(l)
acc_lambda = lambda c, p: 'Accuracy: {}/{} ({:.0f}%)'.format(c, len(test_loader.dataset), p)
line = lambda i, l, c, p: '{}: '.format(i) + avg_lambda(l) + '\t' + acc_lambda(c, p)

def test(models):
    test_loss = [0]*len(models)
    correct = [0]*len(models)
    with torch.no_grad():
        for data, target in test_loader:
            output = [m(data) for m in models]
            for i, m in enumerate(models):
                test_loss[i] += m.loss(output[i], target, reduction='sum').item() # sum up batch loss
                pred = output[i].data.max(1, keepdim=True)[1] # get the index of the max log-probability
                correct[i] += pred.eq(target.data.view_as(pred)).cpu().sum()
    
    for i in range(len(models)):
        test_loss[i] /= len(test_loader.dataset)
    correct_pct = [100. * c / len(test_loader.dataset) for c in correct]
    lines = '\n'.join([line(i, test_loss[i], correct[i], correct_pct[i]) for i in range(len(models))]) + '\n'
    report = 'Test set:\n' + lines
    
    print(report)

In [None]:
for epoch in range(1, 10):
    train(epoch, models)
    test(models)

Loss: 0.6626
Loss: 0.6592
Test set:
0: Loss: 0.2606	Accuracy: 9246/10000 (92%)
1: Loss: 0.2606	Accuracy: 9246/10000 (92%)

Loss: 0.2462
Loss: 0.2490
Test set:
0: Loss: 0.1993	Accuracy: 9421/10000 (94%)
1: Loss: 0.2013	Accuracy: 9419/10000 (94%)

Loss: 0.1557
Loss: 0.1643
Test set:
0: Loss: 0.1630	Accuracy: 9534/10000 (95%)
1: Loss: 0.1650	Accuracy: 9517/10000 (95%)

Loss: 0.1600
Loss: 0.1564
Test set:
0: Loss: 0.1431	Accuracy: 9587/10000 (96%)
1: Loss: 0.1440	Accuracy: 9580/10000 (96%)

Loss: 0.1476
Loss: 0.1501
Test set:
0: Loss: 0.1264	Accuracy: 9643/10000 (96%)
1: Loss: 0.1280	Accuracy: 9625/10000 (96%)

Loss: 0.1641
Loss: 0.1616
Test set:
0: Loss: 0.1163	Accuracy: 9672/10000 (97%)
1: Loss: 0.1184	Accuracy: 9650/10000 (96%)

Loss: 0.1284
Loss: 0.1321
Test set:
0: Loss: 0.1086	Accuracy: 9679/10000 (97%)
1: Loss: 0.1114	Accuracy: 9663/10000 (97%)

Loss: 0.1166
Loss: 0.1233
Test set:
0: Loss: 0.1018	Accuracy: 9712/10000 (97%)
1: Loss: 0.1030	Accuracy: 9675/10000 (97%)

