/
main.py
43 lines (32 loc) · 959 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import numpy as np
import seml
import torch
from sacred import Experiment as SacredExperiment
import time
from experiment import Experiment
ex = SacredExperiment()
seml.setup_logger(ex)
@ex.post_run_hook
def collect_stats(_run):
seml.collect_exp_stats(_run)
@ex.config
def config():
overwrite = None
db_collection = None
if db_collection is not None:
ex.observers.append(seml.create_mongodb_observer(
db_collection, overwrite=overwrite))
@ex.automain
def run(_config, conf: dict, hparams: dict):
start = time.time()
experiment = Experiment()
results, dict_to_save = experiment.run(hparams)
end = time.time()
print(f"time={end-start}s")
results['time'] = end-start
save_dir = conf["save_dir"]
run_id = _config['overwrite']
db_collection = _config['db_collection']
if conf["save"]:
torch.save(dict_to_save, f'{save_dir}/{db_collection}_{run_id}')
return results