This repository has been archived by the owner on May 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 20
/
main.py
70 lines (52 loc) · 1.63 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# Import packages
import os
import torch
import models
from data.dataset import UIRTDataset
from evaluation.evaluator import Evaluator
from experiment.early_stop import EarlyStop
from loggers import FileLogger, CSVLogger
from utils.general import make_log_dir, set_random_seed
from config import load_config
"""
Configurations
"""
config = load_config()
exp_config = config.experiment
gpu_id = exp_config.gpu
seed = exp_config.seed
dataset_config = config.dataset
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if __name__ == '__main__':
set_random_seed(seed)
"""
Dataset
"""
dataset = UIRTDataset(**dataset_config)
# """
# Early stop
# """
# early_stop = EarlyStop(**config['EarlyStop'])
"""
Model base class
"""
model_name = config.experiment.model_name
model_base = getattr(models, model_name)
hparams = config.hparams
"""
Logger
"""
log_dir = make_log_dir(os.path.join(exp_config.save_dir, model_name))
logger = FileLogger(log_dir)
csv_logger = CSVLogger(log_dir)
# Save log & dataset config.
logger.info(config)
logger.info(dataset)
valid_input, valid_target = dataset.valid_input, dataset.valid_target
evaluator = Evaluator(valid_input, valid_target, protocol=dataset.protocol, ks=config.evaluator.ks)
model = model_base(dataset, hparams, device)
ret = model.fit(dataset, exp_config, evaluator=evaluator, loggers=[logger, csv_logger])
print(ret['scores'])
csv_logger.save()