In [11]:
import torch 
import gpytorch

import sys
sys.path.append('..')
import data
import utils
import models
import swag

import copy

In [12]:
from gpytorch.lazy import RootLazyTensor, DiagLazyTensor, AddedDiagLazyTensor
from gpytorch.distributions import MultivariateNormal


In [13]:
def flatten(lst):
    tmp = [i.contiguous().view(-1,1) for i in lst]
    return torch.cat(tmp).view(-1)

In [14]:
model_cfg = getattr(models, 'VGG16')

loaders, num_classes = data.loaders(
    'CIFAR10',
    '../../../../Documents/datasets/',
    128,
    4,
    model_cfg.transform_train,
    model_cfg.transform_test,
    use_validation=not True,
    split_classes=None
)

Files already downloaded and verified
You are going to run models on the test set. Are you sure?
Files already downloaded and verified


In [15]:
model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
model = model.cuda()

swag_model = swag.SWAG(model_cfg.base, no_cov_mat=False, max_num_models=20, loading = True, 
                       *model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
swag_model = swag_model.cuda()

RuntimeError: cuda runtime error (77) : an illegal memory access was encountered at /opt/conda/conda-bld/pytorch-nightly_1540809146283/work/aten/src/THC/generic/THCTensorCopy.cpp:20

In [6]:
checkpoint = torch.load('../../../nfs01_tesla/swa_uncertainties/exps/vgg16_cifar10_0618_1/swag-300.pt')

swag_model.load_state_dict(checkpoint['state_dict'])

Now that I've loaded the model I'll go ahead and define the gpytorch structs necessary to compute LLs.

In [7]:
mean_list = []
var_list = []
cov_mat_root_list = []
for module, name in swag_model.params:
    mean = module.__getattr__('%s_mean' % name)
    sq_mean = module.__getattr__('%s_sq_mean' % name)
    cov_mat_sqrt = module.__getattr__('%s_cov_mat_sqrt' % name)
    
    mean_list.append(copy.deepcopy(mean))
    var_list.append(copy.deepcopy(sq_mean - mean ** 2.0))
    cov_mat_root_list.append(copy.deepcopy(cov_mat_sqrt))

In [26]:
cov_mat_root = torch.cat(cov_mat_root_list,dim=1)
mean_vector = flatten(mean_list)
var_vector = flatten(var_list)

In [27]:
cov_mat_root = cov_mat_root[:, :5000]
mean_vector = mean_vector[:5000]
var_vector = var_vector[:5000]

In [28]:
cov_mat_lt = RootLazyTensor(cov_mat_root.t())
var_lt = DiagLazyTensor(var_vector)
covar_lt = AddedDiagLazyTensor(var_lt, cov_mat_lt)

In [29]:
print(cov_mat_lt.size(), var_lt.size(), covar_lt.size())

torch.Size([5000, 5000]) (5000, 5000) (5000, 5000)


In [30]:
#swag_model.sample(scale=1.0, cov=True, seed=1107)
#param_list = [getattr(param, name) for param, name in swag_model.params]
#param_vector = flatten(param_list)
#param_vector = param_vector[:10000]
param_vector = torch.randn(5000).cuda()

In [57]:
with gpytorch.settings.num_trace_samples(1) and gpytorch.settings.max_preconditioner_size(20):
    qdist = MultivariateNormal(mean_vector, covar_lt)
    print(qdist.log_prob(param_vector))

tensor(nan, device='cuda:0')


In [32]:
diff = param_vector - mean_vector
exp_part = -0.5 * diff.dot(covar_lt.inv_matmul(param_vector))
print(exp_part)

tensor(-1.2573e+08, device='cuda:0')


In [58]:
covar_lt.log_det()

RuntimeError: Lapack Error syev : 14 off-diagonal elements didn't converge to zero at /opt/conda/conda-bld/pytorch-nightly_1540809146283/work/aten/src/TH/generic/THTensorLapack.cpp:395

In [34]:
probe_vectors = torch.empty(covar_lt.size(0), 50).cuda()
probe_vectors.bernoulli_().mul_(2).add_(-1)
solves, t_mat = gpytorch.utils.linear_cg(covar_lt.matmul, rhs = probe_vectors, 
                                         n_tridiag = 50, 
                                         max_tridiag_iter = 20)

In [35]:
covar_lt_true = covar_lt.evaluate()

In [36]:
true_dist = torch.distributions.MultivariateNormal(mean_vector, covar_lt_true)
true_dist.log_prob(param_vector)

tensor(-1.2592e+08, device='cuda:0')

## Blockwise LL Calculation

In [8]:
swag_model.sample(scale=1.0, cov=True, seed=1107)
param_list = [getattr(param, name) for param, name in swag_model.params]

In [9]:
def compute_ll_for_block(vec, mean, var, cov_mat_root):
    vec = flatten(vec)
    mean = flatten(mean)
    var = flatten(var)
    print(vec.size(), mean.size(), var.size())
    
    cov_mat_lt = RootLazyTensor(cov_mat_root.t())
    var_lt = DiagLazyTensor(var)
    covar_lt = AddedDiagLazyTensor(var_lt, cov_mat_lt)
    qdist = MultivariateNormal(mean, covar_lt)
    with gpytorch.settings.max_preconditioner_size(20) and gpytorch.settings.max_cg_iterations(100):
        return qdist.log_prob(vec)
    

In [10]:
full_logprob = 0
for (param, mean, var, cov_mat_root) in zip(param_list, mean_list, var_list, cov_mat_root_list):
    block_ll = compute_ll_for_block(param, mean, var, cov_mat_root)
    print(block_ll)
    full_logprob += block_ll

torch.Size([1728]) torch.Size([1728]) torch.Size([1728])
tensor(5769.6831, device='cuda:0')
torch.Size([64]) torch.Size([64]) torch.Size([64])
tensor(156.4693, device='cuda:0')
torch.Size([36864]) torch.Size([36864]) torch.Size([36864])
tensor(nan, device='cuda:0')
torch.Size([64]) torch.Size([64]) torch.Size([64])
tensor(184.8898, device='cuda:0')
torch.Size([73728]) torch.Size([73728]) torch.Size([73728])
tensor(259847.6250, device='cuda:0')
torch.Size([128]) torch.Size([128]) torch.Size([128])
tensor(443.4670, device='cuda:0')
torch.Size([147456]) torch.Size([147456]) torch.Size([147456])


RuntimeError: cuda runtime error (77) : an illegal memory access was encountered at /opt/conda/conda-bld/pytorch-nightly_1540809146283/work/aten/src/THC/THCReduceAll.cuh:317