A notebook to plot the power criterion functinon of the FSCD test.

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
#%config InlineBackend.figure_format = 'svg'
#%config InlineBackend.figure_format = 'pdf'



In [None]:
import numpy as np
import torch
import torch.distributions as dists

import kcgof
import kcgof.log as klog
import kcgof.util as util
import kcgof.cdensity as cden
import kcgof.cdata as cdat
import kcgof.cgoftest as cgof
import kcgof.kernel as ker
import kcgof.plot as plot

In [None]:
import matplotlib
import matplotlib.pyplot as plt

# font options
font = {
    #'family' : 'normal',
    #'weight' : 'bold',
    'size'   : 20
}

plt.rc('font', **font)
plt.rc('lines', linewidth=2)
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

## Regression problem

In [None]:
dx = 1

slope = torch.tensor([0.5])
noise_variance = 1.0
c = 1.0

# set up the conditional probability model
# f = regression function
f = lambda x: 1.0*x + 0.5*x**2 -1
gaussian_noise = dists.Normal(0, 1)
p = cden.CDAdditiveNoiseRegression(f=f, noise=gaussian_noise, dx=1)

In [None]:
# generate some toy data 
# CondSource
fr = lambda x: 1.0*x + 0.2*x**2 - 1
cs = cdat.CSAdditiveNoiseRegression(f=fr, noise=gaussian_noise, dx=1)

# sample Y conditioned on X
n = 300 # sample size
with util.TorchSeedContext(seed=17):
    px = dists.Normal(1, 1) 
    X = px.sample((n, dx))
Y = cs(X, seed=28)

Plot data, the model, and the power criterion of FSCD.

In [None]:
# kernels
# k = kernel on X
# l = kernel on Y
k = ker.PTKGauss(sigma2=1)
l = ker.PTKGauss(sigma2=1.0)

In [None]:
ep = 0.7
# make a grid that covers X

domX = torch.linspace(torch.min(X)-ep, torch.max(X)+ep, 100)
domY = torch.linspace(torch.min(Y).item()-ep, torch.max(Y).item()+ep, 200)

fscd_pc = cgof.FSCDPowerCriterion(p, k, l, X, Y)


# evaluate the power criterion function
wit_values = fscd_pc.eval_witness(at=domX.unsqueeze(1))
pow_cri_values = fscd_pc.eval_power_criterion(at=domX.unsqueeze(1).unsqueeze(1))

In [None]:
fig, axes = plot.plot_2d_cond_model(
    p, 
    lambda X: torch.exp(px.log_prob(X)), 
    X, Y, domX=domX, domY=domY, 
    height_ratios=[2,1],
    cmap='pink_r', levels=50)
# plt.xlabel('$x$')
# plot.plot_2d_cond_data(X, Y)

Add power criterion as a subplot

In [None]:
npdomX = domX.detach().numpy()

# plt.figure(figsize=(8, 5))
# plt.plot(npX, wit_values.detach().numpy(), 'bo', label='witness')
ax_pow = axes[1]
# ax3.plot(npdomX, wit_values.detach().numpy(), 'b-', label='Witness')
ax_pow.plot(npdomX, pow_cri_values.detach().numpy(), 'g-', label='Power Cri.')

# ax3.legend()
ax_pow.legend(
#     bbox_to_anchor=(0.1, 1), 
           loc='lower left', ncol=1)

# mark the highest point
imax = torch.argmax(pow_cri_values)
ax_pow.vlines(x=npdomX[imax], ymin=0, ymax=pow_cri_values[imax], 
              linestyles='dashed', color='g')
ax_pow.annotate('$v$', (npdomX[imax], -0.02), xytext=(npdomX[imax]-0.1, -0.13))
# ax.annotate('local max', xy=(2, 1), xytext=(3, 1.5),
#             arrowprops=dict(facecolor='black', shrink=0.05),
#             )


In [None]:
fig.set_figheight(5)
fig.set_figwidth(7)
fig.tight_layout()
fig.savefig('lin_gauss_ls_powcri.pdf', bbox_inches='tight')
fig