In [1]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


import utils
from nets.NetOneLayer import NetOneLayer
from nets.NetOneLayerLowRank import NetOneLayerLowRank

In [3]:
batch_size = 128
batch_size_test = 1000

train_loader, test_loader = utils.load_mnist(batch_size, batch_size_test)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [4]:
def train(model, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 200 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            
def test(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [5]:
lr = 0.02
momentum = 0.9
n_epochs = 100

model = NetOneLayerLowRank(n_hidden=2**8, d=2, K=2)
# model = NetOneLayer(n_hidden=2**8)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

for epoch in range(n_epochs):
    train(model, train_loader, optimizer, epoch)
    test(model, test_loader)


Test set: Average loss: -0.2874, Accuracy: 2948/10000 (29%)


Test set: Average loss: -0.3667, Accuracy: 3715/10000 (37%)


Test set: Average loss: -0.4059, Accuracy: 4120/10000 (41%)


Test set: Average loss: -0.4335, Accuracy: 4372/10000 (44%)


Test set: Average loss: -0.4526, Accuracy: 4570/10000 (46%)


Test set: Average loss: -0.4652, Accuracy: 4687/10000 (47%)


Test set: Average loss: -0.4736, Accuracy: 4770/10000 (48%)


Test set: Average loss: -0.4806, Accuracy: 4838/10000 (48%)


Test set: Average loss: -0.4855, Accuracy: 4879/10000 (49%)


Test set: Average loss: -0.4910, Accuracy: 4949/10000 (49%)


Test set: Average loss: -0.4939, Accuracy: 4970/10000 (50%)


Test set: Average loss: -0.4970, Accuracy: 5003/10000 (50%)


Test set: Average loss: -0.4984, Accuracy: 5014/10000 (50%)


Test set: Average loss: -0.5020, Accuracy: 5053/10000 (51%)


Test set: Average loss: -0.5046, Accuracy: 5083/10000 (51%)


Test set: Average loss: -0.5047, Accuracy: 5074/10000 (51%)


Test se


Test set: Average loss: -0.6273, Accuracy: 6305/10000 (63%)


Test set: Average loss: -0.6275, Accuracy: 6309/10000 (63%)


Test set: Average loss: -0.6285, Accuracy: 6318/10000 (63%)


Test set: Average loss: -0.6296, Accuracy: 6323/10000 (63%)


Test set: Average loss: -0.6294, Accuracy: 6326/10000 (63%)


Test set: Average loss: -0.6302, Accuracy: 6332/10000 (63%)


Test set: Average loss: -0.6303, Accuracy: 6333/10000 (63%)


Test set: Average loss: -0.6310, Accuracy: 6338/10000 (63%)


Test set: Average loss: -0.6321, Accuracy: 6342/10000 (63%)


Test set: Average loss: -0.6317, Accuracy: 6342/10000 (63%)


Test set: Average loss: -0.6319, Accuracy: 6347/10000 (63%)


Test set: Average loss: -0.6320, Accuracy: 6347/10000 (63%)


Test set: Average loss: -0.6332, Accuracy: 6364/10000 (64%)


Test set: Average loss: -0.6331, Accuracy: 6360/10000 (64%)


Test set: Average loss: -0.6331, Accuracy: 6353/10000 (64%)


Test set: Average loss: -0.6335, Accuracy: 6367/10000 (64%)


Test se


Test set: Average loss: -0.7032, Accuracy: 7057/10000 (71%)


Test set: Average loss: -0.7047, Accuracy: 7071/10000 (71%)


Test set: Average loss: -0.7050, Accuracy: 7074/10000 (71%)


Test set: Average loss: -0.7052, Accuracy: 7075/10000 (71%)


Test set: Average loss: -0.7051, Accuracy: 7066/10000 (71%)


Test set: Average loss: -0.7056, Accuracy: 7085/10000 (71%)


Test set: Average loss: -0.7060, Accuracy: 7086/10000 (71%)


Test set: Average loss: -0.7056, Accuracy: 7084/10000 (71%)


Test set: Average loss: -0.7073, Accuracy: 7095/10000 (71%)


Test set: Average loss: -0.7067, Accuracy: 7091/10000 (71%)


Test set: Average loss: -0.7075, Accuracy: 7097/10000 (71%)


Test set: Average loss: -0.7081, Accuracy: 7102/10000 (71%)


Test set: Average loss: -0.7076, Accuracy: 7099/10000 (71%)


Test set: Average loss: -0.7074, Accuracy: 7088/10000 (71%)


Test set: Average loss: -0.7085, Accuracy: 7106/10000 (71%)


Test set: Average loss: -0.7086, Accuracy: 7104/10000 (71%)


Test se