-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
136 lines (112 loc) · 4.27 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# import copy
# class A:
# def __init__(self):
# self.x = 1
# def add(A):
# # B = copy.deepcopy(A)
# B = A[:]
# return B
# def do(A):
# A[0].x = 666
# def ad(a):
# a = a+1
# # return a
# a = [A(),A()]
# print(a[1].x)
# b = add(a)
# a[1].x = 2
# print(a[1].x)
# print(b[1].x)
# do(a)
# print(a[0].x)
# k = 1
# ad(k)
# print(k)
# import matplotlib.pyplot as plt
# x = [1,2,3,4,5,6,7]
# y = [1.0,2.0,3.0,4.0,5.0,6.0,7.0]
# alpha = [i/7 for i in y]
# print(alpha)
# plt.scatter(x,y,c = 'red',alpha=alpha)
# plt.show()
"""
Visualize Genetic Algorithm to find the shortest path for travel sales problem.
Visit my tutorial website for more: https://morvanzhou.github.io/tutorials/
"""
import matplotlib.pyplot as plt
import numpy as np
N_CITIES = 20 # DNA size
CROSS_RATE = 0.1
MUTATE_RATE = 0.02
POP_SIZE = 500
N_GENERATIONS = 500
class GA(object):
def __init__(self, DNA_size, cross_rate, mutation_rate, pop_size, ):
self.DNA_size = DNA_size
self.cross_rate = cross_rate
self.mutate_rate = mutation_rate
self.pop_size = pop_size
self.pop = np.vstack([np.random.permutation(DNA_size) for _ in range(pop_size)])
def translateDNA(self, DNA, city_position): # get cities' coord in order
line_x = np.empty_like(DNA, dtype=np.float64)
line_y = np.empty_like(DNA, dtype=np.float64)
for i, d in enumerate(DNA):
city_coord = city_position[d]
line_x[i, :] = city_coord[:, 0]
line_y[i, :] = city_coord[:, 1]
return line_x, line_y
def get_fitness(self, line_x, line_y):
total_distance = np.empty((line_x.shape[0],), dtype=np.float64)
for i, (xs, ys) in enumerate(zip(line_x, line_y)):
total_distance[i] = np.sum(np.sqrt(np.square(np.diff(xs)) + np.square(np.diff(ys))))
fitness = np.exp(self.DNA_size * 2 / total_distance)
return fitness, total_distance
def select(self, fitness):
idx = np.random.choice(np.arange(self.pop_size), size=self.pop_size, replace=True, p=fitness / fitness.sum())
return self.pop[idx]
def crossover(self, parent, pop):
if np.random.rand() < self.cross_rate:
i_ = np.random.randint(0, self.pop_size, size=1) # select another individual from pop
cross_points = np.random.randint(0, 2, self.DNA_size).astype(np.bool) # choose crossover points
keep_city = parent[~cross_points] # find the city number
swap_city = pop[i_, np.isin(pop[i_].ravel(), keep_city, invert=True)]
parent[:] = np.concatenate((keep_city, swap_city))
return parent
def mutate(self, child):
for point in range(self.DNA_size):
if np.random.rand() < self.mutate_rate:
swap_point = np.random.randint(0, self.DNA_size)
swapA, swapB = child[point], child[swap_point]
child[point], child[swap_point] = swapB, swapA
return child
def evolve(self, fitness):
pop = self.select(fitness)
pop_copy = pop.copy()
for parent in pop: # for every parent
child = self.crossover(parent, pop_copy)
child = self.mutate(child)
parent[:] = child
self.pop = pop
class TravelSalesPerson(object):
def __init__(self, n_cities):
self.city_position = np.random.rand(n_cities, 2)
plt.ion()
def plotting(self, lx, ly, total_d):
plt.cla()
plt.scatter(self.city_position[:, 0].T, self.city_position[:, 1].T, s=100, c='k')
plt.plot(lx.T, ly.T, 'g-')
# plt.text(-0.05, -0.05, "Total distance=%.2f" % total_d, fontdict={'size': 20, 'color': 'red'})
plt.xlim((-0.1, 1.1))
plt.ylim((-0.1, 1.1))
plt.pause(0.01)
ga = GA(DNA_size=N_CITIES, cross_rate=CROSS_RATE, mutation_rate=MUTATE_RATE, pop_size=POP_SIZE)
env = TravelSalesPerson(N_CITIES)
for generation in range(N_GENERATIONS):
lx, ly = ga.translateDNA(ga.pop, env.city_position)
fitness, total_distance = ga.get_fitness(lx, ly)
ga.evolve(fitness)
best_idx = np.argmax(fitness)
print('Gen:', generation, '| best fit: %.2f' % fitness[best_idx],)
env.plotting(lx[best_idx], ly[best_idx], total_distance[best_idx])
plt.ioff()
plt.show()