In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

from utils import mnist

In [2]:
train_loader, test_loader = mnist()

In [3]:
def my_log_softmax(x, dim=1, **kwargs):
    ms, _ = torch.max(x, dim)
    z = x - ms[:, None]
    s = torch.log(torch.sum(torch.exp(z), dim))
    return z - s[:, None]

In [4]:
test = torch.randn(3,5)
F.log_softmax(test, dim=1) - my_log_softmax(test, dim=1)

tensor(1.00000e-07 *
       [[ 2.3842,  0.0000,  0.0000,  1.1921,  0.0000],
        [ 0.0000,  0.0000,  1.1921,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]])

In [5]:
def my_NLLLoss(y_hat, y, **kwargs):
    N = y_hat.shape[0]
    class_cnt = (torch.max(y)+1).item()
    y_1h = torch.zeros(N, class_cnt)
    y_1h = y_1h.scatter_(1,y[:,None],1)
    return -torch.sum(y_1h * my_log_softmax(y_hat))/N

In [6]:
x = torch.randn(3,5)
y = torch.tensor([1,2,4])
F.nll_loss(F.log_softmax(x, dim=1),y) - my_NLLLoss(x, y)

tensor(0.)

In [7]:
class Net(nn.Module):
    def __init__(self, log_softmax=False):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 10)
        self.log_softmax = log_softmax
        self.optim = optim.Adam(self.parameters(), lr=0.01)
        
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        if self.log_softmax:
            x = F.log_softmax(x, dim=1)
            
        else:
            #x = torch.log(F.softmax(x, dim=1))
            x = my_log_softmax(x)
        return x
    
    def loss(self, output, target, **kwargs):
        self._loss = F.nll_loss(output, target, **kwargs)
        return self._loss

In [8]:
def train(epoch, models):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = Variable(data), Variable(target)
        for model in models:
            model.optim.zero_grad()
            output = model(data)
            loss = model.loss(output, target)
            loss.backward()
            model.optim.step()
            
        if batch_idx % 200 == 0:
            line = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLosses '.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader))
            losses = ' '.join(['{}: {:.6f}'.format(i, m._loss.item()) for i, m in enumerate(models)])
            print(line + losses)
            
    else:
        batch_idx += 1
        line = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLosses '.format(
            epoch, batch_idx * len(data), len(train_loader.dataset),
            100. * batch_idx / len(train_loader))
        losses = ' '.join(['{}: {:.6f}'.format(i, m._loss.item()) for i, m in enumerate(models)])
        print(line + losses)

In [9]:
models = [Net(), Net(True)]

In [10]:
avg_lambda = lambda l: 'Loss: {:.4f}'.format(l)
acc_lambda = lambda c, p: 'Accuracy: {}/{} ({:.0f}%)'.format(c, len(test_loader.dataset), p)
line = lambda i, l, c, p: '{}: '.format(i) + avg_lambda(l) + '\t' + acc_lambda(c, p)

def test(models):
    test_loss = [0]*len(models)
    correct = [0]*len(models)
    with torch.no_grad():
        for data, target in test_loader:
            output = [m(data) for m in models]
            for i, m in enumerate(models):
                test_loss[i] += m.loss(output[i], target, size_average=False).item() # sum up batch loss
                pred = output[i].data.max(1, keepdim=True)[1] # get the index of the max log-probability
                correct[i] += pred.eq(target.data.view_as(pred)).cpu().sum()
    
    for i in range(len(models)):
        test_loss[i] /= len(test_loader.dataset)
    correct_pct = [100. * c / len(test_loader.dataset) for c in correct]
    lines = '\n'.join([line(i, test_loss[i], correct[i], correct_pct[i]) for i in range(len(models))]) + '\n'
    report = 'Test set:\n' + lines
    
    print(report)

In [11]:
for epoch in range(1, 6):
    train(epoch, models)
    test(models)

Test set:
0: Loss: 0.2131	Accuracy: 9400/10000 (94%)
1: Loss: 0.2073	Accuracy: 9414/10000 (94%)

Test set:
0: Loss: 0.2364	Accuracy: 9396/10000 (93%)
1: Loss: 0.2448	Accuracy: 9401/10000 (94%)

Test set:
0: Loss: 0.2605	Accuracy: 9437/10000 (94%)
1: Loss: 0.2193	Accuracy: 9427/10000 (94%)

Test set:
0: Loss: 0.1956	Accuracy: 9529/10000 (95%)
1: Loss: 0.2233	Accuracy: 9492/10000 (94%)

Test set:
0: Loss: 0.2046	Accuracy: 9517/10000 (95%)
1: Loss: 0.2086	Accuracy: 9508/10000 (95%)

