-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_selection.py
136 lines (110 loc) · 4.38 KB
/
model_selection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import gc
from typing import Tuple, Callable
import optuna
import torch
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.data import Dataset
from torch_geometric.loader import DataLoader
from hyperparams import HyperparamsSpace, Hyperparams
from metrics import Metric
from models import GNN
from training import Trainer
from utils import mkdir_if_not_exists
def define_objective(
dataset: Dataset,
split: Tuple[list, list],
hyperparams_space: HyperparamsSpace,
evaluation_metric: Metric,
task: str,
pruning: bool = True,
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
) -> Callable:
"""Define the objective function to be optimized.
Args:
dataset: Dataset to be used for training.
split: Split of the dataset to be used for training and validation.
hyperparams_space: Hyperparameters space.
evaluation_metric: Metric to be used for model evaluation.
task: Task to be performed (classification or regression).
pruning: Whether to use pruning or not.
device: Device to use. (default: "cuda")
"""
def objective(trial: optuna.Trial) -> float:
"""Objective function to be optimized by optuna."""
gc.collect()
torch.cuda.empty_cache()
hyperparams = hyperparams_space.pick(trial)
print(f"Hyperparameters\n: {hyperparams}")
train_idx, val_idx = split
train_loader = DataLoader(dataset[list(train_idx)], hyperparams.batch_size, shuffle=True)
val_loader = DataLoader(dataset[list(val_idx)], hyperparams.batch_size, shuffle=False)
# Generate the model.
out_channels = 1 if task == "regression" else dataset.num_classes
model = GNN(in_channels=dataset.num_features, out_channels=out_channels, hyperparams=hyperparams)
optimizer = torch.optim.Adam(
model.parameters(),
lr=hyperparams.lr,
weight_decay=hyperparams.weight_decay,
)
criterion = torch.nn.MSELoss() if task == "regression" else torch.nn.CrossEntropyLoss()
writer = SummaryWriter(f"runs/{trial.study.study_name}/trial{trial.number:04d}")
trainer = Trainer(model, optimizer, criterion, evaluation_metric, writer=writer, device=device)
trainer.set_early_stopping(patience=hyperparams.patience, min_epochs=hyperparams.min_epochs)
if pruning:
trainer.set_optuna_trial_pruning(trial)
trainer.train(train_loader, val_loader, epochs=hyperparams.max_epochs)
return trainer.get_best_metric_score()
return objective
def select_hyperparams(
dataset: Dataset,
split: Tuple[list, list],
study_name: str,
hyperparams_space: HyperparamsSpace,
metric: Metric,
task: str,
pruning: bool = True,
n_trials: int = 10,
n_jobs: int = 1,
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
) -> Hyperparams:
"""Select the best hyperparameters for the given dataset and split.
Args:
device:
dataset: Dataset to be used for training.
split: Split of the dataset to be used for training and validation.
study_name: Name of the study.
hyperparams_space: Hyperparameters space.
metric: Metric to be used for model evaluation.
task: Task to be performed (classification or regression).
pruning: Whether to use pruning or not.
n_trials: Number of trials to be performed.
n_jobs: Number of parallel jobs. (default: 1)
"""
study = optuna.create_study(
study_name=f"{study_name}",
direction=metric.direction(),
sampler=optuna.samplers.TPESampler(seed=0),
load_if_exists=True,
)
mkdir_if_not_exists(f"runs/{study_name}")
hyperparams_space.save(f"runs/{study_name}/hyperparams_space.yml")
objective = define_objective(
dataset,
split,
hyperparams_space=hyperparams_space,
evaluation_metric=metric,
task=task,
pruning=pruning,
device=device,
)
study.optimize(
objective,
n_trials=n_trials,
n_jobs=n_jobs,
)
best_hyperparams = Hyperparams(
**study.best_params,
**{k: v for k, v in hyperparams_space.__dict__.items() if k not in study.best_params},
)
best_hyperparams.save(f"runs/{study_name}/best_hyperparams.yml")
return best_hyperparams