In [1]:
import gpytorch
import torch

import cnn_gp

In [2]:
import torchvision

In [3]:
import sys
sys.path.append('..')
from kernels.cnn_kernel import CNNGP_Kernel

In [4]:
### load dataset
mnist = torchvision.datasets.MNIST('/home/wesley/Documents/datasets/', download=True, 
                                   transform=torchvision.transforms.ToTensor())

In [5]:
### construct training and test datasets (batched for now)

iter_data_loader = iter(torch.utils.data.DataLoader(mnist, batch_size = 128))

inputs, targets = next(iter_data_loader)
targets = targets.float()

if torch.cuda.is_available():
    inputs, targets = inputs.cuda(), targets.cuda()
    
test_inputs, test_targets = next(iter_data_loader)
test_targets = test_targets.float()

if torch.cuda.is_available():
    test_inputs, test_targets = test_inputs.cuda(), test_targets.cuda()

In [6]:
### define model from resnet on mnist
in_channels = 1
out_channels = 10

var_bias = 0.86
var_weight = 0.79

layers = []
for _ in range(7):  # n_layers
    layers += [
        cnn_gp.Conv2d(kernel_size=7, padding="same", var_weight=var_weight,
               var_bias=var_bias),
        cnn_gp.ReLU(),
    ]
initial_model = cnn_gp.Sequential(
    *layers,
    cnn_gp.Conv2d(kernel_size=28, padding=0, var_weight=var_weight,
           var_bias=var_bias),
)

In [7]:
class DPClassificationModel(gpytorch.models.ExactGP):
    def __init__(self, cnn_model, input_shape, train_x, train_y, alpha_epsilon = 0.01):
        num_classes = train_y.max() + 1
        
        # set alpha = \alpha_\epsilon
        alpha = alpha_epsilon * torch.ones(train_x.shape[-2], num_classes, 
                                           device = train_x.device, dtype = train_x.dtype)
        
        # alpha[class_labels] = 1 + \alpha_\epsilon
        alpha[torch.arange(len(train_x)), train_y] = alpha[torch.arange(len(train_x)), train_y] + 1.
        
        # sigma^2 = log(1 / alpha + 1)
        sigma2_i = torch.log(1 / alpha + 1.)
        
        # y = log(alpha) - 0.5 * sigma^2
        transformed_targets = alpha.log() - 0.5 * sigma2_i
        
        likelihood = gpytorch.likelihoods.FixedNoiseGaussianLikelihood(sigma2_i.t(), learn_additional_noise= True)
        
        super(DPClassificationModel, self).__init__(train_x, transformed_targets.t(), likelihood)
        self.transformed_targets = transformed_targets
        
        self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size((num_classes,)))
        self.covar_module = CNNGP_Kernel(cnn_model, input_shape, batch_shape=torch.Size((num_classes,)))
        
        self.likelihood = likelihood
    
    def forward(self, x):
        mean = self.mean_module(x)
        covar = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean, covar)

In [8]:
### define model
### TODO: check that our 2d reshaping is reasonable

#ikelihood = gpytorch.likelihoods.GaussianLikelihood()
model = DPClassificationModel(initial_model, (1, 28, 28), train_x = inputs.view(inputs.shape[0], 28*28), 
                              train_y=targets.long())
likelihood = model.likelihood

if torch.cuda.is_available():
    likelihood = likelihood.cuda()
    model = model.cuda()

In [9]:
%pdb

Automatic pdb calling has been turned ON


In [10]:
model.train_inputs[0].shape

torch.Size([128, 784])

In [11]:
model(inputs.view(inputs.shape[0], 28*28)).rsample()



tensor([[-0.3340, -0.3417, -0.3283,  ..., -0.3283, -0.3369, -0.3324],
        [ 0.8561,  0.8603,  0.8447,  ...,  0.8472,  0.8383,  0.8488],
        [-0.8619, -0.8622, -0.8483,  ..., -0.8499, -0.8484, -0.8495],
        ...,
        [ 0.1194,  0.1170,  0.1110,  ...,  0.1206,  0.1083,  0.1112],
        [-2.2980, -2.2827, -2.3017,  ..., -2.2921, -2.2944, -2.2938],
        [ 1.2592,  1.2544,  1.2709,  ...,  1.2519,  1.2423,  1.2565]],
       device='cuda:0', grad_fn=<ViewBackward>)

In [12]:
from botorch.optim.fit import fit_gpytorch_torch

In [13]:
### define and fit any free parameters in the model

mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
fit_gpytorch_torch(mll, options={"maxiter": 1000, "disp": True, "lr": 0.1})

Iter 10/1000: 26.3676700592041
Iter 20/1000: 24.937366485595703
Iter 30/1000: 24.218128204345703
Iter 40/1000: 23.868453979492188
Iter 50/1000: 23.693965911865234
Iter 60/1000: 23.605567932128906
Iter 70/1000: 23.56092071533203
Iter 80/1000: 23.53842544555664
Iter 90/1000: 23.526880264282227
Iter 100/1000: 23.520675659179688
Iter 110/1000: 23.51714324951172
Iter 120/1000: 23.51503562927246
Iter 130/1000: 23.513748168945312
Iter 140/1000: 23.51296615600586
Iter 150/1000: 23.512489318847656
Iter 160/1000: 23.512208938598633
Iter 170/1000: 23.512042999267578
Iter 180/1000: 23.511947631835938
Iter 190/1000: 23.511890411376953


(ExactMarginalLogLikelihood(
   (likelihood): FixedNoiseGaussianLikelihood(
     (noise_covar): FixedGaussianNoise()
     (second_noise_covar): HomoskedasticNoise(
       (raw_noise_constraint): GreaterThan(1.000E-04)
     )
   )
   (model): DPClassificationModel(
     (likelihood): FixedNoiseGaussianLikelihood(
       (noise_covar): FixedGaussianNoise()
       (second_noise_covar): HomoskedasticNoise(
         (raw_noise_constraint): GreaterThan(1.000E-04)
       )
     )
     (mean_module): ConstantMean()
     (covar_module): CNNGP_Kernel(
       (model): Sequential(
         (0): Conv2d()
         (1): ReLU()
         (2): Conv2d()
         (3): ReLU()
         (4): Conv2d()
         (5): ReLU()
         (6): Conv2d()
         (7): ReLU()
         (8): Conv2d()
         (9): ReLU()
         (10): Conv2d()
         (11): ReLU()
         (12): Conv2d()
         (13): ReLU()
         (14): Conv2d()
       )
     )
   )
 ),
 {'fopt': 23.511878967285156,
  'wall_time': 20.668388605117798

In [14]:
### set to test mode

model.eval()
likelihood.eval()

FixedNoiseGaussianLikelihood(
  (noise_covar): FixedGaussianNoise()
  (second_noise_covar): HomoskedasticNoise(
    (raw_noise_constraint): GreaterThan(1.000E-04)
  )
)

In [15]:
test_inputs.shape

torch.Size([128, 1, 28, 28])

In [16]:
inputs.shape

torch.Size([128, 1, 28, 28])

In [17]:
test_preds = model(test_inputs.view(test_inputs.shape[0], 28*28))

In [18]:
test_preds.mean

tensor([[-5.4599, -5.4597, -5.4597,  ..., -5.4596, -5.4595, -5.4596],
        [-5.3144, -5.3147, -5.3146,  ..., -5.3148, -5.3147, -5.3147],
        [-6.0704, -6.0704, -6.0704,  ..., -6.0703, -6.0704, -6.0704],
        ...,
        [-5.8353, -5.8354, -5.8353,  ..., -5.8353, -5.8353, -5.8353],
        [-6.1505, -6.1505, -6.1505,  ..., -6.1506, -6.1506, -6.1504],
        [-5.8352, -5.8353, -5.8352,  ..., -5.8353, -5.8353, -5.8353]],
       device='cuda:0', grad_fn=<ViewBackward>)

In [30]:
test_preds.mean[:,0]

tensor([-5.4599, -5.3144, -6.0704, -5.8352, -5.8353, -6.3962, -5.8353, -5.8353,
        -6.1505, -5.8352], device='cuda:0', grad_fn=<SelectBackward>)

In [25]:
test_preds.mean.argmax(0)

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')

In [26]:
test_targets

tensor([1., 6., 3., 4., 5., 9., 1., 3., 3., 8., 5., 4., 7., 7., 4., 2., 8., 5.,
        8., 6., 7., 3., 4., 6., 1., 9., 9., 6., 0., 3., 7., 2., 8., 2., 9., 4.,
        4., 6., 4., 9., 7., 0., 9., 2., 9., 5., 1., 5., 9., 1., 2., 3., 2., 3.,
        5., 9., 1., 7., 6., 2., 8., 2., 2., 5., 0., 7., 4., 9., 7., 8., 3., 2.,
        1., 1., 8., 3., 6., 1., 0., 3., 1., 0., 0., 1., 7., 2., 7., 3., 0., 4.,
        6., 5., 2., 6., 4., 7., 1., 8., 9., 9., 3., 0., 7., 1., 0., 2., 0., 3.,
        5., 4., 6., 5., 8., 6., 3., 7., 5., 8., 0., 9., 1., 0., 3., 1., 2., 2.,
        3., 3.], device='cuda:0')