-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
55 lines (44 loc) · 1.78 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import socketserver
import sys
from argparse import ArgumentParser
from ast import parse
from typing import List
from omegaconf import OmegaConf
from torch.cuda import device_count
from torch.multiprocessing import spawn
from config import Config, parse_configs
from learner import train, train_distributed
def _get_free_port():
with socketserver.TCPServer(('localhost', 0), None) as s:
return s.server_address[1]
def main(argv: List[str]):
parser = ArgumentParser(description="Train a Diffwave model.")
parser.add_argument("--config", type=str, required=True,
help="Configuration file for model.")
args = parser.parse_args(argv[1:-1])
# First create the base config
cfg = OmegaConf.load(args.config)
cli_cfg = OmegaConf.from_cli(argv[-1].split("::")) if argv[-1] != "" else None
cfg: Config = Config(**parse_configs(cfg, cli_cfg))
# Setup training
world_size = device_count()
if world_size != cfg.distributed.world_size:
raise ValueError(
"Requested world size is not the same as number of visible GPUs.")
if cfg.distributed.distributed:
if world_size < 2:
raise ValueError(
f"Distributed training cannot be run on machine with {world_size} device(s).")
if cfg.data.batch_size % world_size != 0:
raise ValueError(
f"Batch size {cfg.data.batch_size} is not evenly divisble by # GPUs = {world_size}.")
cfg.data.batch_size = cfg.data.batch_size // world_size
port = _get_free_port()
spawn(train_distributed, args=(world_size, port, cfg), nprocs=world_size, join=True)
else:
train(cfg)
if __name__ == "__main__":
argv = sys.argv
if len(sys.argv) == 3:
argv = argv + [""]
main(argv)