In [34]:
import gpytorch
import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torchvision import transforms

gpytorch.functions.use_toeplitz = False

In [35]:
class FeatureExtractor(nn.Sequential):
    
    def __init__(self):
        super(FeatureExtractor, self).__init__(nn.Conv2d(1, 32, kernel_size=5, padding=2),
                                 nn.BatchNorm2d(32),
                                 nn.ReLU(),
                                 nn.MaxPool2d(2, 2),
                                 nn.Conv2d(32, 64, kernel_size=5, padding=2),
                                 nn.BatchNorm2d(64),
                                 nn.ReLU(),
                                 nn.MaxPool2d(2, 2))
        
class Bottleneck(nn.Sequential):
    
    def __init__(self):
        super(Bottleneck, self).__init__(nn.Linear(64*7*7, 128),
                                         nn.BatchNorm1d(128),
                                 nn.ReLU(),
                                 nn.Linear(128, 128),
                                 nn.BatchNorm1d(128),
                                 nn.ReLU(),
                                 nn.Linear(128,64),
                                 nn.BatchNorm1d(64))

class LeNet(nn.Module):
    
    def __init__(self):
        super(LeNet, self).__init__()
        self.feature_extractor = FeatureExtractor()
        self.bottleneck = Bottleneck()
        self.final_layer = nn.Sequential(
                                 nn.ReLU(),
                                 nn.Linear(64,10))
    
    def forward(self, x):
        features = self.feature_extractor(x)
        bottlenecked_features = self.bottleneck(features.view(-1, 64 * 7 * 7))
        classification = self.final_layer(bottlenecked_features)
        return classification
        

In [36]:
train_mnist = torchvision.datasets.MNIST('/mnt/bigboi/datasets/mnist/', train=True,
                                         download=True, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))                

train_indices = train_mnist.train_labels.lt(7).nonzero().squeeze()
train_mnist.train_data = train_mnist.train_data.index_select(0, train_indices)
train_mnist.train_labels = train_mnist.train_labels.index_select(0, train_indices)

test_mnist = torchvision.datasets.MNIST('/mnt/bigboi/datasets/mnist/', train=False,
                                        download=True, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))

test_indices = test_mnist.test_labels.lt(7).nonzero().squeeze()
test_mnist.test_data = test_mnist.test_data.index_select(0, test_indices)
test_mnist.test_labels = test_mnist.test_labels.index_select(0, test_indices)

len(train_mnist), len(test_mnist)

(41935, 6989)

In [37]:
train_data_loader = torch.utils.data.DataLoader(train_mnist, shuffle=True, pin_memory=True, batch_size=256)

In [38]:
criterion = nn.CrossEntropyLoss().cuda()

In [39]:
model = LeNet().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

In [40]:
num_epochs = 0
if num_epochs > 0:
    for i in range(num_epochs):
        for x, y in train_data_loader:
            optimizer.zero_grad()
            x = Variable(x.cuda())
            y = Variable(y.cuda())
            output = model(x)
            loss = criterion(output, y)
            loss.backward()
            optimizer.step()
        print("Loss: %.3f" % loss.data[0])
    torch.save(model.state_dict(), '/mnt/data/dkl/mnist/lenet_oneshot.dat')
else:
    model.load_state_dict(torch.load('/mnt/data/dkl/mnist/lenet_oneshot.dat'))
    
model.eval()
test_data_loader = torch.utils.data.DataLoader(test_mnist, shuffle=False, pin_memory=True, batch_size=256)
avg = 0.
i = 0.
for test_batch_x, test_batch_y in test_data_loader:
    predictions = model(Variable(test_batch_x).cuda()).max(-1)[1]
    test_batch_y = Variable(test_batch_y).cuda()
    avg += torch.eq(predictions, test_batch_y).float().mean().data[0]
    i += 1.
print('Accuracy: %.4f' % (avg / i))

Accuracy: 0.9953


In [41]:
list(model.bottleneck.modules())[-1].weight.data.fill_(1)
None

In [42]:
from gpytorch.kernels import RBFKernel, GridInterpolationKernel

class DeepKernel(gpytorch.Module):
    def __init__(self, model):
        super(DeepKernel, self).__init__()
        self.feature_extractor = model.feature_extractor
        self.bottleneck = model.bottleneck
        self.gp_layer = GPLayer()
        
    def forward(self, x):
        features = self.feature_extractor(x)
        bottlenecked_features = self.bottleneck(features.view(-1, 64 * 7 * 7))
        gp_output = self.gp_layer(bottlenecked_features)
        return gp_output
    
    
class LatentFunction(gpytorch.AdditiveGridInducingPointModule):
    def __init__(self):
        super(LatentFunction, self).__init__(grid_size=256, grid_bounds=[(-7, 7)],
                                             n_components=64, mixing_params=False, sum_output=False)
        cov_module = RBFKernel()
        cov_module.initialize(log_lengthscale=2)
        self.cov_module = cov_module
        
    def forward(self, x):
        mean = Variable(x.data.new(len(x)).zero_())
        covar = self.cov_module(x)
        return gpytorch.random_variables.GaussianRandomVariable(mean, covar)

    
class GPLayer(gpytorch.GPModel):
    def __init__(self, n_dims=64):
        super(GPLayer, self).__init__(gpytorch.likelihoods.SoftmaxGaussianLikelihood(n_features=64, n_classes=10, rank=5))
        self.latent_function = LatentFunction()
    
    def forward(self, x):
        res = self.latent_function(x)
        return res
    

In [43]:
len(train_mnist)

41935

In [44]:
deep_kernel = DeepKernel(model).cuda()
gp_data_loader = torch.utils.data.DataLoader(train_mnist, batch_size=2048., pin_memory=True, shuffle=True)

In [45]:
# Find optimal model hyperparameters
n_epochs = 0
if n_epochs > 0:
    deep_kernel.train()
    optimizer = torch.optim.Adam(deep_kernel.gp_layer.parameters(), lr=0.01)
    optimizer.n_iter = 0
    for i in range(n_epochs):
        for j, (train_x_batch, train_y_batch) in enumerate(gp_data_loader):
            train_x_batch = Variable(train_x_batch).cuda()
            train_y_batch = Variable(train_y_batch).cuda()
            optimizer.zero_grad()
            output = deep_kernel(train_x_batch)
            loss = -deep_kernel.gp_layer.marginal_log_likelihood(output, train_y_batch, n_data=len(train_mnist))
            kl = deep_kernel.gp_layer.likelihood.kl_div() / len(train_mnist)
            loss = loss + kl
            loss.backward()
            optimizer.n_iter += 1
            print('Iter %d/200 - Loss: %.3f' % (
                i + 1, loss.data[0],
            ))
            optimizer.step()

        deep_kernel.eval()
        test_data_loader = torch.utils.data.DataLoader(test_mnist, shuffle=False, pin_memory=True, batch_size=256)

        avg = 0.
        i = 0.
        for test_batch_x, test_batch_y in test_data_loader:
            predictions = deep_kernel(Variable(test_batch_x).cuda()).argmax()
            test_batch_y = Variable(test_batch_y).cuda()
            avg += torch.eq(predictions, test_batch_y).float().mean().data[0]
            i += 1.

        print('Score')
        print(avg / i)
        deep_kernel.train()

    torch.save(deep_kernel.state_dict(), '/mnt/data/dkl/mnist/gp_oneshot.dat')
else:
    deep_kernel.load_state_dict(torch.load('/mnt/data/dkl/mnist/gp_oneshot.dat'))

In [46]:
train_mnist = torchvision.datasets.MNIST('/mnt/bigboi/datasets/mnist/', train=True,
                                         download=True, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))                

train_indices = torch.cat([
    train_mnist.train_labels.eq(7).nonzero().squeeze()[:5],
    train_mnist.train_labels.eq(8).nonzero().squeeze()[:5],
    train_mnist.train_labels.eq(9).nonzero().squeeze()[:5],
])
train_mnist.train_data = train_mnist.train_data.index_select(0, train_indices)
train_mnist.train_labels = train_mnist.train_labels.index_select(0, train_indices) - 7

test_mnist = torchvision.datasets.MNIST('/mnt/bigboi/datasets/mnist/', train=False,
                                        download=True, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))

test_indices = test_mnist.test_labels.ge(7).nonzero().squeeze()
test_mnist.test_data = test_mnist.test_data.index_select(0, test_indices)
test_mnist.test_labels = test_mnist.test_labels.index_select(0, test_indices) - 7

len(train_mnist), len(test_mnist)

(15, 3011)

In [47]:
gp_data_loader = torch.utils.data.DataLoader(train_mnist, batch_size=2048., pin_memory=True)

In [48]:
class FeatureExtractor(nn.Module):
    def __init__(self, gp_model):
        super(FeatureExtractor, self).__init__()
        self.feature_extractor = gp_model.feature_extractor
        self.bottleneck = gp_model.bottleneck
        
    def forward(self, x):
        features = self.feature_extractor(x)
        bottlenecked_features = self.bottleneck(features.view(-1, 64 * 7 * 7))
        return bottlenecked_features
    
feature_extractor = FeatureExtractor(deep_kernel).cuda()
feature_extractor.eval()

data, _ = gp_data_loader.__iter__().next()
inducing_points = feature_extractor(Variable(data).cuda()).data.unsqueeze(-1)

In [56]:
class OneShotGP(gpytorch.GPModel):
    def __init__(self, gp_model, n_dims=64):
        super(OneShotGP, self).__init__(gpytorch.likelihoods.SoftmaxGaussianLikelihood(n_features=64, n_classes=3, rank=5))
        self.latent_function = OneShotLatent(gp_model)
        self.likelihood.feature_mixing_weights.data.copy_(self.likelihood.feature_mixing_weights.data)
    
    def forward(self, x):
        x = x.unsqueeze(-1)
        res = self.latent_function(x)
        return res
    
class OneShotLatent(gpytorch.AdditiveInducingPointModule):
    def __init__(self, gp_model):
        super(OneShotLatent, self).__init__(inducing_points=inducing_points, n_components=64, sum_output=False)
        cov_module = RBFKernel()
        cov_module.initialize(log_lengthscale=gp_model.gp_layer.latent_function.cov_module.base_kernel_module.log_lengthscale.data[0])
        self.cov_module = cov_module
        
    def forward(self, x):
        mean = Variable(x.data.new(x.size()).zero_()).squeeze(-1)
        covar = self.cov_module(x)
        return gpytorch.random_variables.GaussianRandomVariable(mean, covar)
    
oneshot_model = OneShotGP(deep_kernel).cuda()

In [63]:
# Find optimal model hyperparameters
oneshot_model.train()

train_x_batch, train_y_batch = gp_data_loader.__iter__().next()
train_x_batch = feature_extractor(Variable(train_x_batch).cuda())
train_y_batch = Variable(train_y_batch).cuda()

optimizer = torch.optim.Adam(set(oneshot_model.parameters()), lr=0.01)
optimizer.n_iter = 0
for i in range(100):
    optimizer.zero_grad()
    output = oneshot_model(train_x_batch)
    loss = -oneshot_model.marginal_log_likelihood(output, train_y_batch, n_data=len(train_mnist))
    kl = oneshot_model.likelihood.kl_div() / len(train_mnist)
    loss = loss + kl
    loss.backward()
    optimizer.n_iter += 1
    print('Iter %d/200 - Loss: %.3f' % (
        i + 1, loss.data[0],
    ))
    optimizer.step()

oneshot_model.eval()
test_data_loader = torch.utils.data.DataLoader(test_mnist, shuffle=False, pin_memory=True, batch_size=256)

avg = 0.
i = 0.
for test_batch_x, test_batch_y in test_data_loader:
    predictions = oneshot_model(feature_extractor(Variable(test_batch_x).cuda())).argmax()
    test_batch_y = Variable(test_batch_y).cuda()
    avg += torch.eq(predictions, test_batch_y).float().mean().data[0]
    i += 1.

print('Score')
print(avg / i)

Iter 1/200 - Loss: 13.202
Iter 2/200 - Loss: 45.852
Iter 3/200 - Loss: 18.872
Iter 4/200 - Loss: 18.822
Iter 5/200 - Loss: 26.057
Iter 6/200 - Loss: 35.513
Iter 7/200 - Loss: 23.057
Iter 8/200 - Loss: 24.311
Iter 9/200 - Loss: 13.113
Iter 10/200 - Loss: 22.351
Iter 11/200 - Loss: 15.251
Iter 12/200 - Loss: 24.996
Iter 13/200 - Loss: 18.725
Iter 14/200 - Loss: 15.128
Iter 15/200 - Loss: 33.561
Iter 16/200 - Loss: 26.322
Iter 17/200 - Loss: 18.903
Iter 18/200 - Loss: 14.173
Iter 19/200 - Loss: 16.179
Iter 20/200 - Loss: 28.193
Iter 21/200 - Loss: 32.424
Iter 22/200 - Loss: 28.202
Iter 23/200 - Loss: 17.866
Iter 24/200 - Loss: 12.721
Iter 25/200 - Loss: 20.719
Iter 26/200 - Loss: 24.609
Iter 27/200 - Loss: 18.927
Iter 28/200 - Loss: 27.055
Iter 29/200 - Loss: 37.109
Iter 30/200 - Loss: 14.354
Iter 31/200 - Loss: 21.298
Iter 32/200 - Loss: 15.195
Iter 33/200 - Loss: 29.076
Iter 34/200 - Loss: 20.184
Iter 35/200 - Loss: 21.808
Iter 36/200 - Loss: 41.396
Iter 37/200 - Loss: 13.578
Iter 38/20