Skip to content

Commit

Permalink
Add "only_best_status" option to Persist (#233)
Browse files Browse the repository at this point in the history
* add only_best_status option

* fix test case

* only save checkpoint which name end with _1

* improve doc

* fix typo
  • Loading branch information
TsumiNa committed Mar 9, 2021
1 parent d504a1f commit 7b48997
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 12 deletions.
83 changes: 76 additions & 7 deletions tests/models/test_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os

from xenonpy.model import SequentialLinear
from xenonpy.model.training import Trainer
from xenonpy.model.training.base import BaseExtension, BaseRunner
from xenonpy.model.training.extension import TensorConverter, Validator, Persist
from xenonpy.model.utils import regression_metrics, classification_metrics
Expand All @@ -32,9 +33,30 @@ def data():

yield

rmtree(str(Path('.').resolve() / 'test_model'))
rmtree(str(Path('.').resolve() / 'test_model@1'))
rmtree(str(Path('.').resolve() / Path(os.getcwd()).name))
try:
rmtree(str(Path('.').resolve() / 'test_model'))
except:
pass
try:
rmtree(str(Path('.').resolve() / 'test_model@1'))
except:
pass
try:
rmtree(str(Path('.').resolve() / 'test_model_1'))
except:
pass
try:
rmtree(str(Path('.').resolve() / 'test_model_2'))
except:
pass
try:
rmtree(str(Path('.').resolve() / 'test_model_3'))
except:
pass
try:
rmtree(str(Path('.').resolve() / Path(os.getcwd()).name))
except:
pass

print('test over')

Expand Down Expand Up @@ -293,8 +315,8 @@ def predict(self, x_, y_):
val.step_forward(trainer=_Trainer(), step_info=step_info) # noqa
assert step_info['val_mae'] == regression_metrics(y, x)['mae']
assert set(step_info.keys()) == {
'i_epoch', 'val_mae', 'val_mse', 'val_rmse', 'val_r2', 'val_pearsonr', 'val_spearmanr', 'val_p_value',
'val_max_ae', 'train_loss'
'i_epoch', 'val_mae', 'val_mse', 'val_rmse', 'val_r2', 'val_pearsonr', 'val_spearmanr',
'val_p_value', 'val_max_ae', 'train_loss'
}


Expand Down Expand Up @@ -325,8 +347,8 @@ def predict(self, x_, y_): # noqa
val.step_forward(trainer=_Trainer(), step_info=step_info) # noqa
assert step_info['val_f1'] == classification_metrics(y, x)['f1']
assert set(step_info.keys()) == {
'i_epoch', 'val_accuracy', 'val_f1', 'val_precision', 'val_recall', 'val_macro_f1', 'val_macro_precision',
'val_macro_recall', 'train_loss'
'i_epoch', 'val_accuracy', 'val_f1', 'val_precision', 'val_recall', 'val_macro_f1',
'val_macro_precision', 'val_macro_recall', 'train_loss'
}


Expand Down Expand Up @@ -368,5 +390,52 @@ def predict(self, x_, y_): # noqa
assert (Path('.').resolve() / 'test_model@1' / 'model_structure.pkl.z').exists()


def test_persist_save_checkpoints(data):

class _Trainer(BaseRunner):

def __init__(self):
super().__init__()
self.model = SequentialLinear(50, 2)

def predict(self, x_, y_): # noqa
return x_, y_

cp_1 = Trainer.checkpoint_tuple(
id='cp_1',
iterations=111,
model_state=SequentialLinear(50, 2).state_dict(),
)
cp_2 = Trainer.checkpoint_tuple(
id='cp_2',
iterations=111,
model_state=SequentialLinear(50, 2).state_dict(),
)

# save checkpoint
p = Persist('test_model_1', increment=False, only_best_states=False)
p.before_proc(trainer=_Trainer())
p.on_checkpoint(cp_1, trainer=_Trainer())
p.on_checkpoint(cp_2, trainer=_Trainer())
assert (Path('.').resolve() / 'test_model_1' / 'checkpoints' / 'cp_1.pth.s').exists()
assert (Path('.').resolve() / 'test_model_1' / 'checkpoints' / 'cp_2.pth.s').exists()

# reduced save checkpoint
p = Persist('test_model_2', increment=False, only_best_states=True)
p.before_proc(trainer=_Trainer())
p.on_checkpoint(cp_1, trainer=_Trainer())
p.on_checkpoint(cp_2, trainer=_Trainer())
assert (Path('.').resolve() / 'test_model_2' / 'checkpoints' / 'cp.pth.s').exists()
assert not (Path('.').resolve() / 'test_model_2' / 'checkpoints' / 'cp_1.pth.s').exists()
assert not (Path('.').resolve() / 'test_model_2' / 'checkpoints' / 'cp_2.pth.s').exists()

# no checkpoint will be saved
p = Persist('test_model_3', increment=False, only_best_states=True)
p.before_proc(trainer=_Trainer())
p.on_checkpoint(cp_2, trainer=_Trainer())
assert not (Path('.').resolve() / 'test_model_3' / 'checkpoints' / 'cp.pth.s').exists()
assert not (Path('.').resolve() / 'test_model_3' / 'checkpoints' / 'cp_2.pth.s').exists()


if __name__ == "__main__":
pytest.main()
21 changes: 16 additions & 5 deletions xenonpy/model/training/extension/persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ def __init__(self,
*,
model_class: Callable = None,
model_params: Union[tuple, dict, any] = None,
increment=False,
sync_training_step=False,
increment: bool = False,
sync_training_step: bool = False,
only_best_states: bool = False,
**describe: Any):
"""
Expand All @@ -51,13 +52,16 @@ def __init__(self,
sync_training_step
If ``True``, will save ``trainer.training_info`` at each iteration.
Default is ``False``, only save ``trainer.training_info`` at each epoch.
only_best_states
If ``True``, will only save the models with the best states in terms of each of the criteria.
describe:
Any other information to describe this model.
These information will be saved under model dir by name ``describe.pkl.z``.
"""
self._model_class: Callable = model_class
self._model_params: Union[list, dict] = model_params
self.sync_training_step = sync_training_step
self.only_best_states = only_best_states
self._increment = increment
self._describe = describe
self._describe_ = None
Expand Down Expand Up @@ -103,9 +107,16 @@ def __getitem__(self, item):
return self._checker[item]

def on_checkpoint(self, checkpoint: Trainer.checkpoint_tuple, trainer: Trainer) -> None:
key = checkpoint.id
value = deepcopy(checkpoint._asdict())
self._checker.set_checkpoint(**{key: value})
if self.only_best_states:
tmp = checkpoint.id.split('_')
if tmp[-1] == '1':
key = tmp[0]
value = deepcopy(checkpoint._asdict())
self._checker.set_checkpoint(**{key: value})
else:
key = checkpoint.id
value = deepcopy(checkpoint._asdict())
self._checker.set_checkpoint(**{key: value})

def step_forward(self, step_info: OrderedDict, trainer: Trainer) -> None:
if self.sync_training_step:
Expand Down

0 comments on commit 7b48997

Please sign in to comment.