-
Notifications
You must be signed in to change notification settings - Fork 0
/
SA.py
66 lines (54 loc) · 2.14 KB
/
SA.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# Simulated Annealing
# from math import exp
from numpy import exp
from random import sample, random
from .utils.base import Base
from .utils.path import Path
class SA(Base):
"""
Simulated annealing is a probabilistic technique for approximating the global optimum of a given function.
Specifically, it is a metaheuristic to approximate global optimization in a large search space for an optimization problem.\n
-----
`iter: int` THE NUMBER OF ITERATIONS\n
The maximum number of iterations of the algorithm.\n
-----
`t: int` INITIAL TEMPERATURE\n
The initial temperature for the search decreases with the progress of the search.\n
-----
`g: float` CHANGE COEFFICIENT\n
The coefficient affecting temperature change.\n
"""
def __init__(self, iter: int, t: int, g: float) -> None:
"""Initializes the hyperparameters for the algorithm."""
self.iter = iter
self.t = t
self.g = g
def __is_acceptable(self, prb_leng: float, tmp_leng: float) -> bool:
"""Checks if the state transition will execute."""
prob = min(1, exp(-(prb_leng - tmp_leng) / self.t))
if prob > random():
return True
return False
def run(self, points: list[tuple[int]], name: str = None) -> Path:
"""Runs the algorithm for the given 2D points."""
l = len(points)
dm = SA._distance_matrix(points)
tmp_indx = [i for i in range(l)] + [0]
tmp_leng = SA._calculate_dist(dm, tmp_indx)
res_indx = tmp_indx.copy()
res_leng = tmp_leng
for _ in range(self.iter):
i, j = sample(range(1, l), 2)
prb_indx = tmp_indx.copy()
prb_indx[i], prb_indx[j] = prb_indx[j], prb_indx[i]
prb_leng = SA._calculate_dist(dm, prb_indx)
if self.__is_acceptable(prb_leng, tmp_leng):
tmp_indx = prb_indx
tmp_leng = prb_leng
if tmp_leng < res_leng:
res_indx = tmp_indx
res_leng = tmp_leng
self.t *= self.g
return Path(indx=res_indx, leng=res_leng, name=name)
if __name__ == "__main__":
pass