In [1]:
%matplotlib inline

In [2]:
import numpy as np
import matplotlib.pyplot as plt

In [3]:
import torch
import torch.utils.data as dutils

import torchvision.datasets as dataset
import torchvision.transforms as dtrans

In [4]:
dataroot = '/home/dutta/data/'
device = 'cpu'
seed = 99

In [5]:
input_dim = 784
output_dim = 10
num_examples = 60000

batch_size = 32

torch.manual_seed(seed);

### Data setup

In [6]:
transform = dtrans.Compose([dtrans.Grayscale(num_output_channels=1),
                            dtrans.Resize((28,28)),
                            dtrans.ToTensor(), dtrans.Normalize((0.5,), (0.5,))])
train_set = dataset.MNIST(root=dataroot, download=False, train=True, transform=transform)
test_set = dataset.MNIST(root=dataroot, download=False, train=False, transform=transform)

train_loader = dutils.DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = dutils.DataLoader(test_set, batch_size=batch_size, shuffle=False)

### Model setup

In [7]:
def l2_penalty(weight):
    penalty = torch.pow(weight, 2)
    penalty = torch.sqrt(penalty).mean()
    return penalty        

In [8]:
def cross_entropy(yhat, y):
    probs = torch.gather(yhat, 1, y.view(-1, 1))
    loss = -torch.log(probs + 1e-6)
    return loss.mean()

In [9]:
def softmax(y_linear):
    row_max = y_linear.max(dim=1)[0]
    row_max.unsqueeze_(1)
    row_max = row_max.repeat(1, y_linear.shape[1])
    exp = torch.exp(y_linear - row_max)
    norm = exp.sum(dim=1).unsqueeze(1).repeat(1, y_linear.shape[1])
    return exp / norm

In [9]:
# sample_y_linear = torch.randn(10, 10)
# sample_yhat = softmax(sample_y_linear)
# print(sample_y_linear)
# print(sample_yhat)

In [10]:
def SGD(params, lr):
    for param in params:
        param.data = param.data - lr * param.grad

In [11]:
class Net():
    def __init__(self, input_dim, out_dim, device):
        w = torch.randn(size=(input_dim, output_dim), device=device)
        b = torch.randn(output_dim, device=device)
        self.w = w / torch.norm(w)
        self.b = b /torch.norm(b)
        self.w.requires_grad = True
        self.b.requires_grad = True
#         print('Weight shape: ', self.w.shape)
#         print('Bias shape: ', self.b.shape)
        
    def forward(self, x):
        w = self.w
        bias = self.b
        out = torch.mm(x, w) + bias
        out = softmax(out)
        return out.squeeze()
    
    def parameters(self):
        return [self.w, self.b]
    
    def zero_grad(self):
        self.w.grad = None
        self.b.grad = None
        
    def __call__(self, x):
        return self.forward(x)

In [19]:
def test(net, test_loader, device):
    loss = 0
    error = 0
    for i, (images, labels) in enumerate(test_loader):
        images, labels = images.to(device), labels.to(device)
        images = images.view(-1, input_dim)
        with torch.no_grad():
            out = net(images)
            loss += cross_entropy(out, labels).item()
            error += (out.argmax(1) != labels).sum().item()
    loss /= i       
    print('[Test] Loss: %.4f, Error: %d' % (loss, error))

In [20]:
def train(net, train_loader, device):
    params = net.parameters()
    lr = 0.1
    
    loss = 0
    error = 0
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        images = images.view(-1, 784)
        
        out = net(images)
        loss_batch = cross_entropy(out, labels)
        loss_batch = loss_batch + 2 * l2_penalty(net.w)
        
        net.zero_grad()
        loss_batch.backward()
#         SGD(params, lr)
        opt.step()
        
        loss += loss_batch.item()
        error += (out.argmax(1) != labels).sum().item()
    loss /= i       
    print('[TRAIN] Loss: %.4f, Error: %d' % (loss, error))

In [None]:
net = Net(input_dim, output_dim, device)
opt = torch.optim.Adam(net.parameters(), lr=0.003)

test(net, test_loader, device)
# print l2_penalty(net.w)
for epochs in range(50):
    train(net, train_loader, device)
    test(net, test_loader, device)
#     print l2_penalty(net.w)

[Test] Loss: 2.4217, Error: 9106
[TRAIN] Loss: 0.4995, Error: 7666
[Test] Loss: 0.3678, Error: 1000
[TRAIN] Loss: 0.4641, Error: 6532
[Test] Loss: 0.3362, Error: 969
[TRAIN] Loss: 0.4726, Error: 6465
[Test] Loss: 0.4701, Error: 1349
[TRAIN] Loss: 0.4775, Error: 6585
[Test] Loss: 0.3395, Error: 964
[TRAIN] Loss: 0.4648, Error: 6236
[Test] Loss: 0.3809, Error: 1056
[TRAIN] Loss: 0.4730, Error: 6435
[Test] Loss: 0.3509, Error: 931
[TRAIN] Loss: 0.4741, Error: 6451
[Test] Loss: 0.3439, Error: 1003
[TRAIN] Loss: 0.4668, Error: 6312
[Test] Loss: 0.4066, Error: 1101
[TRAIN] Loss: 0.4701, Error: 6263
[Test] Loss: 0.4384, Error: 1301
[TRAIN] Loss: 0.4729, Error: 6341
[Test] Loss: 0.4059, Error: 1147
[TRAIN] Loss: 0.4668, Error: 6216
[Test] Loss: 0.3806, Error: 1020
[TRAIN] Loss: 0.4690, Error: 6314
[Test] Loss: 0.3881, Error: 1070
[TRAIN] Loss: 0.4744, Error: 6400
[Test] Loss: 0.3828, Error: 1113
[TRAIN] Loss: 0.4666, Error: 6280
[Test] Loss: 0.3599, Error: 1007
[TRAIN] Loss: 0.4696, Error: 627

In [64]:
a = torch.randn(10, 5)
b = (torch.rand(10)*5).long()

print a
print b.view(-1, 1)

# torch.gather(a, 1, b.view(-1, 1))


tensor([[-1.4813,  0.3770,  1.1661,  1.1120,  0.7396],
        [-1.1656,  0.3580, -1.0823,  1.3389,  0.7177],
        [ 0.8475, -1.3822, -0.1256, -0.7188, -1.2978],
        [ 1.3189,  1.2414, -0.3685,  0.1967, -0.3643],
        [-0.6959,  0.3633, -0.3550, -0.6447, -0.6605],
        [-0.9395,  0.2178, -2.1892, -1.5369,  1.2575],
        [ 2.0529, -1.3377, -1.8684, -0.3506, -0.3648],
        [ 0.1550,  1.9691, -0.3749, -0.4108,  0.0268],
        [ 0.0091, -0.1698,  0.2955,  0.1400, -0.7759],
        [-0.7668,  0.9081,  0.9108, -2.1320,  1.0144]])
tensor([[0],
        [2],
        [0],
        [2],
        [0],
        [4],
        [0],
        [4],
        [2],
        [3]])
