-
Notifications
You must be signed in to change notification settings - Fork 1
/
astarpathfinding.py
182 lines (159 loc) · 4.8 KB
/
astarpathfinding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
#!/usr/bin/python
from collections import defaultdict
from heapq import *
from math import sqrt
from copy import deepcopy
import pdb
def build_map(desc):
desc.split('\n')
def vecadd((x1,y1), (x2,y2)):
return (x1+x2, y1+y2)
class Node(object):
def __init__(self, pos, parent, dsrc, dtarget):
self.pos,self.parent,self.dtarget,self.dsrc = \
pos,parent,dtarget,dsrc
def set_dsrc(self, value):
self._dsrc = value
self.hvalue = value+self.dtarget
def get_dsrc(self):
return self._dsrc
dsrc = property(get_dsrc, set_dsrc)
def __cmp__(self, node):
return self.hvalue - node.hvalue
def findnode(l, pos):
for node in l:
if node.pos==pos:
return node
return None
class Map(object):
def __init__(self, desc):
self.text_map = []
lines = desc.split('\n')
self.map = defaultdict(lambda:1)
for y,line in enumerate(lines):
self.text_map.append(list(line))
for x,c in enumerate(line):
self.map[(x,y)] = (0 if c!='1' else 1)
if c=='2':
self.start = (x, y)
elif c=='3':
self.target = (x, y)
self.dtarget = {}
self.calc_all_dtarget()
self.debug_map = deepcopy(self.text_map)
def print_lines(self, map_):
print '\n'.join(''.join(line) for line in map_)
def calc_dtarget(self, (x,y)):
tx,ty = self.target
'''
return abs(x-tx) + abs(y-ty)
'''
dx,dy = abs(x-tx), abs(y-ty)
dmax,dmin = dx>dy and (dx,dy) or (dy,dx)
return dmax*10 + dmin*4
def calc_all_dtarget(self):
for pos, value in self.map.items():
if value==0:
self.dtarget[pos] = self.calc_dtarget(pos)
def children(self, node):
'''
directions = [
(0,-1),
(1,0),
(0,1),
(-1,0)
]
'''
directions = [
(-1,-1),
(0,-1),
(1,-1),
(1,0),
(1,1),
(0,1),
(-1,1),
(-1,0)
]
(x,y) = node.pos
for dx,dy in directions:
childpos = (x+dx, y+dy)
if self.map[childpos]==0:
yield childpos,node.dsrc+1
def findpath(self):
startnode = Node(self.start, None, 0, 0)
target = self.target
openset = [startnode] # node list
closeset = set() # position set
while openset:
node = heappop(openset)
if node.pos == target:
return node
closeset.add(node.pos)
#debug
'''
x, y = pos
self.debug_map[y][x] = '*'
self.print_lines(self.debug_map)
pdb.set_trace()
'''
for childpos, dsrc in self.children(node):
if childpos in closeset:
continue
opennode = findnode(openset, childpos)
if opennode and dsrc>=opennode.dsrc:
continue
elif not opennode:
child = Node(childpos, node, dsrc, self.dtarget[childpos])
heappush(openset, child)
else:
opennode.dsrc = dsrc
opennode.parent = node
heapify(openset)
def findpath1(self):
startnode = Node(self.start, None, 0, 0)
target = self.target
openset = [startnode] # node list
closeset = set() # position set
while openset:
node = heappop(openset)
pos = node.pos
if pos in closeset:
continue
if pos == target:
return node
closeset.add(pos)
#debug
'''
x, y = pos
self.debug_map[y][x] = '*'
self.print_lines(self.debug_map)
pdb.set_trace()
'''
for childpos, dsrc in self.children(node):
if childpos in closeset:
continue
child = Node(childpos, node, dsrc, self.dtarget[childpos])
heappush(openset, child)
if __name__ == '__main__':
graph_desc = '''
0000000000000000000100000
0111111111111111110103010
0100000000000000000111110
0101111111111111111100010
0101000000000000000100010
0101000001000110111100010
0101000001012100010000010
0101000001011100010000010
0101000001000001010000010
0100000001111111010000010
0111111111111111111111110
0000000000000000000000000
'''
map = Map(graph_desc)
for i in range(1000):
node = map.findpath()
while node:
x,y = node.pos
map.text_map[y][x] = '*'
node = node.parent
map.print_lines(map.text_map)