/
random_search.py
99 lines (78 loc) · 3.24 KB
/
random_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import datetime
import random
from typing import Union, Generator
from photonai.optimization.base_optimizer import PhotonSlaveOptimizer
from photonai.photonlogger.logger import logger
class RandomSearchOptimizer(PhotonSlaveOptimizer):
"""Random search optimizer.
Searches for the best configuration by randomly
testing hyperparameter combinations without any grid.
"""
def __init__(self, limit_in_minutes: Union[float, None] = 60, n_configurations: Union[int, None] = None):
"""
Initialize the object.
One of limit_in_minutes or n_configurations must differ from None.
Parameters:
limit_in_minutes:
Total time limit in minutes.
n_configurations:
Maximum number of configurations to be calculated.
"""
self.pipeline_elements = None
self.parameter_iterable = None
self.ask = self.next_config_generator()
self.n_configurations = None
if not limit_in_minutes or limit_in_minutes <= 0:
self.limit_in_minutes = None
else:
self.limit_in_minutes = limit_in_minutes
self.start_time = None
self.end_time = None
if not n_configurations or n_configurations <= 0:
self.n_configurations = None
else:
self.n_configurations = n_configurations
self.k_configutration = 0 # use k++ until k==n: break
if self.n_configurations is None and self.limit_in_minutes is None:
msg = "No stopping criteria for RandomSearchOptimizer."
logger.error(msg)
raise ValueError(msg)
def prepare(self, pipeline_elements: list, maximize_metric: bool) -> None:
"""
Initialize the grid-free random hyperparameter search.
Parameters:
pipeline_elements:
List of all PipelineElements to create the hyperparameter space.
maximize_metric:
Boolean to distinguish between score and error.
"""
self.start_time = None
self.pipeline_elements = pipeline_elements
self.ask = self.next_config_generator()
def next_config_generator(self) -> Generator:
"""
Generator for new configs - ask method.
Returns:
Yield the next config.
"""
while True:
_ = (yield self._generate_config())
self.k_configutration += 1
if self.limit_in_minutes:
if self.start_time is None:
self.start_time = datetime.datetime.now()
self.end_time = self.start_time + datetime.timedelta(minutes=self.limit_in_minutes)
if datetime.datetime.now() >= self.end_time:
return
if self.n_configurations:
if self.k_configutration >= self.n_configurations:
return
def _generate_config(self):
config = {}
for p_element in self.pipeline_elements:
for h_key, h_value in p_element.hyperparameters.items():
if isinstance(h_value, list):
config[h_key] = random.choice(h_value)
else:
config[h_key] = h_value.get_random_value()
return config