In [None]:
!pip install tensorflow==2.0
!pip install tensorflow_probability

# In this tutorial, we will implement a simple linear regression model with PAC Bayesian SGD method

In [1]:
import numpy as np
np.random.seed(41)

In [2]:
# configs
n_sample = 100  # number of samples in training set
dim = 5  # dimension of feature vector for each sample

# Section 1: Data preparation
## generate $X$: $n$ samples of dimension $d$ ~ $\mathcal{N}(0, diag[\sigma_1^2, \sigma_2^2, ..., \sigma_d^2$])

In [3]:
mean_x = np.zeros(dim, dtype=np.float32)
diag_x = np.zeros([dim, dim], dtype=np.float32)
for i, sigma in enumerate(np.sort(np.random.rand(dim))):
    diag_x[i, i] = sigma ** 2
print(diag_x)

[[0.0018896  0.         0.         0.         0.        ]
 [0.         0.00212482 0.         0.         0.        ]
 [0.         0.         0.01355448 0.         0.        ]
 [0.         0.         0.         0.06296267 0.        ]
 [0.         0.         0.         0.         0.45808023]]


In [4]:
# generate X
X = np.random.multivariate_normal(mean_x, diag_x, n_sample)
print(X.shape)
print(X[0:2])

(100, 5)
[[-0.07773376  0.01118072  0.06995384  0.10453022 -0.83898637]
 [-0.08650132 -0.00944478 -0.06890632 -0.2147365  -0.58603212]]


## generate $w^*$ ~ $P_{true}$, where $P_{true} := \mathcal{N}(0, \lambda I)$

In [5]:
# ground truth weight distribution P:
diag_lambda = 5 # ground truth for P diags
diag = np.zeros([dim, dim], dtype=np.float32)
for i in range(len(diag)):
    diag[i, i] = diag_lambda
print(diag)
# MSE weight
w_star = np.random.multivariate_normal([0]*dim, diag, 1)[0]
print(w_star)

[[5. 0. 0. 0. 0.]
 [0. 5. 0. 0. 0.]
 [0. 0. 5. 0. 0.]
 [0. 0. 0. 5. 0.]
 [0. 0. 0. 0. 5.]]
[-1.42438729  1.77600611 -1.39529334  8.73640413 -1.69140474]


## generate y: $n$ labels ~ $X^T \cdot w^* + \epsilon$. with $\epsilon$  ~ $\mathcal{N}(0, I)$

In [6]:
# noise add to y
epsilon = np.random.normal(0, 1, n_sample)

In [7]:
# noisy labels
y = np.dot(X, w_star) + epsilon
print(y[0:10])

[ 1.72517815 -1.86011148 -1.91710548 -1.39267423 -0.28523969 -1.55699121
  2.74346011 -4.57253723 -7.19923035 -4.08182988]


In [8]:
print(X.shape)
print(y.shape)
print(w_star)

# groundtruth parameter for P
print(diag_lambda)
# groundtruth parameter for Q
print(diag_x)

(100, 5)
(100,)
[-1.42438729  1.77600611 -1.39529334  8.73640413 -1.69140474]
5
[[0.0018896  0.         0.         0.         0.        ]
 [0.         0.00212482 0.         0.         0.        ]
 [0.         0.         0.01355448 0.         0.        ]
 [0.         0.         0.         0.06296267 0.        ]
 [0.         0.         0.         0.         0.45808023]]


# Section 2. Define the model

In [9]:

import math
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
import itertools

## Below is a simple implementation using equation just above section 3.1 in the $\href{https://arxiv.org/abs/1703.11008}{paper}$ used in the lecture 

In [18]:
class Linear_Simple:
    def __init__(self, n_samples, dimension, delta, learning_rate=0.001, momentum=0.9):
        self.d = dimension
        self.m = n_samples
        self.delta = delta # delta represents belief
        self.optimizer = tf.optimizers.SGD(learning_rate=learning_rate, momentum=momentum)
        self._build_model()
        
        
    def _build_model(self):
        self.weight = tf.Variable(dtype=tf.float32, name="weights", shape=self.d, initial_value=tf.zeros(shape=self.d), trainable=True)
        self.bias = tf.Variable(dtype=tf.float32, name="bias", shape=1, initial_value=tf.zeros(shape=1), trainable=True)
        # define learnable distribution Q
        self.s = tf.Variable(dtype=tf.float32, name='learnable_diag', shape=self.d, initial_value=tf.ones(shape=self.d), trainable=True)
        self.Q = tfd.MultivariateNormalDiag(loc=self.weight, scale_diag=self.s*self.s)
        # define P as fixed distribution
        self.lamda = [3.0]
        self.P = tfd.MultivariateNormalDiag(loc=tf.zeros(self.d), scale_diag=tf.tile(self.lamda, [self.d]))
        self.trainable_variables = [self.weight, self.bias, self.s]
        
    def _sample_si(self):
        return np.random.multivariate_normal([.0]*self.d, np.identity(self.d, dtype=np.float32), 1)[0]
        
    def compute_loss(self, predictions, labels):
        empirical_loss = tf.nn.l2_loss(predictions-labels)
        KL_divergence = tfd.kl_divergence(distribution_a=self.Q, distribution_b=self.P) # compute KL(Q||P)
        RE_loss = (KL_divergence + tf.math.log(self.m/self.delta)) / (2*self.m - 2)
        loss = empirical_loss + tf.math.sqrt(RE_loss)
        return empirical_loss, loss
        
    def predict(self, inputs):
        # apply sgd described in paper: section 3.2
        new_weight = self.weight + tf.multiply(self._sample_si(), self.s)
        self.scores = tf.reduce_sum(tf.multiply(inputs, new_weight)) + self.bias
        return self.scores
    
    def train(self, dataset, epoch=3, print_step_loss=False):
        its = itertools.tee(dataset, epoch)
        for e in range(epoch):
            _loss = 0
            _empirical_loss = 0
            for x, y in its[e]:
                with tf.GradientTape() as tape:
                    scores = self.predict(x)
                    empi_loss, loss = self.compute_loss(scores, y)
                    _loss += loss
                    _empirical_loss += empi_loss
                    if print_step_loss:
                        print(loss.numpy())
                gradients = tape.gradient(loss, self.trainable_variables)
                self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
            print("epoch %d ----total loss: %f ---- empirical loss: %f ----" % (e, _loss/self.m, _empirical_loss/self.m))
            self._print_stats()
            
    def _print_stats(self):
        print("diag s: " + str((self.s * self.s).numpy()))
        

In [19]:
model = Linear_Simple(n_samples=n_sample, dimension=dim, delta=0.025)
model.train(zip(X, y), epoch=100, print_step_loss=False)

epoch 0 ----total loss: 3.655353 ---- empirical loss: 3.413587 ----
diag s: [0.99623805 0.9535061  0.91742533 0.9207604  0.6110957 ]
epoch 1 ----total loss: 3.276746 ---- empirical loss: 3.034504 ----
diag s: [1.0243871 0.9325476 0.87739   1.0357976 0.1197383]
epoch 2 ----total loss: 2.882580 ---- empirical loss: 2.639571 ----
diag s: [0.99349666 0.9014705  0.8696092  0.7952672  0.25136375]
epoch 3 ----total loss: 2.609730 ---- empirical loss: 2.365815 ----
diag s: [0.9644141  0.9125296  0.8947794  0.81660473 0.08865781]
epoch 4 ----total loss: 2.438571 ---- empirical loss: 2.193716 ----
diag s: [0.9592499  0.8971777  0.8792443  0.82848495 0.06263303]
epoch 5 ----total loss: 2.160573 ---- empirical loss: 1.914661 ----
diag s: [0.958031   0.9159454  0.8719184  0.9949781  0.00781012]
epoch 6 ----total loss: 2.104427 ---- empirical loss: 1.857488 ----
diag s: [0.94934094 0.9320908  0.85064715 1.0609933  0.09336017]
epoch 7 ----total loss: 1.986336 ---- empirical loss: 1.738313 ----
diag s

epoch 60 ----total loss: 0.906198 ---- empirical loss: 0.618124 ----
diag s: [7.7021623e-01 7.5025135e-01 2.5929669e-01 5.2441974e-05 1.0213447e-02]
epoch 61 ----total loss: 0.911700 ---- empirical loss: 0.623407 ----
diag s: [7.5742531e-01 7.5197506e-01 2.5082108e-01 5.4338252e-06 1.6061610e-02]
epoch 62 ----total loss: 0.926239 ---- empirical loss: 0.637741 ----
diag s: [7.3811185e-01 7.5102127e-01 2.3607901e-01 6.5578548e-05 7.4039580e-04]
epoch 63 ----total loss: 0.901651 ---- empirical loss: 0.612938 ----
diag s: [7.3107642e-01 7.6906186e-01 2.2763495e-01 4.1164723e-04 2.7640630e-04]
epoch 64 ----total loss: 0.908737 ---- empirical loss: 0.619844 ----
diag s: [7.2637564e-01 7.6782155e-01 2.1154721e-01 5.8265770e-04 3.6008372e-05]
epoch 65 ----total loss: 0.911898 ---- empirical loss: 0.622842 ----
diag s: [0.7209343  0.76959854 0.1909479  0.00142279 0.00216507]
epoch 66 ----total loss: 0.906938 ---- empirical loss: 0.617722 ----
diag s: [7.2246236e-01 7.6516193e-01 1.7722863e-01 1

In [20]:
print("groundtruth data generation diag: ")
print(diag_x)
print("learned s: ")
print((model.s **2) .numpy())
print("ground_truth weights: ")
print(w_star)
print("learned weights: ")
print(model.weight.numpy())
print("ground truth Lambda for P: ")
print(diag_lambda)
print("pre-fixed Lambda for P: ")
print(model.lamda[0])

groundtruth data generation diag: 
[[0.0018896  0.         0.         0.         0.        ]
 [0.         0.00212482 0.         0.         0.        ]
 [0.         0.         0.01355448 0.         0.        ]
 [0.         0.         0.         0.06296267 0.        ]
 [0.         0.         0.         0.         0.45808023]]
learned s: 
[0.6234379  0.7558212  0.10575011 0.00294373 0.00657509]
ground_truth weights: 
[-1.42438729  1.77600611 -1.39529334  8.73640413 -1.69140474]
learned weights: 
[-0.00957734  0.40996817 -0.7989632   9.566933   -1.80343   ]
ground truth Lambda for P: 
5
pre-fixed Lambda for P: 
3.0


## Below is the implementation of eq(4) in original paper. where lambda is also a learnable parameter

In [22]:
class Linear_Paper(Linear_Simple):
    def __init__(self, n_samples, dimension, b, c, delta, learning_rate=0.001, momentum=0.9):
        self.b = b
        self.c = c
        super().__init__(n_samples, dimension, delta, learning_rate, momentum)
        
    def _build_model(self):
        super()._build_model()
        # re-define prior distribution P    
        self.lamda = tf.Variable(dtype=tf.float32, name="p_lambda", shape=1, initial_value=[0.01], trainable=True) # diagnals of prior distribution P
        self.P = tfd.MultivariateNormalDiag(loc=tf.zeros(self.d), scale_diag=tf.tile(self.lamda, [self.d]))
        self.trainable_variables.append(self.lamda)
                
    def compute_loss(self, predictions, labels):
        empirical_loss = tf.nn.l2_loss(predictions-labels)
        KL_divergence = tfd.kl_divergence(distribution_a=self.Q, distribution_b=self.P) # compute KL(Q||P)
        RE_loss = (KL_divergence + \
                   2*tf.math.log(tf.clip_by_value(self.b * tf.math.log(self.c/self.lamda), 1e-5, 1e30)) + \
                   tf.math.log(math.pi ** 2 * self.m/(6*self.delta))) / (self.m-1)
        loss = empirical_loss + tf.math.sqrt(tf.clip_by_value(RE_loss, 0, 1e30)/2)
        return empirical_loss, loss
            
    def _print_stats(self):
        super()._print_stats()
        print("lambda: %f" % self.lamda.numpy())


In [23]:
model = Linear_Paper(n_samples=n_sample, dimension=dim, b=100, c=0.1, delta=0.025, learning_rate=0.001)
model.train(zip(X, y), epoch=100)

epoch 0 ----total loss: 14.462093 ---- empirical loss: 3.173203 ----
diag s: [1.0243852  0.96995956 1.0423746  1.0582445  1.0367427 ]
lambda: 0.024215
epoch 1 ----total loss: 15.120686 ---- empirical loss: 3.761306 ----
diag s: [1.0448482  0.97087973 1.0614083  0.7690095  0.42701343]
lambda: 0.036674
epoch 2 ----total loss: 14.661912 ---- empirical loss: 3.278697 ----
diag s: [1.0522382  1.0089128  1.0471274  0.55457675 0.24691749]
lambda: 0.048890
epoch 3 ----total loss: 14.879770 ---- empirical loss: 3.510389 ----
diag s: [1.0590322  0.9898903  1.0322747  0.39413747 0.02116182]
lambda: 0.062361
epoch 4 ----total loss: 14.709911 ---- empirical loss: 3.334402 ----
diag s: [1.0211291  0.96767324 0.9804744  0.44459367 0.00161812]
lambda: 0.080127
epoch 5 ----total loss: 14.814976 ---- empirical loss: 3.443497 ----
diag s: [1.0082163e+00 9.5659113e-01 8.5268956e-01 3.4474167e-01 1.7074196e-04]
lambda: 0.111891
epoch 6 ----total loss: 14.703222 ---- empirical loss: 3.338418 ----
diag s: [1

epoch 53 ----total loss: 14.689903 ---- empirical loss: 3.323797 ----
diag s: [0.79361385 0.7716567  0.27124175 0.34702963 0.09714367]
lambda: 0.111965
epoch 54 ----total loss: 14.658734 ---- empirical loss: 3.301307 ----
diag s: [0.7838603  0.77097636 0.27328548 0.33627555 0.18235673]
lambda: 0.111965
epoch 55 ----total loss: 14.718010 ---- empirical loss: 3.349350 ----
diag s: [0.7560097  0.78009814 0.27347857 0.38088134 0.00342036]
lambda: 0.111965
epoch 56 ----total loss: 14.644016 ---- empirical loss: 3.282423 ----
diag s: [0.73358405 0.75605416 0.3025157  0.30455548 0.14970613]
lambda: 0.111965
epoch 57 ----total loss: 14.750640 ---- empirical loss: 3.389775 ----
diag s: [0.7467659  0.73872507 0.29301426 0.28415388 0.06428477]
lambda: 0.111965
epoch 58 ----total loss: 14.609859 ---- empirical loss: 3.245651 ----
diag s: [0.8003574  0.7510757  0.30130693 0.40499204 0.00636972]
lambda: 0.111965
epoch 59 ----total loss: 14.726833 ---- empirical loss: 3.356665 ----
diag s: [0.8034949

In [24]:
print("groundtruth data generation diag: ")
print(diag_x)
print("learned s: ")
print((model.s **2) .numpy())
print("ground_truth weights: ")
print(w_star)
print("learned weights: ")
print(model.weight.numpy())
print("ground truth Lambda for P: ")
print(diag_lambda)
print("learned Lambda for P: ")
print(model.lamda.numpy())

groundtruth data generation diag: 
[[0.0018896  0.         0.         0.         0.        ]
 [0.         0.00212482 0.         0.         0.        ]
 [0.         0.         0.01355448 0.         0.        ]
 [0.         0.         0.         0.06296267 0.        ]
 [0.         0.         0.         0.         0.45808023]]
learned s: 
[0.47962904 0.48603502 0.06361027 0.3860183  0.00095246]
ground_truth weights: 
[-1.42438729  1.77600611 -1.39529334  8.73640413 -1.69140474]
learned weights: 
[ 0.00241566  0.00484082 -0.01352639  0.20059681 -0.25651398]
ground truth Lambda for P: 
5
learned Lambda for P: 
[0.11196508]
