In [1]:
import gpytorch
import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torchvision import transforms
from gpytorch.kernels import RBFKernel, GridInterpolationKernel


gpytorch.functions.use_toeplitz = False

In [2]:
train_omniglot = torchvision.datasets.ImageFolder('/scratch/bw462/', transform=transforms.Compose([
                        transforms.Scale((28,28)),
                        transforms.ToTensor()
                   ]))                                              
"""
test_mnist = torchvision.datasets.ImageFolder('/tmp', split='test',
                                        download=True, transform=transforms.Compose([
                       transforms.ToTensor()
                   ]))
"""

"\ntest_mnist = torchvision.datasets.ImageFolder('/tmp', split='test',\n                                        download=True, transform=transforms.Compose([\n                       transforms.ToTensor()\n                   ]))\n"

In [3]:
class FeatureExtractor(nn.Sequential):
    
    def __init__(self):
        super(FeatureExtractor, self).__init__(nn.Conv2d(1, 32, kernel_size=5, padding=2),
                                 nn.BatchNorm2d(32),
                                 nn.ReLU(),
                                 nn.MaxPool2d(2, 2),
                                 nn.Conv2d(32, 64, kernel_size=5, padding=2),
                                 nn.BatchNorm2d(64),
                                 nn.ReLU(),
                                 nn.MaxPool2d(2, 2))
        
class Bottleneck(nn.Sequential):
    
    def __init__(self):
        super(Bottleneck, self).__init__(nn.Linear(64*7*7, 128),
                                         nn.BatchNorm1d(128),
                                 nn.ReLU(),
                                 nn.Linear(128, 128),
                                 nn.BatchNorm1d(128),
                                 nn.ReLU(),
                                 nn.Linear(128,64),
                                 nn.BatchNorm1d(64))

class LeBottleNet(nn.Module):
    
    def __init__(self):
        super(LeBottleNet, self).__init__()
        self.feature_extractor = FeatureExtractor()
        self.bottleneck = Bottleneck()
    
    def forward(self, x):
        input_x = x[:,0,:,:].unsqueeze(1)
        features = self.feature_extractor(input_x)
        bottlenecked_features = self.bottleneck(features.view(-1, 64 * 7 * 7))
        return bottlenecked_features

class SimilarityCompare(gpytorch.Module):
    def __init__(self):
        super(SimilarityCompare, self).__init__()
        self.network = LeBottleNet()
        self.gp_layer = GPLayer()
    
    def forward(self, x1, x2):
        feature1 = self.network(x1)
        feature2 = self.network(x2)
        features_combined = torch.cat([feature1, feature2], dim=1)
        gp_output = self.gp_layer(features_combined)
        return gp_output

class LatentFunction(gpytorch.AdditiveGridInducingPointModule):
    def __init__(self):
        super(LatentFunction, self).__init__(grid_size=256, grid_bounds=[(-10, 10)],
                                             n_components=128, mixing_params=True)
        cov_module = RBFKernel()
        cov_module.initialize(log_lengthscale=2)
        self.cov_module = cov_module
        
    def forward(self, x):
        mean = Variable(x.data.new(len(x)).zero_())
        covar = self.cov_module(x)
        return gpytorch.random_variables.GaussianRandomVariable(mean, covar)

    
class GPLayer(gpytorch.GPModel):
    def __init__(self, n_dims=128):
        super(GPLayer, self).__init__(gpytorch.likelihoods.BernoulliLikelihood())
        self.latent_function = LatentFunction()
    
    def forward(self, x):
        res = self.latent_function(x)
        return res

In [4]:
bs=256
train_data_loader = torch.utils.data.DataLoader(train_omniglot, shuffle=True, pin_memory=True, batch_size=bs,
                                               drop_last=True)

In [5]:
criterion = nn.L1Loss().cuda()

In [6]:
similarity_model = SimilarityCompare().cuda()

In [7]:
"""
Now need to get training going, thinking setup is good. Here we need to pair up the samples from the batch
and then feed through with classifier of whether they're the same or not. Feel like might want to edit in some
way to up frequency of same class to prevent dominant strategy of always -1 (for diff class)
"""
similarity_model.train()
optimizer = torch.optim.Adam(similarity_model.parameters(), lr=0.1)
optimizer.n_iter = 0
num_epochs = 1
for i in range(num_epochs):
    for train_x_batch, train_y_batch in train_data_loader:
        optimizer.zero_grad()
        first_chunk_x, second_chunk_x = torch.chunk(train_x_batch, 2)
        first_chunk_y, second_chunk_y = torch.chunk(train_y_batch, 2)
        output = similarity_model(Variable(first_chunk_x).cuda(), 
                                  Variable(second_chunk_x).cuda())
        train_targets = Variable(torch.eq(first_chunk_y, second_chunk_y).float() * 2 - 1).cuda()
        loss = -similarity_model.gp_layer.marginal_log_likelihood(output, train_targets.float(),
                                                                  n_data=len(train_omniglot)/2)
        loss.backward()
        optimizer.step()
        optimizer.n_iter += 1
        # Do each with itself to get matching
        optimizer.zero_grad()
        var_train_x_batch = Variable(train_x_batch).cuda()
        output = similarity_model(var_train_x_batch, var_train_x_batch)
        train_targets = Variable(torch.ones(bs).float()).cuda()
        loss = -similarity_model.gp_layer.marginal_log_likelihood(output, train_targets,
                                                                  n_data=len(train_omniglot))
        loss.backward()
        optimizer.step()
        optimizer.n_iter += 1
        print('Iter %d/%d - Loss: %.3f' % (
            i + 1, num_epochs, loss.data[0],
        ))

Iter 1/1 - Loss: 106.028
Iter 1/1 - Loss: 54.375
Iter 1/1 - Loss: 30.812
Iter 1/1 - Loss: 83.890
Iter 1/1 - Loss: 22.491
Iter 1/1 - Loss: 23.212
Iter 1/1 - Loss: 81.265
Iter 1/1 - Loss: 19.551


KeyboardInterrupt: 

In [None]:
similarity_model.eval()
test_data_loader = train_data_loader#torch.utils.data.DataLoader(test_mnist, shuffle=False, pin_memory=True, batch_size=256)

for test_batch_x, test_batch_y in test_data_loader:
    predictions = deep_kernel(Variable(test_batch_x).cuda()).representation().max(-1)[1]
    test_batch_y = Variable(test_batch_y).cuda()
    avg += torch.eq(predictions, test_batch_y).float().mean().data[0]
    i += 1.
print('Accuracy: %.4f' % (avg / i))
