-
Notifications
You must be signed in to change notification settings - Fork 149
/
test_bandwidth_estimation.py
executable file
·117 lines (79 loc) · 3.2 KB
/
test_bandwidth_estimation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#! /usr/bin/env python
# Copyright 2013 Tom SF Haines
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
import numpy
import numpy.random
import cv
from utils.cvarray import *
from utils.prog_bar import ProgBar
from ms import MeanShift
base_samples = 64
scale = 1.0
size = 5.0
circle = True
line = True
# Sample from a circle + noise model to create some data...
if circle:
samples = 8 * base_samples
theta = 2.0 * numpy.pi * numpy.random.random(samples)
radius = scale*3.0 + (numpy.random.beta(5.0, 5.0, samples)-0.5)
x = radius * numpy.cos(theta)
y = radius * numpy.sin(theta)
data1 = numpy.concatenate((x.reshape((-1,1)), y.reshape((-1,1))), axis=1)
# More data - from a line...
if line:
samples = 4 * base_samples
x = scale * (numpy.random.beta(3.0, 3.0, samples)*9.0 - 4.5)
y = numpy.random.normal(scale=0.2, size=samples)
data2 = numpy.concatenate((x.reshape((-1,1)), y.reshape((-1,1))), axis=1)
# Munge it all together...
if circle and line: data = numpy.concatenate((data1, data2), axis=0)
elif circle: data = data1
else: data = data2
numpy.random.shuffle(data)
# Visualise the samples...
dim = 512
image = numpy.zeros((dim, dim, 3), dtype=numpy.float32)
for r in xrange(data.shape[0]):
loc = data[r,:]
loc = (loc + size) / (2.0*size)
loc *= dim
try:
image[int(loc[1]+0.5), int(loc[0]+0.5), :] = 255.0
except: pass # Deals with out of range values.
image = array2cv(image)
cv.SaveImage('bandwidth_samples.png', image)
# Setup the mean shift object...
ms = MeanShift()
ms.set_data(data, 'df')
ms.set_kernel('gaussian')
ms.set_spatial('kd_tree')
# Progress bar version of scale_loo_nll...
def scale_loo_nll():
p = ProgBar()
ms.scale_loo_nll(callback = p.callback)
del p
# Iterate and try out a bunch of different algorithms...
for name, alg in [('human_picked', lambda: ms.set_scale(numpy.array([5.0, 5.0]))), ('Silverman',ms.scale_silverman), ('Scott', ms.scale_scott), ('loo_nll', scale_loo_nll)]:
# Calculate and print out the scales...
print '<', name, '>'
alg()
print 'Scale:', ms.get_scale()
print 'loo nll for this scale =', ms.loo_nll()
mean, sd = ms.stats()
print 'mean = (%f, %f); sd = (%f, %f)'%(mean[0], mean[1], sd[0], sd[1])
# Render out a normalised probability map...
image = numpy.zeros((dim, dim, 3), dtype=numpy.float32)
p = ProgBar()
for row in xrange(dim):
p.callback(row, dim)
sam = numpy.append(numpy.linspace(-size, size, dim).reshape((-1,1)), ((row / (dim-1.0) - 0.5) * 2.0 * size) * numpy.ones(dim).reshape((-1,1)), axis=1)
image[row, :, :] = ms.probs(sam).reshape((-1,1))
del p
print 'Largest sampled probability =', image.max()
image *= 255.0 / image.max()
image = array2cv(image)
cv.SaveImage('bandwidth_%s.png'%name, image)
print