/
main.py
177 lines (119 loc) · 4.33 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import math
import xml.etree.ElementTree as ET
import pprint
import sys
CIRCLE_TAG_NAME = '{http://www.w3.org/2000/svg}circle'
GROUP_TAG_NAME = '{http://www.w3.org/2000/svg}g'
pp = pprint.PrettyPrinter(indent=4)
def circle_to_point(circle):
return (float(circle.attrib['cx']),
float(circle.attrib['cy']))
def read_svg_file(svg_file_name):
return ET.parse(svg_file_name)
def get_all_points(tree):
return [circle_to_point(circle)
for circle in tree.iter(CIRCLE_TAG_NAME)]
def get_point_by_id(tree, point_id):
return [circle_to_point(circle)
for circle in tree.iter(CIRCLE_TAG_NAME)
if 'id' in circle.attrib
if circle.attrib['id'] == point_id]
def get_group_by_id(tree, group_id):
return [circle
for group in tree.iter(GROUP_TAG_NAME)
if 'id' in group.attrib
if group.attrib['id'] == group_id
for circle in get_all_points(group)]
def distance_squared(point1, point2):
x1, y1 = point1
x2, y2 = point2
dx = x1 - x2
dy = y1 - y2
return dx * dx + dy * dy
def closest_point(all_points, new_point):
best_point = None
best_distance = None
for current_point in all_points:
current_distance = distance_squared(new_point, current_point)
if best_distance is None or current_distance < best_distance:
best_distance = current_distance
best_point = current_point
return best_point
k = 2
def build_kdtree(points, depth=0):
n = len(points)
if n <= 0:
return None
axis = depth % k
sorted_points = sorted(points, key=lambda point: point[axis])
return {
'point': sorted_points[n // 2],
'left': build_kdtree(sorted_points[:n // 2], depth + 1),
'right': build_kdtree(sorted_points[n // 2 + 1:], depth + 1)
}
def kdtree_naive_closest_point(root, point, depth=0, best=None):
if root is None:
return best
axis = depth % k
next_best = None
next_branch = None
if best is None or distance_squared(point, best) > distance_squared(point, root['point']):
next_best = root['point']
else:
next_best = best
if point[axis] < root['point'][axis]:
next_branch = root['left']
else:
next_branch = root['right']
return kdtree_naive_closest_point(next_branch, point, depth + 1, next_best)
def closer_distance(pivot, p1, p2):
if p1 is None:
return p2
if p2 is None:
return p1
d1 = distance_squared(pivot, p1)
d2 = distance_squared(pivot, p2)
if d1 < d2:
return p1
else:
return p2
def kdtree_closest_point(root, point, depth=0):
if root is None:
return None
axis = depth % k
next_branch = None
opposite_branch = None
if point[axis] < root['point'][axis]:
next_branch = root['left']
opposite_branch = root['right']
else:
next_branch = root['right']
opposite_branch = root['left']
best = closer_distance(point,
kdtree_closest_point(next_branch,
point,
depth + 1),
root['point'])
if distance_squared(point, best) > (point[axis] - root['point'][axis]) ** 2:
best = closer_distance(point,
kdtree_closest_point(opposite_branch,
point,
depth + 1),
best)
return best
svg_files = ['./points.svg', './points2.svg']
for svg_file in svg_files:
print(svg_file)
svg_tree = read_svg_file(svg_file)
[pivot] = get_point_by_id(svg_tree, 'pivot')
[expected] = get_point_by_id(svg_tree, 'closest')
points = get_group_by_id(svg_tree, 'points')
kdtree = build_kdtree(points)
found = kdtree_closest_point(kdtree, pivot)
expected_distance = math.sqrt(distance_squared(pivot, expected))
found_distance = math.sqrt(distance_squared(pivot, found))
print(" Expected: %s (distance: %f)" % (expected, expected_distance))
print(" Found: %s (distance: %f)" % (found, found_distance))
if found_distance > expected_distance:
print(" ----- FAILURE! FOUND WORSE DISTANCE! -----")
sys.exit(1)