 A notebook to demonstrate `cgoftest.KSSDTest`

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
#%config InlineBackend.figure_format = 'svg'
#%config InlineBackend.figure_format = 'pdf'



In [None]:
import numpy as np
import torch
import torch.distributions as dists

import kcgof
import kcgof.log as klog
import kcgof.util as util
import kcgof.cdensity as cden
import kcgof.cdata as cdat
import kcgof.cgoftest as cgof
import kcgof.kernel as ker
import kcgof.plot as plot

In [None]:
import matplotlib
import matplotlib.pyplot as plt

# font options
font = {
    #'family' : 'normal',
    #'weight' : 'bold',
    'size'   : 20
}

plt.rc('font', **font)
plt.rc('lines', linewidth=2)
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

## Ordinary least squares with Gaussian noise

$$p(y|x) = \mathcal{N}(slope*x+c, variance)$$

### KSSD Test

In [None]:
dx = 1

slope = torch.tensor([0.5])
noise_variance = 1.0
c = 1.0

# set up the conditional probability model
p = cden.CDGaussianOLS(slope, c=c, variance=noise_variance)

In [None]:
# generate some toy data 
# CondSource
# If the following parameters are the same as above, then H0 is true.
# Can perturb these to have an H1 case.
cs = cdat.CSGaussianOLS(slope+0.5, c=c+0.5, variance=noise_variance)
# cs = cdat.CSGaussianOLS(slope, c=c+0.8, variance=noise_variance)

# sample Y conditioned on X
n = 400 # sample size
with util.TorchSeedContext(seed=17):
    px = dists.Normal(0, 1) 
    X = px.sample((n, dx))
Y = cs(X, seed=28)

Plot the data and the model

In [None]:
# kernels
# k = kernel on X
# l = kernel on Y
k = ker.PTKGauss(sigma2=1)
l = ker.PTKGauss(sigma2=1.0)

In [None]:
ep = 0.7
# make a grid that covers X

domX = torch.linspace(torch.min(X)-ep, torch.max(X)+ep, 100)
domY = torch.linspace(torch.min(Y).item()-ep, torch.max(Y).item()+ep, 200)

fscd_pc = cgof.FSCDPowerCriterion(p, k, l, X, Y)


# evaluate the power criterion function
wit_values = fscd_pc.eval_witness(at=domX.unsqueeze(1))
pow_cri_values = fscd_pc.eval_power_criterion(at=domX.unsqueeze(1).unsqueeze(1))

In [None]:
fig, axes = plot.plot_2d_cond_model(
    p, 
    lambda X: torch.exp(px.log_prob(X)), 
    X, Y, domX=domX, domY=domY, 
    height_ratios=[2,1],
    cmap='pink_r', levels=50)
# plt.xlabel('$x$')
# plot.plot_2d_cond_data(X, Y)

Add power criterion as a subplot

In [None]:
npdomX = domX.detach().numpy()

# plt.figure(figsize=(8, 5))
# plt.plot(npX, wit_values.detach().numpy(), 'bo', label='witness')
ax_pow = axes[1]
# ax3.plot(npdomX, wit_values.detach().numpy(), 'b-', label='Witness')
ax_pow.plot(npdomX, pow_cri_values.detach().numpy(), 'g-', label='Power Cri.')

# ax3.legend()
ax_pow.legend(
#     bbox_to_anchor=(0.1, 1), 
           loc='lower left', ncol=1)

# mark the highest point
imax = torch.argmax(pow_cri_values)
ax_pow.vlines(x=npdomX[imax], ymin=0, ymax=pow_cri_values[imax], 
              linestyles='dashed', color='g')
ax_pow.annotate('$v$', (npdomX[imax], -0.02), xytext=(npdomX[imax]-0.1, -0.13))
# ax.annotate('local max', xy=(2, 1), xytext=(3, 1.5),
#             arrowprops=dict(facecolor='black', shrink=0.05),
#             )


In [None]:
fig.set_figheight(5)
fig.set_figwidth(7)
fig.tight_layout()
fig.savefig('lin_gauss_ls_powcri.pdf', bbox_inches='tight')
fig

Testing with KSSD

In [None]:
# Construct a KSSD test object
kssdtest = cgof.KSSDTest(p, k, l, alpha=0.05, n_bootstrap=400, seed=9)

In [None]:
result = kssdtest.perform_test(X, Y, return_simulated_stats=True)
result

In [None]:
test_stat = result['test_stat']
plt.figure(figsize=(10, 6))
plt.hist(result['sim_stats'], density=True, label='Bootstrapped');
# plt.stem([test_stat, test_stat], [0, 0.002], 'r', label='Observed', use_line_collection=True)
plt.xlabel('KSSD statistic')
plt.legend()

print('H0 rejected?: {}'.format(result['h0_rejected']))
print('Observed stat: {:.3f}'.format(result['test_stat']))

### FSCD Test

The Finite Set Conditional Discrepancy (FSCD) test

In [None]:
# J x dx torch tensor specifying J test locations
V = torch.tensor([[1.0]])
fscdtest = cgof.FSCDTest(p, k, l, V, alpha=0.05, n_bootstrap=400, seed=10)

In [None]:
fscd_result = fscdtest.perform_test(X, Y, return_simulated_stats=True)

In [None]:
test_stat = fscd_result['test_stat']
plt.figure(figsize=(10, 6))
plt.hist(fscd_result['sim_stats'], density=True, label='Bootstrapped');
# plt.stem([test_stat, test_stat], [0, 0.002], 'r', label='Observed', use_line_collection=True)
plt.xlabel('FSCD statistic')
plt.legend()

print('H0 rejected?: {}'.format(fscd_result['h0_rejected']))
print('Observed stat: {:.3f}'.format(fscd_result['test_stat']))

## Optimized KSSD test

Tune the kernels $k$ and $l$ by maximizing the test power of the KSSD test.

In [None]:
# split the data into trainin and test sets
tr, te = cdat.CondData(X, Y).split_tr_te(tr_proportion=0.5)
Xtr, Ytr = tr.xy()

In [None]:
Ytr.requires_grad = False
kssd_pc = cgof.KSSDPowerCriterion(p, k, l, Xtr, Ytr)

max_iter = 200
# learning rate 
lr = 1e-3
# regularization
reg = 1e-3

# constraint satisfaction function
def con_f(params):
    ksigma2 = params[0]
    lsigma2 = params[1]
    ksigma2.data.clamp_(min=1e-2, max=10)
    lsigma2.data.clamp_(min=1e-2, max=10)
    
objs = kssd_pc.optimize_params(
    [k.sigma2, l.sigma2], constraint_f=con_f,
    lr=lr, reg=reg, max_iter=max_iter)

In [None]:
np_objs = objs.detach().numpy()

plt.figure(figsize=(8,5))
plt.plot(np.arange(max_iter), np_objs, 'b-')
plt.xlabel('iteration')
plt.ylabel('Power criterion')

In [None]:
k.sigma2

In [None]:
l.sigma2

Test on the test set

In [None]:
# Construct a KSSD test object
kssdtest = cgof.KSSDTest(p, k, l, alpha=0.05, n_bootstrap=400, seed=9)
Xte, Yte = te.xy()
kssdtest.perform_test(Xte, Yte)

## Optimized FSCD test


The FSCD test requires two kernels $k$ and $l$, and a set $V$ of test locations as input. We can tune these parameters by optimizing the test power.

In [None]:
Ytr.requires_grad = False
fscd_pc = cgof.FSCDPowerCriterion(p, k, l, Xtr, Ytr)

max_iter = 200
# learning rate 
lr = 1e-2
# regularization
reg = 1e-3

# constraint satisfaction function
def con_f(params, V):
    ksigma2 = params[0]
    lsigma2 = params[1]
    ksigma2.data.clamp_(min=1e-2, max=10)
    lsigma2.data.clamp_(min=1e-2, max=10)
    V.data.clamp_(min=-5, max=5)
    
objs = fscd_pc.optimize_params(
    [k.sigma2, l.sigma2], V, 
    constraint_f=con_f,
    lr=lr, reg=reg, max_iter=max_iter)

In [None]:
np_objs = objs.detach().numpy()

plt.figure(figsize=(8,5))
plt.plot(np.arange(max_iter), np_objs, 'b-')
plt.xlabel('iteration')
plt.ylabel('Power criterion')

In [None]:
# Check optimized V
V

In [None]:
k.sigma2

In [None]:
l.sigma2