-
Notifications
You must be signed in to change notification settings - Fork 0
/
Processor.py
70 lines (57 loc) · 2.21 KB
/
Processor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os, json
from datetime import datetime
from utils.Trainers.SimpleTrainer import SimpleTrainer
from utils.Trainers.OrdinalRegressorTrainer import OrdinalRegressorTrainer
class Processor:
def __init__(self, args: dict) -> None:
"""
Summary: Initialization.
Parameters:
args: dict - config of training. See ./main.py.
"""
self.args = args
self._set_datetime_to_config()
self.pytorch_simple_trainer = SimpleTrainer(self.args)
self.pytorch_ordinal_reg_trainer = OrdinalRegressorTrainer(self.args)
def run(self) -> None:
"""
Summary: Running the module
"""
if self.args['mode'] == 'train' and not self.args['ordinal_reg']:
print("BEGIN TRAINING")
self._train()
self._save_config()
elif self.args['ordinal_reg']:
print("BEGIN TRAINING")
self._train_ord()
self._save_config()
else:
raise ValueError('Unknown mode: {}'.format(self.args['mode']))
def _train(self) -> None:
"""
Summary: Trainin the model with user parameters
"""
self.pytorch_simple_trainer.run()
def _train_ord(self) -> None:
"""
Summary: Trainin the model with user parameters
"""
self.pytorch_ordinal_reg_trainer.run()
def _save_config(self) -> None:
"""
Summary: Save the config file
"""
temp_config = self.args
with open(os.path.join(self.args['path_to_save'], 'config.json'), 'w') as f:
json.dump(temp_config, f, indent=4)
del temp_config
def _set_datetime_to_config(self) -> None:
"""
Summary: Get the current date and time and create directory for saving results
"""
self.args['datetime_start'] = datetime.now().strftime("%d%m%Y_%H%M%S")
folder_name = '{}-{}-{}'.format(self.args['datetime_start'],
self.args['model_name'],
self.args['mode'])
self.args['path_to_save'] = os.path.join(self.args['path_to_save'], folder_name)
os.makedirs(self.args['path_to_save'], exist_ok=False)