-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
executable file
·71 lines (60 loc) · 2.53 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#!/usr/bin/env python3
"""Main training and evaluation code."""
import hydra
import lightning as L
import wandb
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
from lightning.pytorch.utilities.model_summary.model_summary import ModelSummary
from omegaconf import OmegaConf
from data import MNISTDataModule
from models import LitCNN
def flatten(d):
"""Flatten nested dictionary - assumes no duplicate keys."""
items = []
# TODO: warn on duplicate keys
for k, v in d.items():
if isinstance(v, (int, float, str, bool)):
items.append((k, v))
else:
items.extend(flatten(v).items())
return dict(items)
class HPMetricCallback(Callback):
"""Callback to log hyperparameters and final model performance (i.e., metrics)."""
def on_train_end(self, trainer, pl_module):
tb_logger = pl_module.loggers[0]
hp_params = flatten(pl_module.cfg)
metrics = {}
for k, v in trainer.logged_metrics.items():
metrics['metric/' + k] = v.item()
tb_logger.log_hyperparams(hp_params, metrics)
@hydra.main(version_base=None, config_path='conf', config_name='config')
def main(cfg):
"""Training and model evaluation."""
print(OmegaConf.to_yaml(cfg))
cnn = LitCNN(cfg.model)
mnist_data = MNISTDataModule(cfg.data)
print(ModelSummary(cnn, max_depth=-1))
import os
p = 'logs'
if not os.path.exists(p):
os.makedirs(p)
wandb.init(project='mnist', config=flatten(cfg),
dir='logs')
tb_logger = TensorBoardLogger(save_dir="logs/tb", name='', log_graph=True,
default_hp_metric=False) # don't log hpparams without metric
wandb_logger = WandbLogger()
wandb_logger.experiment.config.update(flatten(cfg))
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.01,
patience=3, verbose=True, mode="min")
checkpoint_callback = ModelCheckpoint(save_top_k=2, monitor="val_loss",
dirpath='logs/checkpoints')
hp_metric_callback = HPMetricCallback()
trainer = L.Trainer(**cfg.trainer, logger=[tb_logger, wandb_logger],
callbacks=[early_stop_callback, checkpoint_callback, hp_metric_callback])
trainer.fit(cnn, mnist_data)
wandb.finish()
tb_logger.finalize('success')
if __name__ == "__main__":
main()