In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pylab as plt
import tensorflow_probability as tfp
import sympy

# Piecewise constant fits with increasing number of intervals

In [None]:
#Plot data and true function
x_all = np.linspace(0,1,100)
y_true = x_all**2+.5*x_all-.5

x_data = np.random.uniform(.2,.8,15)
y_data = (x_data**2+.5*x_data-.5) + np.random.normal(0,.05,len(x_data))

def plot_data():
    plt.plot(x_all,y_true,label='True function')
    plt.xlabel('x')
    plt.ylabel('y')
    plt.plot(x_data,y_data,'ro',label='Data points')
    
plot_data()
plt.legend()

def makePWC(nknots):
    knots = np.linspace(0,1,nknots)
    def makeX(x):
        return (tf.cast(x[...,None]<knots[1:],tf.float64)*tf.cast(x[...,None]>=knots[0:-1],tf.float64))*1.
    makeX.knots = knots
    return makeX

In [None]:
#Bayesian inference
def inference(makeX,sig_p):
    
    X = makeX(x_data).numpy()
    X_all = makeX(x_all).numpy()

    sig_l = .5e-1

    sig_prior_inv = np.linalg.inv(sig_p+1e-6*np.eye(sig_p.shape[0]))
    Omega = np.linalg.inv(1./sig_l**2*X.T.dot(X) + sig_prior_inv)
    sqrtOmega = np.linalg.cholesky(Omega+1e-6*np.eye(sig_p.shape[0]))
    mu = Omega.dot(X.T.dot(y_data)/sig_l**2)
    nu = X_all.dot(mu)
    sig_y = np.eye(len(x_all))*sig_l**2+X_all.dot(Omega).dot(X_all.T)
    sig_y_diag = np.sqrt(np.diag(sig_y))
    plot_data()
    plt.plot(x_all,nu,label = 'Posterior predictive mean')
    plt.fill_between(x_all,nu-2.*sig_y_diag,nu+2.*sig_y_diag,color='C4',alpha=.3,label=r'95% confidence')
    plt.legend()

In [None]:
#If we don't have too many knots, we get a reasonable approximation
nknots = 6
prior_sig = np.eye(nknots-1)
inference(makePWC(nknots),prior_sig)

In [None]:
#If we add too many knots, we overfit
nknots = 20
prior_sig = np.eye(nknots-1)
inference(makePWC(nknots),prior_sig)

In [None]:
#Let's make a prior assumption that the function is smooth. The coefficients for
#piecewise constant are the function evalutations in intervals. Let's assume adjacent 
#intervals are correlated.
nknots = 20
makeX = makePWC(nknots)

x = (makeX.knots[1:]+makeX.knots[0:-1])/2
sig_p=np.exp(-50*(x[:,None]-x[None])**2)

In [None]:
inference(makeX,sig_p)