A notebook to check that the wild bootstrap procedure to simulate the null distribution of KSSD and FSCD.

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
#%config InlineBackend.figure_format = 'svg'
#%config InlineBackend.figure_format = 'pdf'

In [None]:
import numpy as np
import torch
import torch.distributions as dists

import kcgof
import kcgof.log as klog
import kcgof.util as util
import kcgof.cdensity as cden
import kcgof.cdata as cdat
import kcgof.cgoftest as cgof
import kcgof.kernel as ker
import kcgof.plot as plot

import scipy.stats as stats

In [None]:
import matplotlib
import matplotlib.pyplot as plt

# font options
font = {
    #'family' : 'normal',
    #'weight' : 'bold',
    'size'   : 20
}

plt.rc('font', **font)
plt.rc('lines', linewidth=2)
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

#         'gaussls_h0_d5' from ex1_vary_n.py

In [None]:
dx = 5
slope_h0_d5 = torch.arange(dx) + 1.0
# p 
p = cden.CDGaussianOLS(slope=slope_h0_d5, c=0, variance=1.0)
# rx
rx = cden.RXIsotropicGaussian(dx=dx)
# CondSource for r
cs = cdat.CSGaussianOLS(slope=slope_h0_d5, c=0, variance=1.0)

In [None]:
# data

# sample Y conditioned on X
n = 800 # sample size
with util.TorchSeedContext(seed=18):    
    X = rx.sample(n)
Y = cs(X, seed=29)

KSSD test

In [None]:
# kernels
# k = kernel on X
# l = kernel on Y
# k = ker.PTKGauss(sigma2=2)
# l = ker.PTKGauss(sigma2=1.0)

sigx = util.pt_meddistance(X, subsample=600, seed=3)
sigy = util.pt_meddistance(Y, subsample=600, seed=38)

# kernels
# k = kernel on X
# k = ker.PTKGauss(sigma2=sigx**2)
k = ker.PTKGauss(sigma2=2**2)
# l = kernel on Y
# l = ker.PTKGauss(sigma2=sigy**2)
l = ker.PTKGauss(sigma2=4**2)

In [None]:
# Construct a KSSD test object
n_bootstrap = 1000
kssdtest = cgof.KSSDTest(p, k, l, alpha=0.05, n_bootstrap=n_bootstrap, seed=9)
result = kssdtest.perform_test(X, Y, return_simulated_stats=True)
result

In [None]:
def redraw_for_histogram(kssd, p, rx, cs, n_run, n):
    '''
    Repeatedly draw samples n_run times and compute the statistic.
    Attended to be used for checking the boostrapping procedure.
    
    n: sample size to draw each time
    '''
    sts = np.zeros(n_run)
    for t in range(n_run):
        with util.TorchSeedContext(seed=1700+t):    
            X = rx.sample(n)
        Y = cs(X, seed=378+t)
        s = kssd.compute_stat(X, Y)
        sts[t] = s
    return sts
        

In [None]:
n_redraw = 1000
redraw_stats = redraw_for_histogram(kssdtest, p, rx, cs, n_run=n_redraw, n=n)

In [None]:
test_stat = result['test_stat']
plt.figure(figsize=(10, 6))
plt.hist(result['sim_stats'], density=True, label='Bootstrapped', bins=20, alpha=0.5);
plt.hist(redraw_stats, density=True, label='Redraw', bins=20, alpha=0.5);
# plt.stem([test_stat, test_stat], [0, 0.002], 'r', label='Observed', use_line_collection=True)
plt.xlabel('KSSD statistic')
plt.legend()

print('H0 rejected?: {}'.format(result['h0_rejected']))
print('Observed stat: {:.3f}'.format(result['test_stat']))
print('n = {}'.format(n))

QQ plot

In [None]:
import statsmodels
import statsmodels.api as sm
import statsmodels.graphics.gofplots

In [None]:
# https://www.statsmodels.org/devel/generated/statsmodels.graphics.gofplots.qqplot_2samples.html
boot_stats = result['sim_stats']
fig=statsmodels.graphics.gofplots.qqplot_2samples(
    boot_stats, redraw_stats, line='45', 
    xlabel='Redraw', ylabel='Bootstrap',
)
plt.axis('square')
plt.grid()
print('n = {}'.format(n))
print('n_bootstrap = {}'.format(n_bootstrap))
print('n_redraw = {}'.format(n_redraw))

quan_boot = np.quantile(boot_stats, 0.95)
quan_redraw = np.quantile(redraw_stats, 0.95)
print('95% quantile of bootstrapped distribution: {:.4}'.format(quan_boot))
print('95% quantile of distribution from redrawing: {:.4}'.format(quan_redraw))


In [None]:
stats.describe(boot_stats)

In [None]:
sam1 = np.random.randn(500)
sam2 = np.random.randn(700)*2
statsmodels.graphics.gofplots.qqplot_2samples(
    sam1, sam2, line='45', 
);
plt.grid()

In [None]:
np.quantile(sam1, 0.95)

In [None]:
np.quantile(sam2, 0.95)