Skip to content

Commit

Permalink
Allow persisit do not save connected model and checkpoints (#241)
Browse files Browse the repository at this point in the history
* rebase to master

* fix test case
  • Loading branch information
TsumiNa committed Oct 3, 2021
1 parent ad04f74 commit 245f18e
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 106 deletions.
29 changes: 14 additions & 15 deletions tests/models/test_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,60 +229,60 @@ def test_tensor_converter_3():
tensor_ = torch.from_numpy(np_)

converter = TensorConverter()
y, y_ = converter.output_proc(tensor_, None, training=True)
y, y_ = converter.output_proc(tensor_, None, is_training=True)
assert y_ is None
assert isinstance(y, torch.Tensor)
assert y.shape == (2, 3)
assert torch.equal(y, tensor_)

y, y_ = converter.output_proc(tensor_, tensor_, training=True)
y, y_ = converter.output_proc(tensor_, tensor_, is_training=True)
assert isinstance(y, torch.Tensor)
assert isinstance(y_, torch.Tensor)
assert y.equal(y_)
assert y.shape == (2, 3)
assert torch.equal(y, tensor_)

y, _ = converter.output_proc((tensor_,), None, training=True)
y, _ = converter.output_proc((tensor_,), None, is_training=True)
assert isinstance(y, tuple)
assert isinstance(y[0], torch.Tensor)
assert torch.equal(y[0], tensor_)

y, y_ = converter.output_proc(tensor_, tensor_, training=False)
y, y_ = converter.output_proc(tensor_, tensor_, is_training=False)
assert isinstance(y, np.ndarray)
assert isinstance(y_, np.ndarray)
assert np.all(y == y_)
assert y.shape == (2, 3)
assert np.all(y == tensor_.numpy())

y, _ = converter.output_proc((tensor_,), None, training=False)
y, _ = converter.output_proc((tensor_,), None, is_training=False)
assert isinstance(y, tuple)
assert isinstance(y[0], np.ndarray)
assert np.all(y[0] == tensor_.numpy())

converter = TensorConverter(argmax=True)
y, y_ = converter.output_proc(tensor_, tensor_, training=False)
y, y_ = converter.output_proc(tensor_, tensor_, is_training=False)
assert isinstance(y, np.ndarray)
assert isinstance(y_, np.ndarray)
assert y.shape == (2,)
assert y_.shape == (2, 3)
assert np.all(y == np.argmax(np_, 1))

y, y_ = converter.output_proc((tensor_, tensor_), None, training=False)
y, y_ = converter.output_proc((tensor_, tensor_), None, is_training=False)
assert isinstance(y, tuple)
assert y_ is None
assert y[0].shape == (2,)
assert y[0].shape == y[1].shape
assert np.all(y[0] == np.argmax(np_, 1))

converter = TensorConverter(probability=True)
y, y_ = converter.output_proc(tensor_, tensor_, training=False)
y, y_ = converter.output_proc(tensor_, tensor_, is_training=False)
assert isinstance(y, np.ndarray)
assert isinstance(y_, np.ndarray)
assert y.shape == (2, 3)
assert y_.shape == (2, 3)
assert np.all(y == softmax(np_, 1))

y, y_ = converter.output_proc((tensor_, tensor_), None, training=False)
y, y_ = converter.output_proc((tensor_, tensor_), None, is_training=False)
assert isinstance(y, tuple)
assert y_ is None
assert y[0].shape == (2, 3)
Expand Down Expand Up @@ -345,7 +345,6 @@ def predict(self, x_, y_): # noqa

step_info = OrderedDict(train_loss=0, i_epoch=1)
val.step_forward(trainer=_Trainer(), step_info=step_info) # noqa
print(step_info)
assert step_info['val_accuracy'] == classification_metrics(y, x)['accuracy']
assert set(step_info.keys()) == {
'i_epoch', 'val_accuracy', 'val_f1', 'val_precision', 'val_recall', 'val_macro_f1', 'val_macro_precision',
Expand Down Expand Up @@ -417,24 +416,24 @@ def predict(self, x_, y_): # noqa
# save checkpoint
p = Persist('test_model_1', increment=False, only_best_states=False)
p.before_proc(trainer=_Trainer())
p.on_checkpoint(cp_1, trainer=_Trainer())
p.on_checkpoint(cp_2, trainer=_Trainer())
p.on_checkpoint(cp_1)
p.on_checkpoint(cp_2)
assert (Path('.').resolve() / 'test_model_1' / 'checkpoints' / 'cp_1.pth.s').exists()
assert (Path('.').resolve() / 'test_model_1' / 'checkpoints' / 'cp_2.pth.s').exists()

# reduced save checkpoint
p = Persist('test_model_2', increment=False, only_best_states=True)
p.before_proc(trainer=_Trainer())
p.on_checkpoint(cp_1, trainer=_Trainer())
p.on_checkpoint(cp_2, trainer=_Trainer())
p.on_checkpoint(cp_1)
p.on_checkpoint(cp_2)
assert (Path('.').resolve() / 'test_model_2' / 'checkpoints' / 'cp.pth.s').exists()
assert not (Path('.').resolve() / 'test_model_2' / 'checkpoints' / 'cp_1.pth.s').exists()
assert not (Path('.').resolve() / 'test_model_2' / 'checkpoints' / 'cp_2.pth.s').exists()

# no checkpoint will be saved
p = Persist('test_model_3', increment=False, only_best_states=True)
p.before_proc(trainer=_Trainer())
p.on_checkpoint(cp_2, trainer=_Trainer())
p.on_checkpoint(cp_2)
assert not (Path('.').resolve() / 'test_model_3' / 'checkpoints' / 'cp.pth.s').exists()
assert not (Path('.').resolve() / 'test_model_3' / 'checkpoints' / 'cp_2.pth.s').exists()

Expand Down
45 changes: 20 additions & 25 deletions xenonpy/model/training/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
# license that can be found in the LICENSE file.

from inspect import signature
from typing import Iterable
from typing import Tuple, Any
from typing import Union, Dict
from typing import NamedTuple, Tuple, Any, OrderedDict, Union, Dict, Iterable

import torch
from sklearn.base import BaseEstimator
Expand All @@ -18,29 +16,34 @@


class BaseExtension(object):
def before_proc(self, *dependence) -> None:

def before_proc(self, trainer: 'BaseRunner' = None, is_training: bool = True, *_dependence: 'BaseExtension') -> None:
pass

def input_proc(self, x_in, y_in, *dependence) -> Tuple[Any, Any]:
def input_proc(self, x_in, y_in, *_dependence: 'BaseExtension') -> Tuple[Any, Any]:
return x_in, y_in

def step_forward(self, *dependence) -> None:
def step_forward(self, step_info: OrderedDict[Any, int], trainer: 'BaseRunner' = None, is_training: bool = True,
*_dependence: 'BaseExtension') -> None:
pass

def output_proc(self, y_pred, y_true, *dependence) -> Tuple[Any, Any]:
def output_proc(self, y_pred, y_true, trainer: 'BaseRunner' = None, is_training: bool = True,
*_dependence: 'BaseExtension') -> Tuple[Any, Any]:
return y_pred, y_true

def after_proc(self, *dependence) -> None:
def after_proc(self, trainer: 'BaseRunner' = None, is_training: bool = True, *_dependence: 'BaseExtension') -> None:
pass

def on_reset(self, *dependence) -> None:
def on_reset(self, trainer: 'BaseRunner' = None, is_training: bool = True, *_dependence: 'BaseExtension') -> None:
pass

def on_checkpoint(self, *dependence) -> None:
def on_checkpoint(self, checkpoint: NamedTuple, trainer: 'BaseRunner' = None, is_training: bool = True,
*_dependence: 'BaseExtension') -> None:
pass


class BaseOptimizer(object):

def __init__(self, optimizer, **kwargs):
self._kwargs = kwargs
self._optimizer = optimizer
Expand All @@ -62,6 +65,7 @@ def __call__(self, params: Iterable) -> Optimizer:


class BaseLRScheduler(object):

def __init__(self, lr_scheduler, **kwargs):
self._kwargs = kwargs
self._lr_scheduler = lr_scheduler
Expand Down Expand Up @@ -114,8 +118,7 @@ def check_device(cuda: Union[bool, str, torch.device]) -> torch.device:
else:
raise RuntimeError(
'wrong device identifier'
'see also: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device'
)
'see also: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device')

if isinstance(cuda, torch.device):
return cuda
Expand All @@ -129,10 +132,7 @@ def device(self, v):
self._device = self.check_device(v)

def _make_inject(self, injects, kwargs):
_kwargs = {
k: self._extensions[k][0]
for k in injects if k in self._extensions
}
_kwargs = {k: self._extensions[k][0] for k in injects if k in self._extensions}
_kwargs.update({k: kwargs[k] for k in injects if k in kwargs})
return _kwargs

Expand Down Expand Up @@ -183,23 +183,18 @@ def extend(self, *extension: BaseExtension) -> 'BaseRunner':
Extension.
"""

def _get_keyword_params(func) -> list:
sig = signature(func)
return [
p.name for p in sig.parameters.values()
if p.kind == p.POSITIONAL_OR_KEYWORD
]
return [p.name for p in sig.parameters.values() if p.kind == p.POSITIONAL_OR_KEYWORD]

# merge exts to named_exts
for ext in extension:
name = camel_to_snake(ext.__class__.__name__)
methods = [
'before_proc', 'input_proc', 'step_forward', 'output_proc',
'after_proc', 'on_reset', 'on_checkpoint'
]
dependencies = [
_get_keyword_params(getattr(ext, m)) for m in methods
'before_proc', 'input_proc', 'step_forward', 'output_proc', 'after_proc', 'on_reset', 'on_checkpoint'
]
dependencies = [_get_keyword_params(getattr(ext, m)) for m in methods]
dependency_inject = {k: v for k, v in zip(methods, dependencies)}

self._extensions[name] = (ext, dependency_inject)
Expand Down
12 changes: 3 additions & 9 deletions xenonpy/model/training/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ def __init__(
self._make_file_index()

@classmethod
@deprecated(
'This method is rotten and will be removed in v1.0.0, use `Checker(<model path>)` instead')
@deprecated('This method is rotten and will be removed in v1.0.0, use `Checker(<model path>)` instead')
def load(cls, model_path):
return cls(model_path)

Expand All @@ -91,7 +90,6 @@ def model_name(self):
@property
def model_structure(self):
structure = self['model_structure']
print(structure)
return structure

@property
Expand Down Expand Up @@ -154,8 +152,7 @@ def model(self, model: Module):
raise TypeError(f'except `torch.nn.Module` object but got {type(model)}')

@property
@deprecated('This property is rotten and will be removed in v1.0.0, use `checker.model` instead'
)
@deprecated('This property is rotten and will be removed in v1.0.0, use `checker.model` instead')
def trained_model(self):
if (self._path / 'trained_model.@1.pkl.z').exists():
return torch.load(str(self._path / 'trained_model.@1.pkl.z'), map_location=self._device)
Expand Down Expand Up @@ -211,10 +208,7 @@ def final_state(self, state: OrderedDict):

def _make_file_index(self):

for f in [
f for f in self._path.iterdir()
if f.match('*.pkl.*') or f.match('*.pd.*') or f.match('*.pth.*')
]:
for f in [f for f in self._path.iterdir() if f.match('*.pkl.*') or f.match('*.pd.*') or f.match('*.pth.*')]:
# select data
fn = '.'.join(f.name.split('.')[:-2])
self._files[fn] = str(f)
Expand Down
34 changes: 26 additions & 8 deletions xenonpy/model/training/extension/persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,19 @@
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.

from collections import OrderedDict
from copy import deepcopy
from datetime import datetime, timedelta
from pathlib import Path
from platform import version as sys_ver
from sys import version as py_ver
from typing import Union, Callable, Any
from typing import Union, Callable, Any, OrderedDict

import numpy as np
import torch

from xenonpy import __version__
from xenonpy.model.training import Trainer, Checker
from xenonpy.model.training.base import BaseExtension
from xenonpy.model.training.base import BaseExtension, BaseRunner

__all__ = ['Persist']

Expand All @@ -33,6 +32,7 @@ def __init__(self,
increment: bool = False,
sync_training_step: bool = False,
only_best_states: bool = False,
no_model_saving: bool = False,
**describe: Any):
"""
Expand All @@ -54,6 +54,8 @@ def __init__(self,
Default is ``False``, only save ``trainer.training_info`` at each epoch.
only_best_states
If ``True``, will only save the models with the best states in terms of each of the criteria.
no_model_saving
Indicate whether to save the connected model and checkpoints.
describe:
Any other information to describe this model.
These information will be saved under model dir by name ``describe.pkl.z``.
Expand All @@ -69,6 +71,7 @@ def __init__(self,
self._tmp_args: list = []
self._tmp_kwargs: dict = {}
self._epoch_count = 0
self._no_model_saving = no_model_saving
self.path = path

@property
Expand All @@ -77,6 +80,10 @@ def describe(self):
raise ValueError('can not access property `describe` before training')
return self._checker.describe

@property
def no_model_saving(self):
return self._no_model_saving

@property
def path(self):
if self._checker is None:
Expand Down Expand Up @@ -106,7 +113,13 @@ def __call__(self, handle: Any = None, **kwargs: Any):
def __getitem__(self, item):
return self._checker[item]

def on_checkpoint(self, checkpoint: Trainer.checkpoint_tuple, trainer: Trainer) -> None:
def on_checkpoint(self,
checkpoint: Trainer.checkpoint_tuple,
_trainer: BaseRunner = None,
_is_training: bool = True,
*_dependence: 'BaseExtension') -> None:
if self._no_model_saving:
return None
if self.only_best_states:
tmp = checkpoint.id.split('_')
if tmp[-1] == '1':
Expand All @@ -118,7 +131,11 @@ def on_checkpoint(self, checkpoint: Trainer.checkpoint_tuple, trainer: Trainer)
value = deepcopy(checkpoint._asdict())
self._checker.set_checkpoint(**{key: value})

def step_forward(self, step_info: OrderedDict, trainer: Trainer) -> None:
def step_forward(self,
step_info: OrderedDict[Any, int],
trainer: Trainer = None,
_is_training: bool = True,
*_dependence: BaseExtension) -> None:
if self.sync_training_step:
training_info = trainer.training_info
if training_info is not None:
Expand All @@ -131,13 +148,14 @@ def step_forward(self, step_info: OrderedDict, trainer: Trainer) -> None:
self._epoch_count = epoch
self._checker(training_info=training_info)

def before_proc(self, trainer: Trainer) -> None:
def before_proc(self, trainer: Trainer = None, _is_training: bool = True, *_dependence: BaseExtension) -> None:
self._checker = Checker(self._path, increment=self._increment)
if self._model_class is not None:
self._checker(model_class=self._model_class)
if self._model_params is not None:
self._checker(model_params=self._model_params)
self._checker.model = trainer.model
if not self._no_model_saving:
self._checker.model = trainer.model
self._describe_ = dict(
python=py_ver,
system=sys_ver(),
Expand All @@ -152,7 +170,7 @@ def before_proc(self, trainer: Trainer) -> None:
)
self._checker(describe=self._describe_)

def after_proc(self, trainer: Trainer) -> None:
def after_proc(self, trainer: Trainer = None, _is_training: bool = True, *_dependence: 'BaseExtension') -> None:
self._describe_.update(finish=datetime.now().strftime('%Y/%m/%d %H:%M:%S'),
time_elapsed=str(timedelta(seconds=trainer.timer.elapsed)))
self._checker.final_state = trainer.model.state_dict()
Expand Down

0 comments on commit 245f18e

Please sign in to comment.