Skip to content

Commit

Permalink
fix: address comments from #136 (#137)
Browse files Browse the repository at this point in the history
* fix: address comments from #136

* fix: return trainer, regex replace

* fix: move comments inside `ejs` conditions

* fix: replace with \n

* fix: add sampler argument

* fix: drop timer, show more options for saving when checked, separate model file for segmentation

- drop timer handlers
- show more options that are needed for saving checkpoints if user wants to save checkpoints
- put a separate model file in segmentation template
- more simple conditions for setup_handlers

* fix: put model.py in templates.json

* fix: add missing }
  • Loading branch information
ydcjeff committed May 27, 2021
1 parent 3bb7d50 commit 9682276
Show file tree
Hide file tree
Showing 24 changed files with 193 additions and 278 deletions.
3 changes: 0 additions & 3 deletions __tests__/text-classification.spec.js
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,6 @@ test('text classification all', async () => {
await page.check('#terminate_on_nan-checkbox')
expect(await page.isChecked('#terminate_on_nan-checkbox')).toBeTruthy()

await page.check('#timer-checkbox')
expect(await page.isChecked('#timer-checkbox')).toBeTruthy()

await page.fill('#patience-input-number', '2')
expect(await page.$eval('#patience-input-number', (e) => e.value)).toBe('2')

Expand Down
3 changes: 0 additions & 3 deletions __tests__/vision-classification.spec.js
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,6 @@ test('vision classification all', async () => {
await page.check('#terminate_on_nan-checkbox')
expect(await page.isChecked('#terminate_on_nan-checkbox')).toBeTruthy()

await page.check('#timer-checkbox')
expect(await page.isChecked('#timer-checkbox')).toBeTruthy()

await page.fill('#patience-input-number', '2')
expect(await page.$eval('#patience-input-number', (e) => e.value)).toBe('2')

Expand Down
3 changes: 0 additions & 3 deletions __tests__/vision-dcgan.spec.js
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,6 @@ test('vision dcgan all', async () => {
await page.check('#terminate_on_nan-checkbox')
expect(await page.isChecked('#terminate_on_nan-checkbox')).toBeTruthy()

await page.check('#timer-checkbox')
expect(await page.isChecked('#timer-checkbox')).toBeTruthy()

await page.fill('#patience-input-number', '2')
expect(await page.$eval('#patience-input-number', (e) => e.value)).toBe('2')

Expand Down
3 changes: 0 additions & 3 deletions __tests__/vision-segmentation.spec.js
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,6 @@ test('vision segmentation all', async () => {
await page.check('#terminate_on_nan-checkbox')
expect(await page.isChecked('#terminate_on_nan-checkbox')).toBeTruthy()

await page.check('#timer-checkbox')
expect(await page.isChecked('#timer-checkbox')).toBeTruthy()

await page.fill('#patience-input-number', '2')
expect(await page.$eval('#patience-input-number', (e) => e.value)).toBe('2')

Expand Down
8 changes: 5 additions & 3 deletions src/components/TabHandlers.vue
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,19 @@
:saveKey="save_evaluation.name"
/>
<FormInput
v-if="store.config.save_training"
:label="filename_prefix.description"
:saveKey="filename_prefix.name"
:type="filename_prefix.type"
/>
<FormInput
v-if="store.config.save_training"
:label="save_every_iters.description"
:saveKey="save_every_iters.name"
:type="save_every_iters.type"
/>
<FormInput
v-if="store.config.save_training || store.config.save_evaluation"
:label="n_saved.description"
:saveKey="n_saved.name"
:type="n_saved.type"
Expand All @@ -30,8 +33,6 @@
:label="terminate_on_nan.description"
:saveKey="terminate_on_nan.name"
/>
<h2>Events Timer</h2>
<FormCheckbox :label="timer.description" :saveKey="timer.name" />
<h2>Early Stopping</h2>
<FormInput
:label="patience.description"
Expand All @@ -51,11 +52,12 @@
import { handlers } from '../metadata/metadata.json'
import FormInput from './FormInput.vue'
import FormCheckbox from './FormCheckbox.vue'
import { store } from '../store.js'
export default {
components: { FormInput, FormCheckbox },
setup() {
return { ...handlers }
return { ...handlers, store }
}
}
</script>
2 changes: 1 addition & 1 deletion src/metadata/metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
"save_training": {
"name": "save_training",
"type": "checkbox",
"description": "Save the training state (models, optimizers, trainers, ...) by every save_every_iters."
"description": "Save the training state (models, optimizers, trainers, ...)."
},
"save_evaluation": {
"name": "save_evaluation",
Expand Down
3 changes: 2 additions & 1 deletion src/store.js
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ export function genCode() {
}
store.code[file] = ejs
.render(currentFiles[file], store.config)
.replaceAll(/(\n\n\n\n)+/gi, '\n')
.replaceAll(/\s{4}\n/gi, '\n')
.replaceAll(/(\n{3,})/gi, '\n\n')
}
if (isDev) {
store.code[__DEV_CONFIG_FILE__] = JSON.stringify(store.config, null, 2)
Expand Down
30 changes: 12 additions & 18 deletions src/templates/template-common/main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
ckpt_handler_train, ckpt_handler_eval, timer = setup_handlers(
ckpt_handler_train, ckpt_handler_eval = setup_handlers(
trainer, evaluator, config, to_save_train, to_save_eval
)

#::: if (it.save_training || it.save_evaluation || it.patience || it.terminate_on_nan || it.timer || it.limit_sec) { :::#
if timer is not None:
logger.info("Time per batch: %.4f seconds", timer.value())
timer.reset()
#::: } :::#

#::: if (it.logger) { :::#
if rank == 0:
from ignite.contrib.handlers.wandb_logger import WandBLogger
Expand All @@ -20,20 +14,20 @@
exp_logger.close()
#::: } :::#

#::: if (it.save_training || it.save_evaluation || it.patience || it.terminate_on_nan || it.timer || it.limit_sec) { :::#
if ckpt_handler_train is not None:
logger.info(
"Last training checkpoint name - %s",
ckpt_handler_train.last_checkpoint,
)
#::: if (it.save_training || it.save_evaluation) { :::#
# show last checkpoint names
logger.info(
"Last training checkpoint name - %s",
ckpt_handler_train.last_checkpoint,
)

if ckpt_handler_eval is not None:
logger.info(
"Last evaluation checkpoint name - %s",
ckpt_handler_eval.last_checkpoint,
)
logger.info(
"Last evaluation checkpoint name - %s",
ckpt_handler_eval.last_checkpoint,
)
#::: } :::#


# main entrypoint
def main():
config = setup_parser().parse_args()
Expand Down
22 changes: 5 additions & 17 deletions src/templates/template-common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,7 @@ def setup_logging(config: Any) -> Logger:
return logger


#::: if (it.save_training || it.save_evaluation || it.patience || it.terminate_on_nan || it.timer || it.limit_sec) { :::#


#::: if (it.save_training || it.save_evaluation || it.patience || it.terminate_on_nan || it.limit_sec) { :::#
def setup_handlers(
trainer: Engine,
evaluator: Engine,
Expand All @@ -144,7 +142,7 @@ def setup_handlers(
):
"""Setup Ignite handlers."""

ckpt_handler_train = ckpt_handler_eval = timer = None
ckpt_handler_train = ckpt_handler_eval = None
#::: if (it.save_training || it.save_evaluation) { :::#
# checkpointing
saver = DiskSaver(config.output_dir / "checkpoints", require_empty=False)
Expand Down Expand Up @@ -191,25 +189,15 @@ def setup_handlers(
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
#::: } :::#

#::: if (it.timer) { :::#
# timer
timer = Timer(average=True)
timer.attach(
trainer,
start=Events.EPOCH_STARTED,
resume=Events.ITERATION_STARTED,
pause=Events.ITERATION_COMPLETED,
step=Events.ITERATION_COMPLETED,
)
#::: } :::#

#::: if (it.limit_sec) { :::#
# time limit
trainer.add_event_handler(
Events.ITERATION_COMPLETED, TimeLimit(config.limit_sec)
)
#::: } :::#
return ckpt_handler_train, ckpt_handler_eval, timer
#::: if (it.save_training || it.save_evaluation) { :::#
return ckpt_handler_train, ckpt_handler_eval
#::: } :::#


#::: } :::#
Expand Down
55 changes: 19 additions & 36 deletions src/templates/template-text-classification/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from model import TransformerModel
from torch import nn, optim
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data.distributed import DistributedSampler
from trainers import setup_evaluator, setup_trainer
from utils import *

Expand Down Expand Up @@ -72,7 +71,9 @@ def run(local_rank: int, config: Any):
}

# trainer and evaluator
trainer = setup_trainer(config, model, optimizer, loss_fn, device)
trainer = setup_trainer(
config, model, optimizer, loss_fn, device, dataloader_train.sampler
)
evaluator = setup_evaluator(config, model, metrics, device)

# setup engines logger with python logging
Expand All @@ -82,14 +83,6 @@ def run(local_rank: int, config: Any):
(config.output_dir / "config-lock.yaml").write_text(yaml.dump(config))
trainer.logger = evaluator.logger = logger

# set epoch for distributed sampler
@trainer.on(Events.EPOCH_STARTED)
def set_epoch():
if idist.get_world_size() > 1 and isinstance(
dataloader_train.sampler, DistributedSampler
):
dataloader_train.sampler.set_epoch(trainer.state.epoch - 1)

if isinstance(lr_scheduler, _LRScheduler):
trainer.add_event_handler(
Events.ITERATION_COMPLETED,
Expand All @@ -101,8 +94,7 @@ def set_epoch():
trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)

# setup ignite handlers
#::: if (it.save_training || it.save_evaluation || it.patience || it.terminate_on_nan || it.timer || it.limit_sec) { :::#

#::: if (it.save_training || it.save_evaluation) { :::#
#::: if (it.save_training) { :::#
to_save_train = {
"model": model,
Expand All @@ -113,20 +105,20 @@ def set_epoch():
#::: } else { :::#
to_save_train = None
#::: } :::#

#::: if (it.save_evaluation) { :::#
to_save_eval = {"model": model}
#::: } else { :::#
to_save_eval = None
#::: } :::#

ckpt_handler_train, ckpt_handler_eval, timer = setup_handlers(
ckpt_handler_train, ckpt_handler_eval = setup_handlers(
trainer, evaluator, config, to_save_train, to_save_eval
)
#::: } else if (it.patience || it.terminate_on_nan || it.limit_sec) { :::#
setup_handlers(trainer, evaluator, config)
#::: } :::#

# experiment tracking
#::: if (it.logger) { :::#
# experiment tracking
if rank == 0:
exp_logger = setup_exp_logging(config, trainer, optimizer, evaluator)
#::: } :::#
Expand All @@ -147,13 +139,6 @@ def set_epoch():
# for evaluation stats
@trainer.on(Events.EPOCH_COMPLETED(every=1))
def _():
# show timer
#::: if (it.save_training || it.save_evaluation || it.patience || it.terminate_on_nan || it.timer || it.limit_sec) { :::#
if timer is not None:
logger.info("Time per batch: %.4f seconds", timer.value())
timer.reset()
#::: } :::#

evaluator.run(dataloader_eval, epoch_length=config.eval_epoch_length)
log_metrics(evaluator, "eval")

Expand All @@ -169,8 +154,8 @@ def _():
epoch_length=config.train_epoch_length,
)

# close logger
#::: if (it.logger) { :::#
# close logger
if rank == 0:
from ignite.contrib.handlers.wandb_logger import WandBLogger

Expand All @@ -182,19 +167,17 @@ def _():
exp_logger.close()
#::: } :::#

# show the last checkpoint filename
#::: if (it.save_training || it.save_evaluation || it.patience || it.terminate_on_nan || it.timer || it.limit_sec) { :::#
if ckpt_handler_train is not None:
logger.info(
"Last training checkpoint name - %s",
ckpt_handler_train.last_checkpoint,
)
#::: if (it.save_training || it.save_evaluation) { :::#
# show last checkpoint names
logger.info(
"Last training checkpoint name - %s",
ckpt_handler_train.last_checkpoint,
)

if ckpt_handler_eval is not None:
logger.info(
"Last evaluation checkpoint name - %s",
ckpt_handler_eval.last_checkpoint,
)
logger.info(
"Last evaluation checkpoint name - %s",
ckpt_handler_eval.last_checkpoint,
)
#::: } :::#


Expand Down
19 changes: 16 additions & 3 deletions src/templates/template-text-classification/trainers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Any, Dict, Union

import ignite.distributed as idist
import torch
from ignite.engine import DeterministicEngine, Engine
from ignite.engine import DeterministicEngine, Engine, Events
from ignite.metrics.metric import Metric
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.optim.optimizer import Optimizer
from torch.utils.data import DistributedSampler, Sampler


def setup_trainer(
Expand All @@ -14,6 +16,7 @@ def setup_trainer(
optimizer: Optimizer,
loss_fn: nn.Module,
device: Union[str, torch.device],
train_sampler: Sampler,
) -> Union[Engine, DeterministicEngine]:

scaler = GradScaler(enabled=config.use_amp)
Expand Down Expand Up @@ -50,11 +53,21 @@ def train_function(engine: Union[Engine, DeterministicEngine], batch: Any):
return metric

#::: if(it.deterministic) { :::#
return DeterministicEngine(train_function)
trainer = DeterministicEngine(train_function)
#::: } else { :::#
return Engine(train_function)
trainer = Engine(train_function)
#::: } :::#

# set epoch for distributed sampler
@trainer.on(Events.EPOCH_STARTED)
def set_epoch():
if idist.get_world_size() > 1 and isinstance(
train_sampler, DistributedSampler
):
train_sampler.set_epoch(trainer.state.epoch - 1)

return trainer


def setup_evaluator(
config: Any,
Expand Down

0 comments on commit 9682276

Please sign in to comment.