-
Notifications
You must be signed in to change notification settings - Fork 5.4k
/
trainable.py
82 lines (62 loc) · 2.09 KB
/
trainable.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# flake8: noqa
# fmt: off
# __example_objective_start__
def objective(x, a, b):
return a * (x ** 0.5) + b
# __example_objective_end__
# fmt: on
# __function_api_report_intermediate_metrics_start__
from ray import train, tune
def trainable(config: dict):
intermediate_score = 0
for x in range(20):
intermediate_score = objective(x, config["a"], config["b"])
train.report({"score": intermediate_score}) # This sends the score to Tune.
tuner = tune.Tuner(trainable, param_space={"a": 2, "b": 4})
results = tuner.fit()
# __function_api_report_intermediate_metrics_end__
# __function_api_report_final_metrics_start__
from ray import train, tune
def trainable(config: dict):
final_score = 0
for x in range(20):
final_score = objective(x, config["a"], config["b"])
train.report({"score": final_score}) # This sends the score to Tune.
tuner = tune.Tuner(trainable, param_space={"a": 2, "b": 4})
results = tuner.fit()
# __function_api_report_final_metrics_end__
# fmt: off
# __function_api_return_final_metrics_start__
def trainable(config: dict):
final_score = 0
for x in range(20):
final_score = objective(x, config["a"], config["b"])
return {"score": final_score} # This sends the score to Tune.
# __function_api_return_final_metrics_end__
# fmt: on
# __class_api_example_start__
from ray import train, tune
class Trainable(tune.Trainable):
def setup(self, config: dict):
# config (dict): A dict of hyperparameters
self.x = 0
self.a = config["a"]
self.b = config["b"]
def step(self): # This is called iteratively.
score = objective(self.x, self.a, self.b)
self.x += 1
return {"score": score}
tuner = tune.Tuner(
Trainable,
run_config=train.RunConfig(
# Train for 20 steps
stop={"training_iteration": 20},
checkpoint_config=train.CheckpointConfig(
# We haven't implemented checkpointing yet. See below!
checkpoint_at_end=False
),
),
param_space={"a": 2, "b": 4},
)
results = tuner.fit()
# __class_api_example_end__