Skip to content

Commit

Permalink
Revert "Revert "checkpoint consolidation""
Browse files Browse the repository at this point in the history
This reverts commit 3a9fde9.
  • Loading branch information
shuyingsunshine21 committed Mar 24, 2021
1 parent 3a9fde9 commit 7a369f4
Show file tree
Hide file tree
Showing 10 changed files with 99 additions and 39 deletions.
4 changes: 4 additions & 0 deletions pytorch_lightning/callbacks/base.py
Expand Up @@ -109,6 +109,10 @@ def on_epoch_end(self, trainer, pl_module: LightningModule) -> None:
"""Called when the epoch ends."""
pass

def on_train_epoch_final_end(self, trainer, pl_module: LightningModule) -> None:
"""Called when at the very end of train epoch."""
pass

def on_batch_start(self, trainer, pl_module: LightningModule) -> None:
"""Called when the training batch begins."""
pass
Expand Down
15 changes: 15 additions & 0 deletions pytorch_lightning/callbacks/early_stopping.py
Expand Up @@ -143,6 +143,21 @@ def on_validation_end(self, trainer, pl_module):

self._run_early_stopping_check(trainer)

def on_train_epoch_final_end(self, trainer, pl_module):
from pytorch_lightning.trainer.states import TrainerState
if (
trainer.state != TrainerState.FITTING or trainer.sanity_checking
or not trainer.checkpoint_connector.has_trained
):
return
# if validation is disabled or should skip, we run early stopping
# at end of the training epoch
if (
trainer.disable_validation
or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches)
):
self._run_early_stopping_check(trainer)

def _run_early_stopping_check(self, trainer):
"""
Checks whether the early stopping condition is met
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/callbacks/lambda_function.py
Expand Up @@ -53,6 +53,7 @@ def __init__(
on_train_batch_end: Optional[Callable] = None,
on_train_epoch_start: Optional[Callable] = None,
on_train_epoch_end: Optional[Callable] = None,
on_train_epoch_final_end: Optional[Callable] = None,
on_validation_epoch_start: Optional[Callable] = None,
on_validation_epoch_end: Optional[Callable] = None,
on_test_epoch_start: Optional[Callable] = None,
Expand Down Expand Up @@ -155,3 +156,5 @@ def __init__(
self.on_after_backward = on_after_backward
if on_before_zero_grad is not None:
self.on_before_zero_grad = on_before_zero_grad
if on_train_epoch_final_end is not None:
self.on_train_epoch_final_end = on_train_epoch_final_end
31 changes: 31 additions & 0 deletions pytorch_lightning/callbacks/model_checkpoint.py
Expand Up @@ -238,6 +238,37 @@ def on_validation_end(self, trainer, *args, **kwargs) -> None:
return
self.save_checkpoint(trainer)

def on_train_epoch_final_end(self, trainer, pl_module):
"""
at the end of each training epoch, checkpoint only when validation is skipped or disabled
"""
print("aaa: epoch {}, step: {}".format(trainer.current_epoch, trainer.global_step))
if (
self._should_skip_saving_checkpoint(trainer)
or not trainer.checkpoint_connector.has_trained
):
return
# if validation is disabled or should skip, we checkpoint at end of the training epoch
if (
trainer.disable_validation
or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches)
):
self.save_checkpoint(trainer)

def on_train_end(self, trainer, *args, **kwargs) -> None:
"""
checkpoints can be saved at the end of the trianing
"""
trainer.global_step -= 1
if (
not self._should_skip_saving_checkpoint(trainer)
and trainer.checkpoint_connector.has_trained
):
if self.save_last and self.verbose:
rank_zero_info("Saving latest checkpoint...")
self.save_checkpoint(trainer)
trainer.global_step += 1

def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
return {
"monitor": self.monitor,
Expand Down
7 changes: 7 additions & 0 deletions pytorch_lightning/trainer/callback_hook.py
Expand Up @@ -92,6 +92,13 @@ def on_train_epoch_end(self, outputs: List[Any]):
for callback in self.callbacks:
callback.on_train_epoch_end(self, self.lightning_module, outputs)

def on_train_epoch_final_end(self) -> None:
"""
Called when at the very end of train epoch.
"""
for callback in self.callbacks:
callback.on_train_epoch_final_end(self, self.lightning_module)

def on_validation_epoch_start(self):
"""Called when the epoch begins."""
for callback in self.callbacks:
Expand Down
Expand Up @@ -100,6 +100,11 @@ def _on_train_epoch_end_log():
"""Called when the epoch ends."""
return {"on_step": [False], "on_epoch": [False, True]}

@staticmethod
def _on_train_epoch_final_end_log():
"""Called when at the very end of train epoch."""
return {"on_step": [False], "on_epoch": [False, True]}

@staticmethod
def _on_validation_epoch_start_log():
"""Called when the epoch begins."""
Expand Down
35 changes: 3 additions & 32 deletions pytorch_lightning/trainer/training_loop.py
Expand Up @@ -121,12 +121,6 @@ def on_train_end(self):
return
self._teardown_already_run = True

# trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates
# when a checkpoint was saved at the last step
self.trainer.global_step -= 1
self.check_checkpoint_callback(should_update=True, is_last=True)
self.trainer.global_step += 1

# hook
self.trainer.call_hook("on_train_end")

Expand All @@ -145,28 +139,6 @@ def on_train_end(self):
# reset bookkeeping
self.trainer._running_stage = None

def check_checkpoint_callback(self, should_update, is_last=False):
# TODO bake this logic into the ModelCheckpoint callback
if should_update and self.trainer.checkpoint_connector.has_trained:
callbacks = self.trainer.checkpoint_callbacks

if is_last and any(cb.save_last and cb.verbose for cb in callbacks):
rank_zero_info("Saving latest checkpoint...")

model = self.trainer.lightning_module

for cb in callbacks:
cb.on_validation_end(self.trainer, model)

def check_early_stopping_callback(self, should_update):
# TODO bake this logic into the EarlyStopping callback
if should_update and self.trainer.checkpoint_connector.has_trained:
callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)]
model = self.trainer.lightning_module

for cb in callbacks:
cb.on_validation_end(self.trainer, model)

def on_train_epoch_start(self, epoch):

# update training progress in trainer
Expand Down Expand Up @@ -562,15 +534,14 @@ def run_training_epoch(self):
if (val_loop_called and not should_check_val) or should_train_only:
self.trainer.optimizer_connector.update_learning_rates(interval='epoch')

if should_train_only:
self.check_checkpoint_callback(True)
self.check_early_stopping_callback(True)

if should_check_val:
self.trainer.validating = True
self.trainer.run_evaluation(on_epoch=True)
self.trainer.training = True

if should_train_only:
self.trainer.call_hook('on_train_epoch_final_end')

# increment the global step once
# progress global step according to grads progress
self.increment_accumulated_grad_global_step()
Expand Down
35 changes: 29 additions & 6 deletions tests/checkpointing/test_model_checkpoint.py
Expand Up @@ -609,7 +609,13 @@ def test_model_checkpoint_period(tmpdir, period: int):
trainer.fit(model)

# check that the correct ckpts were created
expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else []
final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1)
expected = (
[f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % period and e + 1 != epochs]
if period > 0
else []
)
expected.append(final_epoch_ckpt)
assert set(os.listdir(tmpdir)) == set(expected)


Expand All @@ -631,8 +637,14 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs):
trainer.fit(model)

# check that the correct ckpts were created
expected = [f'epoch={e}.ckpt' for e in range(epochs)
if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else []
# check that the correct ckpts were created
final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1)
expected = (
[f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs]
if every_n_val_epochs > 0
else []
)
expected.append(final_epoch_ckpt)
assert set(os.listdir(tmpdir)) == set(expected)


Expand All @@ -659,8 +671,14 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc
trainer.fit(model)

# check that the correct ckpts were created
expected = [f'epoch={e}.ckpt' for e in range(epochs)
if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else []
# check that the correct ckpts were created
final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1)
expected = (
[f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs]
if every_n_val_epochs > 0
else []
)
expected.append(final_epoch_ckpt)
assert set(os.listdir(tmpdir)) == set(expected)


Expand Down Expand Up @@ -816,10 +834,15 @@ def test_model_checkpoint_save_last_warning(
default_root_dir=tmpdir,
callbacks=[ckpt],
max_epochs=max_epochs,
val_check_interval=0.1,
)
with caplog.at_level(logging.INFO):
trainer.fit(model)
assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last)
if verbose and save_last and not should_validate:
# no validation, hence checkpoint triggered at the end of each training epoch
assert caplog.messages.count('Saving latest checkpoint...') == False
else:
assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last)


def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
Expand Down
2 changes: 1 addition & 1 deletion tests/helpers/utils.py
Expand Up @@ -76,7 +76,7 @@ def reset_seed(seed=0):
def set_random_master_port():
reset_seed()
port = RANDOM_PORTS.pop()
os.environ['MASTER_PORT'] = str(port)
os.environ['MASTER_PORT'] = "29501"


def init_checkpoint_callback(logger):
Expand Down
1 change: 1 addition & 0 deletions tests/trainer/logging_/test_logger_connector.py
Expand Up @@ -300,6 +300,7 @@ def test_call_back_validator(tmpdir):
'on_train_batch_start',
'on_train_end',
'on_train_epoch_end',
'on_train_epoch_final_end',
'on_train_epoch_start',
'on_train_start',
'on_validation_batch_end',
Expand Down

0 comments on commit 7a369f4

Please sign in to comment.