In [1]:
import gpytorch
import torch

import cnn_gp

In [2]:
import torchvision

In [3]:
import sys
sys.path.append('..')
from kernels.cnn_kernel import CNNGP_Kernel

In [6]:
### load dataset
mnist = torchvision.datasets.MNIST('/home/wesley/Documents/datasets/', download=True, 
                                   transform=torchvision.transforms.ToTensor())

In [7]:
### construct training and test datasets (batched for now)

iter_data_loader = iter(torch.utils.data.DataLoader(mnist, batch_size = 128))

inputs, targets = next(iter_data_loader)
targets = targets.float()

if torch.cuda.is_available():
    inputs, targets = inputs.cuda(), targets.cuda()
    
test_inputs, test_targets = next(iter_data_loader)
test_targets = test_targets.float()

if torch.cuda.is_available():
    test_inputs, test_targets = test_inputs.cuda(), test_targets.cuda()

In [4]:
### define model from resnet on mnist
in_channels = 1
out_channels = 10

var_bias = 0.86
var_weight = 0.79

layers = []
for _ in range(7):  # n_layers
    layers += [
        cnn_gp.Conv2d(kernel_size=7, padding="same", var_weight=var_weight,
               var_bias=var_bias),
        cnn_gp.ReLU(),
    ]
initial_model = cnn_gp.Sequential(
    *layers,
    cnn_gp.Conv2d(kernel_size=28, padding=0, var_weight=var_weight,
           var_bias=var_bias),
)

In [5]:
### define model

class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, cnn_model, input_shape, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ZeroMean()
        self.covar_module = CNNGP_Kernel(cnn_model, input_shape)
    
    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

In [8]:
### define model
### TODO: check that our 2d reshaping is reasonable

likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = ExactGPModel(initial_model, inputs.shape[1:], inputs.view(inputs.shape[0], 28*28), targets, likelihood)

if torch.cuda.is_available():
    likelihood = likelihood.cuda()
    model = model.cuda()

In [9]:
from botorch.optim.fit import fit_gpytorch_torch

In [10]:
### define and fit any free parameters in the model

mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
fit_gpytorch_torch(mll, options={"maxiter": 1000, "disp": True, "lr": 0.1})

Iter 10/1000: 4.816862106323242
Iter 20/1000: 3.7636613845825195
Iter 30/1000: 3.343903064727783
Iter 40/1000: 3.142294406890869
Iter 50/1000: 3.0279245376586914
Iter 60/1000: 2.9538207054138184
Iter 70/1000: 2.9008588790893555
Iter 80/1000: 2.8603687286376953
Iter 90/1000: 2.828003406524658
Iter 100/1000: 2.801354169845581
Iter 110/1000: 2.7789578437805176
Iter 120/1000: 2.7598509788513184
Iter 130/1000: 2.7433600425720215
Iter 140/1000: 2.728991985321045
Iter 150/1000: 2.716373920440674
Iter 160/1000: 2.705216884613037
Iter 170/1000: 2.6952929496765137
Iter 180/1000: 2.6864194869995117
Iter 190/1000: 2.6784472465515137
Iter 200/1000: 2.671255111694336
Iter 210/1000: 2.6647419929504395
Iter 220/1000: 2.658823013305664
Iter 230/1000: 2.6534271240234375
Iter 240/1000: 2.648494243621826
Iter 250/1000: 2.643972396850586
Iter 260/1000: 2.639817237854004
Iter 270/1000: 2.635991096496582
Iter 280/1000: 2.6324596405029297
Iter 290/1000: 2.6291942596435547
Iter 300/1000: 2.6261696815490723
Ite

(ExactMarginalLogLikelihood(
   (likelihood): GaussianLikelihood(
     (quadrature): GaussHermiteQuadrature1D()
     (noise_covar): HomoskedasticNoise(
       (raw_noise_constraint): GreaterThan(1.000E-04)
     )
   )
   (model): ExactGPModel(
     (likelihood): GaussianLikelihood(
       (quadrature): GaussHermiteQuadrature1D()
       (noise_covar): HomoskedasticNoise(
         (raw_noise_constraint): GreaterThan(1.000E-04)
       )
     )
     (mean_module): ZeroMean()
     (covar_module): CNNGP_Kernel(
       (model): Sequential(
         (0): Conv2d()
         (1): ReLU()
         (2): Conv2d()
         (3): ReLU()
         (4): Conv2d()
         (5): ReLU()
         (6): Conv2d()
         (7): ReLU()
         (8): Conv2d()
         (9): ReLU()
         (10): Conv2d()
         (11): ReLU()
         (12): Conv2d()
         (13): ReLU()
         (14): Conv2d()
       )
     )
   )
 ),
 {'fopt': 2.580012798309326,
  'wall_time': 108.10976529121399,
  'iterations': [OptimizationIterati

In [11]:
### set to test mode

model.eval()
likelihood.eval()

GaussianLikelihood(
  (quadrature): GaussHermiteQuadrature1D()
  (noise_covar): HomoskedasticNoise(
    (raw_noise_constraint): GreaterThan(1.000E-04)
  )
)

In [12]:
test_preds = model(test_inputs.view(test_inputs.shape[0], 28*28))

In [13]:
test_preds.mean

tensor([3.8670, 3.8671, 3.8670, 3.8672, 3.8670, 3.8672, 3.8671, 3.8671, 3.8669,
        3.8671, 3.8669, 3.8672, 3.8671, 3.8672, 3.8672, 3.8671, 3.8670, 3.8670,
        3.8671, 3.8673, 3.8672, 3.8670, 3.8671, 3.8671, 3.8669, 3.8672, 3.8671,
        3.8671, 3.8669, 3.8671, 3.8670, 3.8671, 3.8671, 3.8669, 3.8672, 3.8672,
        3.8671, 3.8672, 3.8673, 3.8673, 3.8672, 3.8670, 3.8672, 3.8671, 3.8672,
        3.8670, 3.8670, 3.8670, 3.8672, 3.8670, 3.8669, 3.8670, 3.8670, 3.8671,
        3.8669, 3.8673, 3.8670, 3.8672, 3.8670, 3.8668, 3.8671, 3.8670, 3.8669,
        3.8670, 3.8669, 3.8672, 3.8672, 3.8672, 3.8671, 3.8671, 3.8670, 3.8671,
        3.8670, 3.8670, 3.8671, 3.8671, 3.8672, 3.8670, 3.8668, 3.8671, 3.8670,
        3.8667, 3.8669, 3.8670, 3.8671, 3.8671, 3.8671, 3.8670, 3.8670, 3.8672,
        3.8671, 3.8670, 3.8669, 3.8671, 3.8670, 3.8672, 3.8671, 3.8671, 3.8672,
        3.8673, 3.8670, 3.8668, 3.8672, 3.8670, 3.8670, 3.8671, 3.8669, 3.8670,
        3.8670, 3.8672, 3.8671, 3.8670, 

In [None]:
### yikes :(