# Simple Linear Regression

In [None]:
import math
import time
import random
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from scipy.stats import norm

In [None]:
SAVEFIG = False

plt.rcParams["font.family"] = "serif"
plt.rcParams["mathtext.fontset"] = "dejavuserif"
plt.rcParams["figure.figsize"] = (3.5, 3)
plt.rcParams["font.size"] = 15
plt.rcParams["figure.titlesize"] = 27
plt.rcParams["axes.labelsize"] = 27
plt.rcParams["xtick.labelsize"] = 20
plt.rcParams["ytick.labelsize"] = 20
plt.rcParams["legend.fontsize"] = 13
plt.rcParams["lines.linewidth"] = 2

# plt.rcParams["font.family"] = "serif"
# plt.rcParams["mathtext.fontset"] = "dejavuserif"
# plt.rcParams["figure.figsize"] = (8, 8)
# plt.rcParams["font.size"] = 30
# plt.rcParams["figure.titlesize"] = 53
# plt.rcParams["axes.labelsize"] = 53
# plt.rcParams["xtick.labelsize"] = 40
# plt.rcParams["ytick.labelsize"] = 40
# plt.rcParams["legend.fontsize"] = 28
# plt.rcParams["lines.linewidth"] = 4


In [None]:
def simple_linear_regression(xs, ws):
    xs = tf.expand_dims(xs, axis=-1)
    xs = tf.concat([xs, tf.ones(xs.shape, dtype=tf.float32)], axis=-1)
    ys = tf.reduce_sum(ws * xs, axis=-1)  # ys = w[:,0] * xs + w[:,1]
    return ys


def temporal_smoothing(ys, l, wsize=5):
    weight = exp_decay(l, wsize)
    
    ys_pad = ys[:,:wsize]
    ys = tf.stack([ys[:,i:i-wsize] for i in range(wsize)], axis=-1)
    ys = tf.tensordot(ys, weight, axes=[[2], [0]])
    ys = tf.concat([ys_pad, ys], axis=-1)
    
    return ys


def exp_decay(l, wsize=5):
    weight = (tf.range(wsize, dtype=tf.float32) - wsize) * l
    weight = tf.math.exp(weight)
    weight = weight / tf.reduce_sum(weight)
    return weight


In [None]:
tf.random.set_seed(1)

v = 0.1
x_var = 0.1

start = -10.0
batch_size = 20
steps = 101

t = tf.range(steps, dtype=tf.float32)
xs = start + t * v + tf.random.normal([batch_size, steps]) * x_var
# xs = tf.concat([xs, tf.constant([[0.4]] * batch_size)], axis=1)  # Noise

w_mean = [1.0, 0.0]
w_var = [0.02, 0.2]
ws_true = w_mean + tf.random.normal(list(xs.shape) + [2]) * w_var
ys_true = simple_linear_regression(xs, ws_true)

print('xs: ...,', ', '.join([str(x.numpy()) for x in xs[0][-5:]]))
print('ys_true: ...,', ', '.join([str(y.numpy()) for y in ys_true[0][-5:]]))

In [None]:
# xs.shape: [batch_size, steps]
# ws.shape: [n_ff, batch_size, steps, 2]
# ys.shape: [n_ff, batch_size, steps]
# ss.shape: [n_ff, batch_size, steps]

# DNN
w_mean = [1.0, 0.0]
w_var = [0.0, 0.0]
w_dnn = [w_mean, w_var]
ws_dnns = tf.stack([w_mean + tf.random.normal(list(xs.shape) + [2]) * w_var for _ in range(1)], axis=0)
ys_dnn = tf.stack([simple_linear_regression(xs, ws_dnn) for ws_dnn in ws_dnns], axis=0)
ss_dnn = tf.ones([1, batch_size, steps])

# BNN
n_ff = 30
w_mean = [1.0, 0.0]
w_var = [0.02, 0.2]
w_bnn = [w_mean, w_var]
ws_bnns = tf.stack([w_mean + tf.random.normal(list(xs.shape) + [2]) * w_var for _ in range(n_ff)], axis=0)
ys_bnn = tf.stack([simple_linear_regression(xs, ws_bnn) for ws_bnn in ws_bnns], axis=0)
ss_bnn = tf.ones([n_ff, batch_size, steps]) / n_ff

# VQ-DNN
ys_vqdnn = temporal_smoothing(simple_linear_regression(xs, ws_dnns[0]), l=1.0)
ss_vqdnn = tf.reshape(exp_decay(0.7, 5), [5, 1, 1])
ss_vqdnn = tf.broadcast_to(ss_vqdnn, [5, batch_size, steps])

# VQ-BNN
ys_vqbnn = temporal_smoothing(simple_linear_regression(xs, ws_bnns[0]), l=1.0)
ss_vqdnn = tf.reshape(exp_decay(0.7, 5), [5, 1, 1])
ss_vqdnn = tf.broadcast_to(ss_vqdnn, [5, batch_size, steps])



In [None]:
def xw0plot(fig, inner_grid, xs, ws, ss, color, weight_scale=500):
    x_lim = -1.0, 1.0
    y_lim = 0.9, 1.1
        
    # p(x, w)
    ax = fig.add_subplot(inner_grid[2])
    
    ax.set_xlim(x_lim[0], x_lim[1])
    ax.set_ylim(y_lim[0], y_lim[1])
    ax.set_xlabel("$x$")
    ax.set_ylabel("$w_{0}$")
    ax.set_yticks([0.9, 1.0, 1.1])
    
    w_lim = w_bnn[0][0] - w_bnn[1][0], w_bnn[0][0] + w_bnn[1][0]
    ax.fill_between(np.linspace(x_lim[0] , x_lim[1], 10), [w_lim[0]] * 10, [w_lim[1]] * 10, color='gainsboro', alpha=0.5)
    ax.fill_between(np.linspace(-0.2, 0.2, 10), [y_lim[0]] * 10, [y_lim[1]] * 10, color='gainsboro', alpha=0.5)
    ax.plot(np.linspace(x_lim[0], x_lim[1], 10), [(w_lim[0] + w_lim[1]) / 2] * 10, linestyle=(0, (1, 1)), color='black')
    ax.plot([0] * 10, np.linspace(0.9, 1.1, 10), linestyle=(0, (1, 1)), color='black')
    ax.scatter(xs, ws, s=tf.stack([s * weight_scale for s in ss]), facecolors='none', edgecolors=color, marker='o')

    
    # p(x)
    ax = fig.add_subplot(inner_grid[0])
    dom = np.linspace(x_lim[0], x_lim[1], 100)
    px_true = norm.pdf(dom, 0.0, 0.2)
    px = tf.reduce_sum([s * norm.pdf(dom, x, 0.05) for s, x in zip(ss, xs)], axis=0)
    ax.plot(dom, px_true, color='black', alpha=0.5)
    ax.plot(dom, px, color=color)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlim(x_lim[0], x_lim[1])
    ax.set_ylim(bottom=0, top=norm.pdf(0, 0, 0.05))
    
    
    # p(w)
    ax = fig.add_subplot(inner_grid[3])
    dom = np.linspace(y_lim[0], y_lim[1], 100)
    pw_true = norm.pdf(dom, 1.0, 0.02)
    pw = tf.reduce_sum([s * norm.pdf(dom, w, 0.005) for s, w in zip(ss, ws)], axis=0)
    ax.plot(pw_true, dom, color='black', alpha=0.5)
    ax.plot(pw, dom, color=color)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_ylim(y_lim[0], y_lim[1])
    ax.set_xlim(left=0, right=norm.pdf(0, 0, 0.005))
    
    
def xw1plot(fig, inner_grid, xs, ws, ss, color, weight_scale=500):
    x_lim = -1.0, 1.0
    y_lim = -1.0, 1.0
    
    # p(x, w)
    ax = fig.add_subplot(inner_grid[2])
    
    ax.set_xlim(x_lim[0], x_lim[1])
    ax.set_ylim(y_lim[0], y_lim[1])
    ax.set_xlabel("$x$")
    ax.set_ylabel("$w_{1}$")
    ax.set_yticks([-1.0, 0.0, 1.0])
    
    w_lim = w_bnn[0][1] - w_bnn[1][1], w_bnn[0][1] + w_bnn[1][1]
    ax.fill_between(np.linspace(x_lim[0] , x_lim[1], 10), [w_lim[0]] * 10, [w_lim[1]] * 10, color='gainsboro', alpha=0.5)
    ax.fill_between(np.linspace(-0.2, 0.2, 10), [y_lim[0]] * 10, [y_lim[1]] * 10, color='gainsboro', alpha=0.5)
    ax.plot(np.linspace(x_lim[0], x_lim[1], 10), [(w_lim[0] + w_lim[1]) / 2] * 10, linestyle=(0, (1, 1)), color='black')
    ax.plot([0] * 10, np.linspace(-1.0, 1.0, 10), linestyle=(0, (1, 1)), color='black')
    ax.scatter(xs, ws, s=tf.stack([s * weight_scale for s in ss]), facecolors='none', edgecolors=color, marker='o')
    
    
    # p(x)
    ax = fig.add_subplot(inner_grid[0])
    dom = np.linspace(x_lim[0], x_lim[1], 100)
    px_true = norm.pdf(dom, 0.0, 0.2)
    px = tf.reduce_sum([s * norm.pdf(dom, x, 0.05) for s, x in zip(ss, xs)], axis=0)
    ax.plot(dom, px_true, color='black', alpha=0.5)
    ax.plot(dom, px, color=color)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlim(x_lim[0], x_lim[1])
    ax.set_ylim(bottom=0, top=norm.pdf(0, 0, 0.05))
    
    
    # p(w)
    ax = fig.add_subplot(inner_grid[3])
    dom = np.linspace(y_lim[0], y_lim[1], 100)
    pw_true = norm.pdf(dom, 0.0, 0.2)
    pw = tf.reduce_sum([s * norm.pdf(dom, w, 0.05) for s, w in zip(ss, ws)], axis=0)
    ax.plot(pw_true, dom, color='black', alpha=0.5)
    ax.plot(pw, dom, color=color)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_ylim(y_lim[0], y_lim[1])
    ax.set_xlim(left=0, right=norm.pdf(0, 0, 0.05))
    

    
def xyplot(fig, inner_grid, xs, ys, ss, color, weight_scale=500):
    x_lim = -1.0, 1.0
    y_lim = -1.0, 1.0
    
    # p(x, y)
    ax = fig.add_subplot(inner_grid[2])
        
    ax.set_xlim(x_lim[0], x_lim[1])
    ax.set_ylim(y_lim[0], y_lim[1])
    ax.set_xlabel("$x$")
    ax.set_ylabel("$y_{\quad}$")
    ax.set_yticks([-1.0, 0.0, 1.0])
    
    w_lim = w_bnn[0][1] - w_bnn[1][1], w_bnn[0][1] + w_bnn[1][1]
    ax.fill_between(np.linspace(x_lim[0] , x_lim[1], 10), [w_lim[0]] * 10, [w_lim[1]] * 10, color='gainsboro', alpha=0.5)
    ax.fill_between(np.linspace(-0.2, 0.2, 10), [y_lim[0]] * 10, [y_lim[1]] * 10, color='gainsboro', alpha=0.5)
    ax.plot(np.linspace(x_lim[0], x_lim[1], 10), [(w_lim[0] + w_lim[1]) / 2] * 10, linestyle=(0, (1, 1)), color='black')
    ax.plot([0] * 10, np.linspace(-1.0, 1.0, 10), linestyle=(0, (1, 1)), color='black')
    ax.scatter(xs, ys, s=tf.stack([s * weight_scale for s in ss]), facecolors='none', edgecolors=color, marker='o')
    
    
    # p(x)
    ax = fig.add_subplot(inner_grid[0])
    dom = np.linspace(x_lim[0], x_lim[1], 100)
    px_true = norm.pdf(dom, 0.0, 0.2)
    px = tf.reduce_sum([s * norm.pdf(dom, x, 0.05) for s, x in zip(ss, xs)], axis=0)
    ax.plot(dom, px_true, color='black', alpha=0.5)
    ax.plot(dom, px, color=color)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlim(x_lim[0], x_lim[1])
    ax.set_ylim(bottom=0, top=norm.pdf(0, 0, 0.05))
    
    
    # p(y)
    ax = fig.add_subplot(inner_grid[3])
    dom = np.linspace(y_lim[0], y_lim[1], 100)
    py_true = norm.pdf(dom, 0.0, 0.2)
    py = tf.reduce_sum([s * norm.pdf(dom, y, 0.05) for s, y in zip(ss, ys)], axis=0)
    ax.plot(py_true, dom, color='black', alpha=0.5)
    ax.plot(py, dom, color=color)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_ylim(y_lim[0], y_lim[1])
    ax.set_xlim(left=0, right=norm.pdf(0, 0, 0.05))
    

def cond_grid(outer_grid, i):
    return outer_grid[i].subgridspec(2, 2, wspace=0.03, hspace=0.03, width_ratios=[3, 1], height_ratios=[1, 3])


fig = plt.figure(figsize=(14, 10), constrained_layout=False)

# gridspec inside gridspec
outer_grid = fig.add_gridspec(3, 4, wspace=0.5, hspace=0.3)

# DNN
xs_dnn = tf.stack([xs[0,-1]])
xw0plot(fig, cond_grid(outer_grid, 0), xs_dnn, ws_dnns[:,0,-1,0], ss=ss_dnn[:,0,-1], color='tab:blue')
xw1plot(fig, cond_grid(outer_grid, 4), xs_dnn, ws_dnns[:,0,-1,1], ss=ss_dnn[:,0,-1], color='tab:blue') 
xyplot(fig, cond_grid(outer_grid, 8), xs_dnn, ys_dnn[:,0,-1], ss=ss_dnn[:,0,-1], color='tab:blue') 

# BNN
xs_bnn = tf.stack([xs[0,-1]] * n_ff)
xw0plot(fig, cond_grid(outer_grid, 1), xs_bnn, ws_bnns[:,0,-1,0], ss=ss_bnn[:,0,-1], color='tab:green')
xw1plot(fig, cond_grid(outer_grid, 5), xs_bnn, ws_bnns[:,0,-1,1], ss=ss_bnn[:,0,-1], color='tab:green') 
xyplot(fig, cond_grid(outer_grid, 9), xs_bnn, ys_bnn[:,0,-1], ss=ss_bnn[:,0,-1], color='tab:green') 

# VQ-DNN
xs_vqdnn = xs[0,-5:]
xw0plot(fig, cond_grid(outer_grid, 2), xs_vqdnn, ws_dnns[0,0,-5:,0], ss=ss_vqdnn[:,0,-1], color='tab:purple')
xw1plot(fig, cond_grid(outer_grid, 6), xs_vqdnn, ws_dnns[0,0,-5:,1], ss=ss_vqdnn[:,0,-1], color='tab:purple') 
xyplot(fig, cond_grid(outer_grid, 10), xs_vqdnn, ys_dnn[0,0,-5:], ss=ss_vqdnn[:,0,-1], color='tab:purple') 

# VQ-BNN
xs_vqbnn = xs[0,-5:]
xw0plot(fig, cond_grid(outer_grid, 3), xs_vqbnn, ws_bnns[0,0,-5:,0], ss=ss_vqdnn[:,0,-1], color='tab:red')
xw1plot(fig, cond_grid(outer_grid, 7), xs_vqbnn, ws_bnns[0,0,-5:,1], ss=ss_vqdnn[:,0,-1], color='tab:red') 
xyplot(fig, cond_grid(outer_grid, 11), xs_vqbnn, ys_bnn[0,0,-5:], ss=ss_vqdnn[:,0,-1], color='tab:red') 


