In [1]:
import sys
sys.path.insert(0, '..')

In [2]:
%load_ext autoreload
%autoreload 2
import vardl
import torch
import torch.nn as nn
import sklearn.datasets
import numpy as np

In [3]:
layer = vardl.layers.BayesianLinear(in_features=10, out_features=1, local_reparameterization=True, nmc_test=1, nmc_train=1)

In [4]:
layer

BayesianLinear(
  in_features=10, out_features=1, bias=False, local_repr=factorized
  (prior_W): MatrixGaussianDistribution()
  (q_posterior_W): MatrixGaussianDistribution()
)

In [5]:
arch = nn.Sequential(layer)

In [6]:
arch

Sequential(
  (0): BayesianLinear(
    in_features=10, out_features=1, bias=False, local_repr=factorized
    (prior_W): MatrixGaussianDistribution()
    (q_posterior_W): MatrixGaussianDistribution()
  )
)

In [7]:
model = vardl.models.RegrBayesianNet(architecure=arch, 
                                     dtype=torch.float32)

In [8]:
model

RegrBayesianNet(
  (architecture): Sequential(
    (0): BayesianLinear(
      in_features=10, out_features=1, bias=False, local_repr=factorized
      (prior_W): MatrixGaussianDistribution()
      (q_posterior_W): MatrixGaussianDistribution()
    )
  )
  (likelihood): Gaussian()
)

In [9]:
for name, par in model.named_parameters():
    print(name, par.requires_grad)

architecture.0.prior_W._mean False
architecture.0.prior_W._logvars False
architecture.0.q_posterior_W._mean True
architecture.0.q_posterior_W._logvars True
likelihood.log_noise_var True


In [10]:
X, Y, W = sklearn.datasets.make_regression(n_samples=100000, 
                                 n_features=10, #100
                                 n_informative=5, 
                                 n_targets=1, bias=0,
                                 effective_rank=None,
                                 noise=np.exp(0),
                                 shuffle=False, coef=True, 
                                 random_state=0)

X = torch.from_numpy(X).float()
Y = torch.from_numpy(Y.reshape(-1, 1)).float()

In [11]:
from torch.utils.data import DataLoader, TensorDataset
dataset = TensorDataset(X, Y)

dataloader = DataLoader(dataset, batch_size=256, shuffle=True, 
                              drop_last=False, num_workers=0)

In [12]:
trainer = vardl.trainer.TrainerRegressor(model=model, 
                                         train_dataloader=dataloader, 
                                         test_dataloader=dataloader, 
                                         optimizer='Adam', 
                                         optimizer_config={'lr':0.1}, 
                                         device='cpu', 
                                         seed=0)

In [13]:
for _ in range(1000):
    trainer.model.likelihood.log_noise_var.requires_grad = False
    trainer.train_batch(*next(iter(dataloader)))

[1m[34mTrain[0m || iter=  100   loss=6376268288  error=131.37  log_theta_noise_var=-2.00
[1m[34mTrain[0m || iter=  200   loss=4611614208  error=111.72  log_theta_noise_var=-2.00
[1m[34mTrain[0m || iter=  300   loss=3406280704  error=96.02  log_theta_noise_var=-2.00
[1m[34mTrain[0m || iter=  400   loss=2329985280  error=79.41  log_theta_noise_var=-2.00
[1m[34mTrain[0m || iter=  500   loss=1752567680  error=68.87  log_theta_noise_var=-2.00
[1m[34mTrain[0m || iter=  600   loss= 949006272  error=50.68  log_theta_noise_var=-2.00
[1m[34mTrain[0m || iter=  700   loss= 615964544  error=40.83  log_theta_noise_var=-2.00
[1m[34mTrain[0m || iter=  800   loss= 444020640  error=34.67  log_theta_noise_var=-2.00
[1m[34mTrain[0m || iter=  900   loss= 298856448  error=28.44  log_theta_noise_var=-2.00
[1m[34mTrain[0m || iter= 1000   loss= 178509152  error=21.98  log_theta_noise_var=-2.00


In [33]:
model.architecture[0].q_posterior_W.mean

Parameter containing:
tensor([[16.9489],
        [16.5895],
        [10.0688],
        [18.3010],
        [16.0238],
        [ 0.0727],
        [-0.2205],
        [ 0.0074],
        [ 0.1137],
        [ 0.0054]], requires_grad=True)

In [14]:
W

array([81.17220954, 64.40445972, 23.64951166, 81.73135271, 62.05559643,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ])