-
-
Notifications
You must be signed in to change notification settings - Fork 1k
/
statistics_hsic.py
156 lines (133 loc) · 5.78 KB
/
statistics_hsic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# This software is distributed under BSD 3-clause license (see LICENSE file).
#
# Authors: Heiko Strathmann
#
from numpy import *
from pylab import *
from scipy import *
from shogun import RealFeatures
from shogun import DataGenerator
from shogun import GaussianKernel
from shogun import HSIC
from shogun import PERMUTATION, HSIC_GAMMA
from shogun import EuclideanDistance
from shogun import Statistics, Math
# for nice plotting that fits into our shogun tutorial
import latex_plot_inits
def hsic_graphical():
# parameters, change to get different results
m=250
difference=3
# setting the angle lower makes a harder test
angle=pi/30
# number of samples taken from null and alternative distribution
num_null_samples=500
# use data generator class to produce example data
data=DataGenerator.generate_sym_mix_gauss(m,difference,angle)
# create shogun feature representation
features_x=RealFeatures(array([data[0]]))
features_y=RealFeatures(array([data[1]]))
# compute median data distance in order to use for Gaussian kernel width
# 0.5*median_distance normally (factor two in Gaussian kernel)
# However, shoguns kernel width is different to usual parametrization
# Therefore 0.5*2*median_distance^2
# Use a subset of data for that, only 200 elements. Median is stable
subset=int32(array([x for x in range(features_x.get_num_vectors())])) # numpy
subset=random.permutation(subset) # numpy permutation
subset=subset[0:200]
features_x.add_subset(subset)
dist=EuclideanDistance(features_x, features_x)
distances=dist.get_distance_matrix()
features_x.remove_subset()
median_distance=np.median(distances)
sigma_x=median_distance**2
features_y.add_subset(subset)
dist=EuclideanDistance(features_y, features_y)
distances=dist.get_distance_matrix()
features_y.remove_subset()
median_distance=np.median(distances)
sigma_y=median_distance**2
print "median distance for Gaussian kernel on x:", sigma_x
print "median distance for Gaussian kernel on y:", sigma_y
kernel_x=GaussianKernel(10,sigma_x)
kernel_y=GaussianKernel(10,sigma_y)
# create hsic instance. Note that this is a convienience constructor which copies
# feature data. features_x and features_y are not these used in hsic.
# This is only for user-friendlyness. Usually, its ok to do this.
# Below, the alternative distribution is sampled, which means
# that new feature objects have to be created in each iteration (slow)
# However, normally, the alternative distribution is not sampled
hsic=HSIC(kernel_x,kernel_y,features_x,features_y)
# sample alternative distribution
alt_samples=zeros(num_null_samples)
for i in range(len(alt_samples)):
data=DataGenerator.generate_sym_mix_gauss(m,difference,angle)
features_x.set_feature_matrix(array([data[0]]))
features_y.set_feature_matrix(array([data[1]]))
# re-create hsic instance everytime since feature objects are copied due to
# useage of convienience constructor
hsic=HSIC(kernel_x,kernel_y,features_x,features_y)
alt_samples[i]=hsic.compute_statistic()
# sample from null distribution
# permutation, biased statistic
hsic.set_null_approximation_method(PERMUTATION)
hsic.set_num_null_samples(num_null_samples)
null_samples_boot=hsic.sample_null()
# fit gamma distribution, biased statistic
hsic.set_null_approximation_method(HSIC_GAMMA)
gamma_params=hsic.fit_null_gamma()
# sample gamma with parameters
null_samples_gamma=array([gamma(gamma_params[0], gamma_params[1]) for _ in range(num_null_samples)])
# plot
figure()
# plot data x and y
subplot(2,2,1)
gca().xaxis.set_major_locator( MaxNLocator(nbins = 4) ) # reduce number of x-ticks
gca().yaxis.set_major_locator( MaxNLocator(nbins = 4) ) # reduce number of x-ticks
grid(True)
plot(data[0], data[1], 'o')
title('Data, rotation=$\pi$/'+str(1/angle*pi)+'\nm='+str(m))
xlabel('$x$')
ylabel('$y$')
# compute threshold for test level
alpha=0.05
null_samples_boot.sort()
null_samples_gamma.sort()
thresh_boot=null_samples_boot[floor(len(null_samples_boot)*(1-alpha))];
thresh_gamma=null_samples_gamma[floor(len(null_samples_gamma)*(1-alpha))];
type_one_error_boot=sum(null_samples_boot<thresh_boot)/float(num_null_samples)
type_one_error_gamma=sum(null_samples_gamma<thresh_boot)/float(num_null_samples)
# plot alternative distribution with threshold
subplot(2,2,2)
gca().xaxis.set_major_locator( MaxNLocator(nbins = 3) ) # reduce number of x-ticks
gca().yaxis.set_major_locator( MaxNLocator(nbins = 3) ) # reduce number of x-ticks
grid(True)
hist(alt_samples, 20, normed=True);
axvline(thresh_boot, 0, 1, linewidth=2, color='red')
type_two_error=sum(alt_samples<thresh_boot)/float(num_null_samples)
title('Alternative Dist.\n' + 'Type II error is ' + str(type_two_error))
# compute range for all null distribution histograms
hist_range=[min([min(null_samples_boot), min(null_samples_gamma)]), max([max(null_samples_boot), max(null_samples_gamma)])]
# plot null distribution with threshold
subplot(2,2,3)
gca().xaxis.set_major_locator( MaxNLocator(nbins = 3) ) # reduce number of x-ticks
gca().yaxis.set_major_locator( MaxNLocator(nbins = 3) ) # reduce number of x-ticks
grid(True)
hist(null_samples_boot, 20, range=hist_range, normed=True);
axvline(thresh_boot, 0, 1, linewidth=2, color='red')
title('Sampled Null Dist.\n' + 'Type I error is ' + str(type_one_error_boot))
# plot null distribution gamma
subplot(2,2,4)
gca().xaxis.set_major_locator( MaxNLocator(nbins = 3) ) # reduce number of x-ticks
gca().yaxis.set_major_locator( MaxNLocator(nbins = 3) ) # reduce number of x-ticks
grid(True)
hist(null_samples_gamma, 20, range=hist_range, normed=True);
axvline(thresh_gamma, 0, 1, linewidth=2, color='red')
title('Null Dist. Gamma\nType I error is ' + str(type_one_error_gamma))
grid(True)
# pull plots a bit apart
subplots_adjust(hspace=0.5)
subplots_adjust(wspace=0.5)
if __name__=='__main__':
hsic_graphical()
show()