In [1]:
!pip install tensorflow==2.0
!pip install tensorflow_probability



In [2]:
# In this tutorial, we will implement a simple linear regression model 
# with PAC Bayesian SGD method
import numpy as np
np.random.seed(41)

In [3]:
# configs
n_sample = 100  # number of samples in training set
dim = 5  # dimension of feature vector for each sample

In [4]:
# Section 1: Data preparation
# generate X: n samples of dimension d ~ N(0, diag(sigma_1^2, sigma_2^2, ..., sigma_d^2))
mean_x = np.zeros(dim, dtype=np.float32)
diag_x = np.zeros([dim, dim], dtype=np.float32)
for i, sigma in enumerate(np.sort(np.random.rand(dim))):
    diag_x[i, i] = sigma ** 2
print(diag_x)

[[0.0018896  0.         0.         0.         0.        ]
 [0.         0.00212482 0.         0.         0.        ]
 [0.         0.         0.01355448 0.         0.        ]
 [0.         0.         0.         0.06296267 0.        ]
 [0.         0.         0.         0.         0.45808023]]


In [5]:
# generate X
X = np.random.multivariate_normal(mean_x, diag_x, n_sample)
print(X.shape)
print(X[0:2])

(100, 5)
[[-0.07773376  0.01118072  0.06995384  0.10453022 -0.83898637]
 [-0.08650132 -0.00944478 -0.06890632 -0.2147365  -0.58603212]]


In [6]:
# generate y: n samples of dimension 1 ~ X^T . w* + epsilon. with epsilon ~ N(0, I)
# ground truth weight distribution P:
diag_lambda = 5 # ground truth for P diags
diag = np.zeros([dim, dim], dtype=np.float32)
for i in range(len(diag)):
    diag[i, i] = diag_lambda
print(diag)
# MSE weight
w_star = np.random.multivariate_normal([0]*dim, diag, 1)[0]
print(w_star)

[[5. 0. 0. 0. 0.]
 [0. 5. 0. 0. 0.]
 [0. 0. 5. 0. 0.]
 [0. 0. 0. 5. 0.]
 [0. 0. 0. 0. 5.]]
[-1.42438729  1.77600611 -1.39529334  8.73640413 -1.69140474]


In [7]:
# noise add to y
epsilon = np.random.normal(0, 1, n_sample)

In [8]:
# noisy labels
y = np.dot(X, w_star) + epsilon
print(y[0:10])

[ 1.72517815 -1.86011148 -1.91710548 -1.39267423 -0.28523969 -1.55699121
  2.74346011 -4.57253723 -7.19923035 -4.08182988]


In [23]:
print(X.shape)
print(y.shape)
print(w_star)

# groundtruth parameter for P
print(diag_lambda)
# groundtruth parameter for Q
print(diag_x)

(100, 5)
(100,)
[-1.42438729  1.77600611 -1.39529334  8.73640413 -1.69140474]
5
[[0.0018896  0.         0.         0.         0.        ]
 [0.         0.00212482 0.         0.         0.        ]
 [0.         0.         0.01355448 0.         0.        ]
 [0.         0.         0.         0.06296267 0.        ]
 [0.         0.         0.         0.         0.45808023]]


In [9]:
# Section 2: 
# define the model
import math
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
import itertools

# This is a simple implementation using equation just above section 3.1 in paper

In [10]:
class Linear_Simple:
    def __init__(self, n_samples, dimension, delta, learning_rate=0.001, momentum=0.9):
        self.d = dimension
        self.m = n_samples
        self.delta = delta # delta represents belief
        self.optimizer = tf.optimizers.SGD(learning_rate=learning_rate, momentum=momentum)
        self._build_model()
        
        
    def _build_model(self):
        self.weight = tf.Variable(dtype=tf.float32, name="weights", shape=self.d, initial_value=tf.zeros(shape=self.d), trainable=True)
        self.bias = tf.Variable(dtype=tf.float32, name="bias", shape=1, initial_value=tf.zeros(shape=1), trainable=True)
        # define learnable distribution Q
        self.s = tf.Variable(dtype=tf.float32, name='learnable_diag', shape=self.d, initial_value=tf.ones(shape=self.d), trainable=True)
        self.Q = tfd.MultivariateNormalDiag(loc=self.weight, scale_diag=self.s*self.s)
        # define P as fixed distribution
        self.lamda = [3.0]
        self.P = tfd.MultivariateNormalDiag(loc=tf.zeros(self.d), scale_diag=tf.tile(self.lamda, [self.d]))
        self.trainable_variables = [self.weight, self.bias, self.s]
        
    def _sample_phi(self):
        return np.random.multivariate_normal([.0]*self.d, np.identity(self.d, dtype=np.float32), 1)[0]
        
    def compute_loss(self, predictions, labels):
        empirical_loss = tf.nn.l2_loss(predictions-labels)
        KL_divergence = tfd.kl_divergence(distribution_a=self.Q, distribution_b=self.P) # compute KL(Q||P)
        RE_loss = (KL_divergence + tf.math.log(self.m/self.delta)) / (2*self.m - 2)
        loss = empirical_loss + tf.math.sqrt(RE_loss)
        return loss
        
    def predict(self, inputs):
        # apply sgd described in paper: section 3.2
        new_weight = self.weight + tf.multiply(self._sample_phi(), self.s)
        self.scores = tf.reduce_sum(tf.multiply(inputs, new_weight)) + self.bias
        return self.scores
    
    def train(self, dataset, epoch=3, print_step_loss=False):
        its = itertools.tee(dataset, epoch)
        for e in range(epoch):
            _loss = 0
            for x, y in its[e]:
                with tf.GradientTape() as tape:
                    scores = self.predict(x)
                    loss = self.compute_loss(scores, y)
                    _loss += loss
                    if print_step_loss:
                        print(loss.numpy())
                gradients = tape.gradient(loss, self.trainable_variables)
                self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
            print("average loss at epoch %d is: %f" % (e, _loss/self.m))
            self._print_stats()
            
    def _print_stats(self):
        print("diag s: " + str((self.s * self.s).numpy()))
        

In [11]:
model = Linear_Simple(n_samples=n_sample, dimension=dim, delta=0.025)
model.train(zip(X, y), epoch=100, print_step_loss=False)

average loss at epoch 0 is: 3.740866
diag s: [0.99481803 1.0270623  0.96434236 1.0287461  0.32720926]
average loss at epoch 1 is: 3.234305
diag s: [1.0229533  1.0139452  0.8806895  0.8765629  0.14445315]
average loss at epoch 2 is: 2.917866
diag s: [1.0001037  1.0469124  0.81301045 0.7482896  0.14153744]
average loss at epoch 3 is: 2.634566
diag s: [0.99541485 1.0400496  0.8058162  0.69964737 0.01819591]
average loss at epoch 4 is: 2.426518
diag s: [0.9915569 1.0181191 0.7514363 0.8257075 0.0139191]
average loss at epoch 5 is: 2.391124
diag s: [0.9942878  1.0256128  0.7362773  0.6148264  0.01033538]
average loss at epoch 6 is: 2.170149
diag s: [0.9812073  1.0012332  0.7577909  0.61120325 0.01170056]
average loss at epoch 7 is: 2.057967
diag s: [0.97988075 0.99544257 0.7636441  0.53251815 0.01911303]
average loss at epoch 8 is: 1.962724
diag s: [0.9819547  0.98891675 0.7389337  0.4842021  0.0069253 ]
average loss at epoch 9 is: 1.825257
diag s: [0.98564476 0.98099786 0.74202126 0.470739

average loss at epoch 77 is: 0.884558
diag s: [7.2449881e-01 7.6319820e-01 2.4901277e-01 6.8923654e-03 3.8762728e-04]
average loss at epoch 78 is: 0.905282
diag s: [7.1288460e-01 7.6202625e-01 2.3925944e-01 8.3799008e-03 1.0823606e-05]
average loss at epoch 79 is: 0.908053
diag s: [0.69839317 0.74737465 0.23390509 0.01390431 0.01213733]
average loss at epoch 80 is: 0.894784
diag s: [0.68348575 0.7316585  0.23313184 0.01008356 0.03857528]
average loss at epoch 81 is: 0.901653
diag s: [0.6922533  0.723738   0.239395   0.00522361 0.0289592 ]
average loss at epoch 82 is: 0.900916
diag s: [0.6813152  0.7329053  0.24039073 0.00298153 0.00352013]
average loss at epoch 83 is: 0.890880
diag s: [0.6697032  0.73917454 0.23671174 0.00117891 0.0318952 ]
average loss at epoch 84 is: 0.915721
diag s: [0.65111697 0.7221709  0.21864381 0.00447042 0.02369864]
average loss at epoch 85 is: 0.909448
diag s: [0.6452284  0.72083414 0.21047889 0.01075998 0.00112018]
average loss at epoch 86 is: 0.894349
diag 

In [19]:
print("groundtruth data generation diag: ")
print(diag_x)
print("learned s: ")
print((model.s **2) .numpy())
print("ground_truth weights: ")
print(w_star)
print("learned weights: ")
print(model.weight.numpy())
print("ground truth Lambda for P: ")
print(diag_lambda)
print("learned Lambda for P: ")
print(model.lamda.numpy())

groundtruth data generation diag: 
[[0.0018896  0.         0.         0.         0.        ]
 [0.         0.00212482 0.         0.         0.        ]
 [0.         0.         0.01355448 0.         0.        ]
 [0.         0.         0.         0.06296267 0.        ]
 [0.         0.         0.         0.         0.45808023]]
learned s: 
[0.6603121  0.775492   0.30869764 0.11568683 0.02556862]
ground_truth weights: 
[-1.42438729  1.77600611 -1.39529334  8.73640413 -1.69140474]
learned weights: 
[ 0.00251557  0.00429379 -0.01429341  0.19932954 -0.25943637]
ground truth Lambda for P: 
5
learned Lambda for P: 
[0.11213199]


# Below is the implementation of eq(4) in original paper. where lambda is also a learnable parameter

In [20]:
class Linear_Paper(Linear_Simple):
    def __init__(self, n_samples, dimension, b, c, delta, learning_rate=0.001, momentum=0.9):
        self.b = b
        self.c = c
        super().__init__(n_samples, dimension, delta, learning_rate, momentum)
        
    def _build_model(self):
        self.weight = tf.Variable(dtype=tf.float32, name="weights", shape=self.d, initial_value=tf.zeros(shape=self.d), trainable=True)
        self.bias = tf.Variable(dtype=tf.float32, name="bias", shape=1, initial_value=tf.zeros(shape=1), trainable=True)
        # define learnable distribution Q
        self.s = tf.Variable(dtype=tf.float32, name='learnable_diag', shape=self.d, initial_value=tf.ones(shape=self.d), trainable=True)
        self.Q = tfd.MultivariateNormalDiag(loc=self.weight, scale_diag=self.s*self.s)
        # define prior distribution P    
        self.lamda = tf.Variable(dtype=tf.float32, name="p_lambda", shape=1, initial_value=[0.01], trainable=True) # diagnals of prior distribution P
        self.P = tfd.MultivariateNormalDiag(loc=tf.zeros(self.d), scale_diag=tf.tile(self.lamda, [self.d]))
        self.trainable_variables = [self.weight, self.bias, self.s, self.lamda]
                
    def compute_loss(self, predictions, labels):
        empirical_loss = tf.nn.l2_loss(predictions-labels)
        KL_divergence = tfd.kl_divergence(distribution_a=self.Q, distribution_b=self.P) # compute KL(Q||P)
        RE_loss = (KL_divergence + \
                   2*tf.math.log(tf.clip_by_value(self.b * tf.math.log(self.c/self.lamda), 1e-5, 1e30)) + \
                   tf.math.log(math.pi ** 2 * self.m/(6*self.delta))) / (self.m-1)
        loss = empirical_loss + tf.math.sqrt(tf.clip_by_value(RE_loss, 0, 1e30)/2)
        return loss
            
    def _print_stats(self):
        super()._print_stats()
        print("lambda: %f" % self.lamda.numpy())


In [21]:
model = Linear_Paper(n_samples=n_sample, dimension=dim, b=100, c=0.1, delta=0.025, learning_rate=0.001)
model.train(zip(X, y), epoch=100)

average loss at epoch 0 is: 14.684502
diag s: [1.0091044  0.9613919  0.9781083  1.1547221  0.36836565]
lambda: 0.024212
average loss at epoch 1 is: 14.818404
diag s: [1.0228788  0.9312968  0.8993562  1.0948617  0.18609247]
lambda: 0.036660
average loss at epoch 2 is: 14.886151
diag s: [1.0102277  0.9305457  0.85511667 0.8747031  0.03690664]
lambda: 0.048871
average loss at epoch 3 is: 14.641424
diag s: [0.98379284 0.9136573  0.8367183  0.94107395 0.00554069]
lambda: 0.062324
average loss at epoch 4 is: 14.532815
diag s: [1.0071161  0.8837219  0.87317795 1.191427   0.0121667 ]
lambda: 0.080074
average loss at epoch 5 is: 14.786790
diag s: [1.0150102  0.87750614 0.8720314  0.98344296 0.00385936]
lambda: 0.187454
average loss at epoch 6 is: 14.778208
diag s: [1.0208317  0.84528416 0.8301078  0.795103   0.00484626]
lambda: 0.188014
average loss at epoch 7 is: 14.811518
diag s: [0.9848045  0.8472985  0.75783265 0.5014059  0.09085104]
lambda: 0.188014
average loss at epoch 8 is: 14.786369
di

average loss at epoch 67 is: 14.731799
diag s: [0.67695194 0.69067705 0.32989895 0.19620535 0.02762379]
lambda: 0.188014
average loss at epoch 68 is: 14.782635
diag s: [0.66901714 0.6617494  0.3204606  0.07292069 0.03477235]
lambda: 0.188014
average loss at epoch 69 is: 14.705016
diag s: [0.6653107  0.6533075  0.28954533 0.09839866 0.03473667]
lambda: 0.188014
average loss at epoch 70 is: 14.710004
diag s: [0.6797719  0.6355406  0.2682696  0.14842682 0.00438321]
lambda: 0.188014
average loss at epoch 71 is: 14.675573
diag s: [0.67735237 0.65159756 0.23825891 0.15025994 0.04352284]
lambda: 0.188014
average loss at epoch 72 is: 14.623749
diag s: [0.6515304  0.6413449  0.20173495 0.28359172 0.06813668]
lambda: 0.188014
average loss at epoch 73 is: 14.741642
diag s: [0.6446236  0.63482434 0.1944143  0.25716364 0.01024378]
lambda: 0.188014
average loss at epoch 74 is: 14.656399
diag s: [6.5702826e-01 6.4490384e-01 1.8049215e-01 3.4789693e-01 4.3491463e-04]
lambda: 0.188014
average loss at e

In [22]:
print("groundtruth data generation diag: ")
print(diag_x)
print("learned s: ")
print((model.s **2) .numpy())
print("ground_truth weights: ")
print(w_star)
print("learned weights: ")
print(model.weight.numpy())
print("ground truth Lambda for P: ")
print(diag_lambda)
print("learned Lambda for P: ")
print(model.lamda.numpy())

groundtruth data generation diag: 
[[0.0018896  0.         0.         0.         0.        ]
 [0.         0.00212482 0.         0.         0.        ]
 [0.         0.         0.01355448 0.         0.        ]
 [0.         0.         0.         0.06296267 0.        ]
 [0.         0.         0.         0.         0.45808023]]
learned s: 
[0.5043025  0.57687074 0.01902749 0.03450254 0.00063787]
ground_truth weights: 
[-1.42438729  1.77600611 -1.39529334  8.73640413 -1.69140474]
learned weights: 
[ 0.00248535  0.00475126 -0.01315061  0.19954143 -0.25769135]
ground truth Lambda for P: 
5
learned Lambda for P: 
[0.18801434]
