forked from AntixK/PyTorch-VAE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
86 lines (72 loc) · 2.57 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import yaml
import argparse
import numpy as np
from pathlib import Path
from models import *
from experiment import VAEXperiment
import torch.backends.cudnn as cudnn
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from dataset import VAEDataset, RoveVAEDataset, BIOSCANVAEDataset, BIOSCANTreeVAEDataset, BIOSCANPMLDataset, SimDataModule
from pytorch_lightning.strategies import DDPStrategy
parser = argparse.ArgumentParser(description="Generic runner for VAE models")
parser.add_argument(
"--config",
"-c",
dest="filename",
metavar="FILE",
help="path to the config file",
default="configs/vae.yaml",
)
args = parser.parse_args()
with open(args.filename, "r") as file:
try:
config = yaml.safe_load(file)
except yaml.YAMLError as exc:
print(exc)
tb_logger = TensorBoardLogger(
save_dir=config["logging_params"]["save_dir"],
name=config["model_params"]["name"],
)
# For reproducibility
seed_everything(config["exp_params"]["manual_seed"], True)
model = vae_models[config["model_params"]["name"]](**config["model_params"])
if config["model_params"].get("checkpoint", None) is not None:
experiment = VAEXperiment.load_from_checkpoint(
config["model_params"]["checkpoint"],
vae_model=model,
params=config["exp_params"],
)
else:
experiment = VAEXperiment(model, config["exp_params"])
data = SimDataModule(
**config["data_params"], pin_memory=config["trainer_params"]["accelerator"] == "gpu"
)
data.setup()
runner = Trainer(
logger=tb_logger,
callbacks=[
LearningRateMonitor(),
ModelCheckpoint(
save_top_k=2,
dirpath=os.path.join(tb_logger.log_dir, "checkpoints"),
monitor="val_loss",
save_last=True,
),
],
strategy="ddp_find_unused_parameters_true",
use_distributed_sampler=False,
**config["trainer_params"],
)
Path(f"{tb_logger.log_dir}/Inputs").mkdir(exist_ok=True, parents=True)
Path(f"{tb_logger.log_dir}/Samples").mkdir(exist_ok=True, parents=True)
Path(f"{tb_logger.log_dir}/Reconstructions").mkdir(exist_ok=True, parents=True)
Path(f"{tb_logger.log_dir}/Latent").mkdir(exist_ok=True, parents=True)
config_fp = f"{tb_logger.log_dir}/config.yaml"
with open(config_fp, 'w') as outfile:
yaml.dump(config, outfile, default_flow_style=False)
print(f"======= Training {config['model_params']['name']} =======")
runner.fit(experiment, datamodule=data)