In [12]:
from __future__ import print_function
import pickle 
import numpy as np
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from sub import subMNIST

In [13]:
trainset_labeled = pickle.load(open("train_labeled.p", "rb")) 
train_loader  = torch.utils.data.DataLoader(trainset_labeled, batch_size=64, shuffle=True, num_workers=2)

validset = pickle.load(open("validation.p", "rb"))
valid_loader = torch.utils.data.DataLoader(validset, batch_size=64, shuffle=True, num_workers=2)

In [14]:
#utility functions
def rand(x, level = 1e-3):
    return x + torch.randn(x.size()) * level

def norm_weights_2d(size):
    return BatchNorm2d(nn.Linear(size))
def norm_weights_1d(size):
    return BatchNorm2d(nn.Linear(size))

def reluN(x, level = 1e-3):
    y = x + torch.randn(x.size()) * level

In [19]:
def norm(size, dim = 1):
    if dim == 1:
        return torch.nn.BatchNorm1d(size)
    else:
        return torch.nn.BatchNorm2d(size)
def linear(indim, outdim):
    return nn.Linear(indim, outdim)
def var(tens):
    return torch.autograd.Variable(tens)
def randn(size):
    return torch.randn(size)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.layers = [1000, 600, 300, 10]
        layers = self.layers
        self.encs = [28*28, layers[0]]
        self.lats = []
        self.encnorms = []
        self.decs = []
        self.decnorms = []
        for idx, l in enumerate(layers[:-1]):
            self.lats.append(linear(l, l))
            self.encnorms.append(norm(l))
            self.encs.append(linear(l, layers[idx+1]))
            self.decs.append(linear(layers[idx+1], l))
            self.decnorms.append(norm(l))
        self.batch_size = 64
        self.weights = [self.encs, self.lats, self.decs]
        

    def forward(self, x, v = 0, noise=1e-3):
        self.eps = noise
        self.batch_size = x.size()[0]
        bs = self.batch_size
        #enc= F.relu(self.enc4(F.relu(self.enc3(F.relu(self.enc2(F.relu(self.enc1(x))))))))
        corrupted = []
        corruptedout = []
        for idx, l in enumerate(self.layers):
            if idx == 0:
                corrupted.append(self.encnorms[l](self.encs[l](x)))
            else:
                corrupted.append(self.encnorms[l](self.encs[l](corruptedout[-1])))
            corruptedout.append(F.relu(self.encs[l] + self.encnorms[l].weight.unsqueeze(0).expand(
                    bs, l) * var(randn(corrupted[-1].size()) * self.eps) +
                    self.encnorms[0].bias.unsqueeze(0).expand(bs, l)))
            
        encout = F.softmax(corruptedout[-1])
        
        clean = [x]
        for norm, enc in zip(self.encnorms, self.encs):
            clean.append(F.relu(norm(enc(clean[-1]))))
        
        decout = [encout]
        decin = []
        for idx in range(len(self.layers) - 1, -1, -1):
            dec = self.decs[idx]
            lat = self.lats[idx]
            decnorm = self.decnorms[idx+1]
            decin.append(dec(decout[-1]) + lat(corruptedout[idx]))
            decout.append(F.relu(decin[-1] + decnorm.weight.unsqueeze(0).expand(bs, self.layers[idx]) *
                    var(randn(decin[-1].size()) * self.eps) + \
                    decnorm.bias.unsqueeze(0).expend(bs, self.layers[idx])))
        weight_reg = 0
        for w_list in self.weights:
            for w in w_list:
                weight_reg += (w.weight**2).mean()/100
        
        yhat = F.log_softmax(corruptedout[-1])
        
        encode_err = 0
        enc_weight = 5
        enc_decay = .2
        for c, d in zip(clean[-2::-1], decin):
            encode_err += enc_weight * ((c - d)**2).mean()
        return yhat, weight_reg + encode_err
        

model = Net()
params = list(model.parameters())

if 1:
    print(model)
    print('Models has {} learnable paramater:'.format(len(params)))
    [print('parameter {} has a size of {}'.format(i+1, params[i].size())) for i in range(len(params))]

Net (
)
Models has 0 learnable paramater:


In [16]:
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

IndexError: list index out of range

In [59]:
def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = Variable(data), Variable(target) # Wrap them in Variable 
        optimizer.zero_grad() # Zero the parameter gradients
        data = data.view(data.size()[0], 28*28)
        outputs = model(data,noise=1e-2) # Forward 
        output = outputs[0]
        US = outputs[1]
        loss = F.nll_loss(output, target) + US*.1
        loss.backward() 
        optimizer.step()
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.data[0]))

In [60]:
def test(epoch, valid_loader):
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in valid_loader:

        data, target = Variable(data, volatile=True), Variable(target)
        data = data.view(data.size()[0], 28*28)
        outputs = model(data, noise = 0)
        output = outputs[0]
        test_loss += F.nll_loss(output, target).data[0]
        pred = output.data.max(1)[1] # get the index of the max log-probability
        correct += pred.eq(target.data).cpu().sum()

    test_loss /= len(valid_loader) # loss function already averages over batch size
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(valid_loader.dataset),
        100. * correct / len(valid_loader.dataset)))

In [61]:
for epoch in range(1, 30):
    train(epoch)
    test(epoch, valid_loader)


Test set: Average loss: 0.7336, Accuracy: 8942/10000 (89%)


Test set: Average loss: 0.4306, Accuracy: 9278/10000 (93%)


Test set: Average loss: 0.3471, Accuracy: 9367/10000 (94%)


Test set: Average loss: 0.2937, Accuracy: 9388/10000 (94%)


Test set: Average loss: 0.2662, Accuracy: 9413/10000 (94%)


Test set: Average loss: 0.2484, Accuracy: 9444/10000 (94%)


Test set: Average loss: 0.2410, Accuracy: 9438/10000 (94%)


Test set: Average loss: 0.2323, Accuracy: 9457/10000 (95%)


Test set: Average loss: 0.2325, Accuracy: 9420/10000 (94%)


Test set: Average loss: 0.2203, Accuracy: 9438/10000 (94%)


Test set: Average loss: 0.2252, Accuracy: 9423/10000 (94%)


Test set: Average loss: 0.2155, Accuracy: 9440/10000 (94%)


Test set: Average loss: 0.2329, Accuracy: 9377/10000 (94%)


Test set: Average loss: 0.2206, Accuracy: 9419/10000 (94%)


Test set: Average loss: 0.2237, Accuracy: 9393/10000 (94%)


Test set: Average loss: 0.2109, Accuracy: 9431/10000 (94%)


Test set: Average loss: