In [1]:
import torch, torchvision, gpytorch
import torchvision.transforms as transforms
from convgp.model import ConvClassificationModel
from tqdm.notebook import tqdm
import numpy as np
from numpy.random import choice

## Load in the CIFAR10 data

In [2]:
batch_size = 5
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, #pin_memory=True,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, 
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, #pin_memory=True,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


## Creating the model, learning hyperparameters (and variational parameters)

In [3]:
img_shape = (32, 32)
patch_shape = (10, 10)
n_channels = 3
n_classes = 10
n_epochs = 1

rand_idx = choice(len(trainset), batch_size)
inducing_points = torch.cat([trainset[i][0][None,:] for i in rand_idx]).cuda()
# inducing_points = torch.randn(batch_size, n_channels, *img_shape).cuda()
model = ConvClassificationModel(inducing_points, patch_shape, n_classes).cuda()
# for key, kern in model.covar_module.base_kernels.items():
#     model.covar_module.base_kernels[key] = kern.cuda()
    
likelihood = gpytorch.likelihoods.SoftmaxLikelihood(num_features=n_classes, num_classes=n_classes).cuda()
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=len(trainset), beta=0.2)
optim = torch.optim.Adam(list(model.parameters())+ list(model.covar_module.parameters()) + list(likelihood.parameters()))

In [4]:
model.train()
likelihood.train()

train_iter = tqdm(trainloader)
with gpytorch.settings.lazily_evaluate_kernels(False):
    for _ in range(n_epochs):
        for x, y in train_iter:
            x = x.cuda(); y = y.cuda()
            optim.zero_grad()
            output = model(x)
            loss = -mll(output, y)
            loss.backward()
            train_iter.set_postfix(loss=loss.item())
            optim.step()
            del x; del y; del output; del loss; torch.cuda.empty_cache()

HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))




In [5]:
torch.save(model.state_dict(), 'model.pth')

In [6]:
model.eval()
likelihood.eval()

correct = 0
test_iter = tqdm(testloader)
with torch.no_grad(), gpytorch.settings.lazily_evaluate_kernels(False), gpytorch.settings.num_likelihood_samples(20):
    for i, (x, y) in enumerate(test_iter,1):
        x = x.cuda(); y = y.cuda()
        pred = likelihood(model(x))
        pred = pred.probs.mean(0).argmax(-1)
        correct += int(pred.eq(y.view_as(pred)).cpu().sum())
        test_iter.set_postfix(acc=(correct/(i*batch_size)))

HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


