forked from chemprop/chemprop
-
Notifications
You must be signed in to change notification settings - Fork 3
/
hyperparameter_optimization.py
104 lines (82 loc) · 3.86 KB
/
hyperparameter_optimization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from argparse import ArgumentParser, Namespace
from copy import deepcopy
import json
import os
from typing import Dict, Union
from hyperopt import fmin, hp, tpe
import numpy as np
from chemprop.models import build_model
from chemprop.nn_utils import param_count
from chemprop.parsing import add_train_args, modify_train_args
from chemprop.train import cross_validate
from model_comparison import create_logger, create_train_logger
SPACE = {
'hidden_size': hp.quniform('hidden_size', low=300, high=2400, q=100),
'depth': hp.quniform('depth', low=2, high=6, q=1),
'dropout': hp.quniform('dropout', low=0.0, high=0.4, q=0.05),
'ffn_num_layers': hp.quniform('ffn_num_layers', low=1, high=3, q=1)
}
INT_KEYS = ['hidden_size', 'depth', 'ffn_num_layers']
TRAIN_LOGGER = create_train_logger()
def grid_search(args: Namespace):
# Create logger for dataset
logger = create_logger(name='hyperparameter_optimization', save_path=args.log_path)
# Run grid search
results = []
# Define hyperparameter optimization
def objective(hyperparams: Dict[str, Union[int, float]]) -> float:
# Convert hyperparams from float to int when necessary
for key in INT_KEYS:
hyperparams[key] = int(hyperparams[key])
# Update args with hyperparams
hyper_args = deepcopy(args)
if args.save_dir is not None:
folder_name = '_'.join([f'{key}_{value}' if key in INT_KEYS else f'{key}_{value}' for key, value in hyperparams.items()])
hyper_args.save_dir = os.path.join(hyper_args.save_dir, folder_name)
for key, value in hyperparams.items():
setattr(hyper_args, key, value)
# Record hyperparameters
logger.info(hyperparams)
# Cross validate
mean_score, std_score = cross_validate(hyper_args, TRAIN_LOGGER)
# Record results
temp_model = build_model(hyper_args)
num_params = param_count(temp_model)
logger.info(f'num params: {num_params:,}')
logger.info(f'{mean_score} +/- {std_score} {hyper_args.metric}')
results.append({
'mean_score': mean_score,
'std_score': std_score,
'hyperparams': hyperparams,
'num_params': num_params
})
# Deal with nan
if np.isnan(mean_score):
if hyper_args.dataset_type == 'classification':
mean_score = 0
else:
raise ValueError('Can\'t handle nan score for non-classification dataset.')
return (1 if hyper_args.minimize_score else -1) * mean_score
fmin(objective, SPACE, algo=tpe.suggest, max_evals=args.num_iters)
# Report best result
results = [result for result in results if not np.isnan(result['mean_score'])]
best_result = min(results, key=lambda result: (1 if args.minimize_score else -1) * result['mean_score'])
logger.info('best')
logger.info(best_result['hyperparams'])
logger.info(f'num params: {best_result["num_params"]:,}')
logger.info(f'{best_result["mean_score"]} +/- {best_result["std_score"]} {args.metric}')
# Save best hyperparameter settings as JSON config file
with open(args.config_save_path, 'w') as f:
json.dump(best_result['hyperparams'], f, indent=4, sort_keys=True)
if __name__ == '__main__':
parser = ArgumentParser()
add_train_args(parser)
parser.add_argument('--num_iters', type=int, default=20,
help='Number of hyperparameter choices to try')
parser.add_argument('--config_save_path', type=str, required=True,
help='Path to .json file where best hyperparameter settings will be written')
parser.add_argument('--log_path', type=str,
help='(Optional) Path to .log file where all results of the hyperparameter optimization will be written')
args = parser.parse_args()
modify_train_args(args)
grid_search(args)