In [20]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.nn.utils import parameters_to_vector
from torch.optim import Adam
from preds.models import MLPS
from preds.likelihoods import GaussianLh
from preds.datasets import UCIClassificationDatasets
from preds.laplace import Laplace

In [21]:
width = 25
depth = 3
prior_prec = 0.1
lr = 5e-3
n_epochs = 1000
lh = GaussianLh(sigma_noise=0.3)  # fixed to true noise
uci_dataset = 'glass'
root_dir = '/Users/tamire1/Documents/GitHub/BNN-predictions/data/'
device = 'cpu'

Load training data

In [22]:
ds_train = UCIClassificationDatasets(train=True, data_set=uci_dataset, split_train_size=0.95, double=False, root=root_dir)
X_train, y_train = ds_train.data.to(device), ds_train.targets.to(device).unsqueeze(1)
train_loader = [(X_train, y_train)]  # mock

Load validation data

In [23]:
ds_val = UCIClassificationDatasets(train=False,valid=True, data_set=uci_dataset, split_train_size=0.95, double=False, root=root_dir)
X_val, y_val = ds_val.data.to(device), ds_val.targets.to(device).unsqueeze(1)
val_loader = [(X_val, y_val)] 

In [24]:
model = MLPS(X_train.shape[1], [width]*depth, 1, activation='tanh', flatten=False).to(device)
optim = Adam(model.parameters(), lr=lr)
losses = list()
for i in range(n_epochs):
    f = model(X_train)
    w = parameters_to_vector(model.parameters())
    reg = 0.5 * prior_prec * w @ w
    loss = - lh.log_likelihood(y_train, f) + reg
    loss.backward()
    optim.step()
    losses.append(loss.item())
    model.zero_grad()
# optionally: plt.plot(losses) to monitor convergence

In [25]:
lap = Laplace(model, float(prior_prec), lh)


def get_pred_for(x, model_type='glm', cov_type='full'):
    #### INFERENCE (Posterior approximation) ####
    lap.infer(train_loader, cov_type=cov_type, dampen_kron=model_type=='bnn')
    if model_type == 'glm':
        #### GLM PREDICTIVE ####
        mu, var = lap.predictive_samples_glm(x, n_samples=1000)
    elif model_type == 'bnn':
        #### BNN PREDICTIVE ####
        samples = lap.predictive_samples_bnn(x, n_samples=1000)
        mu, var = samples.mean(dim=0), samples.var(dim=0)
    else:
        raise ValueError('unsupported model_type.')
    mu = mu.detach().cpu().squeeze().numpy()
    var = var.detach().cpu().squeeze().numpy()
    return mu, var

In [30]:
print(X_train.dtype)
print(X_val.dtype)
print()

torch.float32

In [26]:
# GLM

mu, var = get_pred_for(X_val, 'glm', 'full')
mu, var = get_pred_for(X_val, 'glm', 'kron')
mu, var = get_pred_for(X_val, 'glm', 'diag')

RuntimeError: Found dtype Long but expected Float
Exception raised from compute_types at /Users/distiller/project/conda/conda-bld/pytorch_1595629430416/work/aten/src/ATen/native/TensorIterator.cpp:183 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >) + 169 (0x10f2f9199 in libc10.dylib)
frame #1: at::TensorIterator::compute_types(at::TensorIteratorConfig const&) + 3842 (0x134695312 in libtorch_cpu.dylib)
frame #2: at::TensorIterator::build(at::TensorIteratorConfig&) + 618 (0x13469e51a in libtorch_cpu.dylib)
frame #3: at::TensorIterator::TensorIterator(at::TensorIteratorConfig&) + 223 (0x13469e1ff in libtorch_cpu.dylib)
frame #4: at::native::mse_loss_backward_out(at::Tensor&, at::Tensor const&, at::Tensor const&, at::Tensor const&, long long) + 410 (0x1344e9f7a in libtorch_cpu.dylib)
frame #5: at::CPUType::mse_loss_backward_out_grad_input(at::Tensor&, at::Tensor const&, at::Tensor const&, at::Tensor const&, long long) + 9 (0x134907fe9 in libtorch_cpu.dylib)
frame #6: at::mse_loss_backward_out(at::Tensor&, at::Tensor const&, at::Tensor const&, at::Tensor const&, long long) + 157 (0x1349bb44d in libtorch_cpu.dylib)
frame #7: at::native::mse_loss_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, long long) + 118 (0x1344e9cf6 in libtorch_cpu.dylib)
frame #8: at::CPUType::mse_loss_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, long long) + 14 (0x134907ffe in libtorch_cpu.dylib)
frame #9: c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoRuntimeFunctor_<at::Tensor (*)(at::Tensor const&, at::Tensor const&, at::Tensor const&, long long), at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, at::Tensor const&, long long> >, at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, long long)>::call(c10::OperatorKernel*, at::Tensor const&, at::Tensor const&, at::Tensor const&, long long) + 27 (0x1342b987b in libtorch_cpu.dylib)
frame #10: at::Tensor c10::Dispatcher::call<at::Tensor, at::Tensor const&, at::Tensor const&, at::Tensor const&, long long>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, long long)> const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, long long) const + 287 (0x1349ea13f in libtorch_cpu.dylib)
frame #11: at::mse_loss_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, long long) + 157 (0x1349bb51d in libtorch_cpu.dylib)
frame #12: torch::autograd::VariableType::mse_loss_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, long long) + 900 (0x1369869f4 in libtorch_cpu.dylib)
frame #13: c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoRuntimeFunctor_<at::Tensor (*)(at::Tensor const&, at::Tensor const&, at::Tensor const&, long long), at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, at::Tensor const&, long long> >, at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, long long)>::call(c10::OperatorKernel*, at::Tensor const&, at::Tensor const&, at::Tensor const&, long long) + 27 (0x1342b987b in libtorch_cpu.dylib)
frame #14: at::Tensor c10::Dispatcher::call<at::Tensor, at::Tensor const&, at::Tensor const&, at::Tensor const&, long long>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, long long)> const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, long long) const + 287 (0x1349ea13f in libtorch_cpu.dylib)
frame #15: at::mse_loss_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, long long) + 157 (0x1349bb51d in libtorch_cpu.dylib)
frame #16: torch::autograd::generated::MseLossBackward::apply(std::__1::vector<at::Tensor, std::__1::allocator<at::Tensor> >&&) + 298 (0x136811dca in libtorch_cpu.dylib)
frame #17: torch::autograd::Node::operator()(std::__1::vector<at::Tensor, std::__1::allocator<at::Tensor> >&&) + 742 (0x136f6b636 in libtorch_cpu.dylib)
frame #18: torch::autograd::Engine::evaluate_function(std::__1::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::__1::shared_ptr<torch::autograd::ReadyQueue> const&) + 1661 (0x136f6063d in libtorch_cpu.dylib)
frame #19: torch::autograd::Engine::thread_main(std::__1::shared_ptr<torch::autograd::GraphTask> const&) + 764 (0x136f5f66c in libtorch_cpu.dylib)
frame #20: torch::autograd::Engine::execute_with_graph_task(std::__1::shared_ptr<torch::autograd::GraphTask> const&, std::__1::shared_ptr<torch::autograd::Node>) + 1023 (0x136f696cf in libtorch_cpu.dylib)
frame #21: torch::autograd::python::PythonEngine::execute_with_graph_task(std::__1::shared_ptr<torch::autograd::GraphTask> const&, std::__1::shared_ptr<torch::autograd::Node>) + 53 (0x11536b145 in libtorch_python.dylib)
frame #22: torch::autograd::Engine::execute(std::__1::vector<torch::autograd::Edge, std::__1::allocator<torch::autograd::Edge> > const&, std::__1::vector<at::Tensor, std::__1::allocator<at::Tensor> > const&, bool, bool, std::__1::vector<torch::autograd::Edge, std::__1::allocator<torch::autograd::Edge> > const&) + 662 (0x136f67d36 in libtorch_cpu.dylib)
frame #23: torch::autograd::python::PythonEngine::execute(std::__1::vector<torch::autograd::Edge, std::__1::allocator<torch::autograd::Edge> > const&, std::__1::vector<at::Tensor, std::__1::allocator<at::Tensor> > const&, bool, bool, std::__1::vector<torch::autograd::Edge, std::__1::allocator<torch::autograd::Edge> > const&) + 82 (0x11536af42 in libtorch_python.dylib)
frame #24: THPEngine_run_backward(THPEngine*, _object*, _object*) + 2174 (0x11536baee in libtorch_python.dylib)
frame #25: _PyMethodDef_RawFastCallKeywords + 642 (0x1049b2d42 in python3.7)
frame #26: call_function + 257 (0x104af32b1 in python3.7)
frame #27: _PyEval_EvalFrameDefault + 45890 (0x104af1052 in python3.7)
frame #28: _PyEval_EvalCodeWithName + 418 (0x104ae4a42 in python3.7)
frame #29: _PyFunction_FastCallKeywords + 195 (0x1049b2a73 in python3.7)
frame #30: call_function + 181 (0x104af3265 in python3.7)
frame #31: _PyEval_EvalFrameDefault + 45215 (0x104af0daf in python3.7)
frame #32: _PyEval_EvalCodeWithName + 418 (0x104ae4a42 in python3.7)
frame #33: _PyFunction_FastCallKeywords + 195 (0x1049b2a73 in python3.7)
frame #34: call_function + 181 (0x104af3265 in python3.7)
frame #35: _PyEval_EvalFrameDefault + 45065 (0x104af0d19 in python3.7)
frame #36: _PyEval_EvalCodeWithName + 418 (0x104ae4a42 in python3.7)
frame #37: _PyFunction_FastCallKeywords + 195 (0x1049b2a73 in python3.7)
frame #38: call_function + 181 (0x104af3265 in python3.7)
frame #39: _PyEval_EvalFrameDefault + 45890 (0x104af1052 in python3.7)
frame #40: _PyEval_EvalCodeWithName + 418 (0x104ae4a42 in python3.7)
frame #41: _PyFunction_FastCallKeywords + 195 (0x1049b2a73 in python3.7)
frame #42: call_function + 181 (0x104af3265 in python3.7)
frame #43: _PyEval_EvalFrameDefault + 45705 (0x104af0f99 in python3.7)
frame #44: _PyEval_EvalCodeWithName + 418 (0x104ae4a42 in python3.7)
frame #45: builtin_exec + 347 (0x104adf83b in python3.7)
frame #46: _PyMethodDef_RawFastCallKeywords + 230 (0x1049b2ba6 in python3.7)
frame #47: call_function + 257 (0x104af32b1 in python3.7)
frame #48: _PyEval_EvalFrameDefault + 45705 (0x104af0f99 in python3.7)
frame #49: gen_send_ex + 180 (0x1049cbc04 in python3.7)
frame #50: _PyEval_EvalFrameDefault + 25228 (0x104aebf9c in python3.7)
frame #51: gen_send_ex + 180 (0x1049cbc04 in python3.7)
frame #52: _PyEval_EvalFrameDefault + 25228 (0x104aebf9c in python3.7)
frame #53: gen_send_ex + 180 (0x1049cbc04 in python3.7)
frame #54: _PyMethodDef_RawFastCallKeywords + 131 (0x1049b2b43 in python3.7)
frame #55: _PyMethodDescr_FastCallKeywords + 84 (0x1049bfae4 in python3.7)
frame #56: call_function + 382 (0x104af332e in python3.7)
frame #57: _PyEval_EvalFrameDefault + 45065 (0x104af0d19 in python3.7)
frame #58: function_code_fastcall + 120 (0x1049b2368 in python3.7)
frame #59: call_function + 181 (0x104af3265 in python3.7)
frame #60: _PyEval_EvalFrameDefault + 45705 (0x104af0f99 in python3.7)
frame #61: function_code_fastcall + 120 (0x1049b2368 in python3.7)
frame #62: call_function + 181 (0x104af3265 in python3.7)
frame #63: _PyEval_EvalFrameDefault + 45065 (0x104af0d19 in python3.7)
