Skip to content

Commit

Permalink
Merge pull request #5 from fmigneault/fix-version-compare
Browse files Browse the repository at this point in the history
fix version compare
  • Loading branch information
plstcharles committed Apr 25, 2020
2 parents 8075a55 + 471e1c4 commit 804d890
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 18 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@ Changelog
`Unreleased <http://github.com/plstcharles/thelper/tree/master>`_ (latest)
----------------------------------------------------------------------------------

* Fix version comparison check when validating configuration and/or checkpoint against package version.
Version can now have a release part which was not considered.

`0.5.0-rc <http://github.com/plstcharles/thelper/tree/0.5.0-rc>`_ (%Y/%m/%d)
----------------------------------------------------------------------------------

* Update this changelog to use rst links (renders on github and readthedocs)
* Add ``infer`` mode for classification of geo-referenced rasters
* Add ``Dockerfile-geo`` to build thelper with pre-installed geo packages
* Add geo-related build instructions to travis-ci build steps
* Add geo-related build instructions to travis-ci build steps
* Add auto-documentation of makefile targets and docker related targets

`0.4.7 <http://github.com/plstcharles/thelper/tree/0.4.7>`_ (2019/11/20)
Expand Down
21 changes: 21 additions & 0 deletions tests/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,24 @@ def test_load_checkpoint_not_default_cpu_with_devices(mock_torch_load, mock_thel
checkpoint_file = 'dummy-checkpoint.pth'
thelper.utils.load_checkpoint(checkpoint_file)
mock_torch_load.assert_called_once_with(checkpoint_file, map_location=None)


def test_check_version_correct_parsing_and_not_future():
versions_tests = [
# not-future, check, required, expected check parts, expected required parts
("0.1.0", "0.2.0", True, [0, 1, 0, ''], [0, 2, 0]),
("0.2.0", "0.2.1", True, [0, 2, 0, ''], [0, 2, 1]),
("0.3.1", "0.2.2", False, [0, 3, 1, ''], [0, 2, 2]),
("0.4.8", "0.5.0-rc", True, [0, 4, 8, ''], [0, 5, 0, 'rc']),
("0.5.0a0", "0.5.0-rc", True, [0, 5, 0, 'a0'], [0, 5, 0, 'rc']) # invalid parsing of '-' switches result around
]
for i, ver_test in enumerate(versions_tests):
ver_check, ver_req, exp_ver_ok, exp_ver_check, exp_ver_req = ver_test
res_ver_ok, res_ver_check, res_ver_req = thelper.utils.check_version(ver_check, ver_req)
assert res_ver_ok == exp_ver_ok, f"supposed to get {exp_ver_ok} for [{ver_check}] and [{ver_req}] (test: {i})"
assert len(res_ver_check) == 4, "missing parts in parsing result"
assert len(res_ver_req) == 4, "missing parts in parsing result"
assert all(list(rv == ev for rv, ev in zip(res_ver_check, exp_ver_check))), \
f"check version parsing mismatches the expected result ({res_ver_check} != {exp_ver_check}) (test: {i})"
assert all(list(rv == ev for rv, ev in zip(res_ver_req, exp_ver_req))), \
f"required version parsing mismatches the expected result ({res_ver_req} != {exp_ver_req}) (test: {i})"
63 changes: 46 additions & 17 deletions thelper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import re
import sys
import time
from distutils.version import LooseVersion
from typing import TYPE_CHECKING

import numpy as np
Expand All @@ -30,7 +31,7 @@
import thelper.typedefs # noqa: F401

if TYPE_CHECKING:
from typing import Any, AnyStr, Callable, Dict, List, Optional, Type, Union # noqa: F401
from typing import Any, AnyStr, Callable, Dict, List, Optional, Tuple, Type, Union # noqa: F401
from types import FunctionType # noqa: F401

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -276,6 +277,40 @@ def load_checkpoint(ckpt, # type: thelper.typedefs.Checkpoi
return ckptdata


def check_version(version_check, version_required):
# type: (AnyStr, AnyStr) -> Tuple[bool, List[Union[int, AnyStr]], List[Union[int, AnyStr]]]
"""Verifies that the checked version is not greater than the required one (ie: not a future version).
Version format is ``MAJOR[.MINOR[.PATCH[[-]<RELEASE>]]]``.
Note that for ``RELEASE`` part, comparison depends on alphabetical order if all other previous parts were equal
(i.e.: ``alpha`` will be lower than ``beta``, which in turn is lower than ``rc`` and so on). The ``-`` is optional
and will be removed for comparison (i.e.: ``0.5.0-rc`` is exactly the same as ``0.5.0rc`` and the additional ``-``
will not result in evaluating ``0.5.0a0`` as a greater version because of ``-`` being lower ascii than ``a``).
Args:
version_check: the version string that needs to be verified and compared for lower than the required version.
version_required: the control version against which the check is done.
Returns:
Tuple of the validated check, and lists of both parsed version parts as ``[MAJOR, MINOR, PATCH, 'RELEASE']``.
The returned lists are *guaranteed* to be formed of 4 elements, adding 0 or '' as applicable for missing parts.
"""
v_check = LooseVersion(version_check)
v_req = LooseVersion(version_required)
l_check = [0, 0, 0, '']
l_req = [0, 0, 0, '']
for ver_list, ver_parse in [(l_check, v_check), (l_req, v_req)]:
for v in [0, 1, 2]:
ver_list[v] = 0 if len(ver_parse.version) < v + 1 else ver_parse.version[v]
if len(ver_parse.version) >= 4:
release_idx = 4 if len(ver_parse.version) >= 5 and ver_parse.version[3] == '-' else 3
ver_list[3] = ''.join(str(v) for v in ver_parse.version[release_idx:])
# check with re-parsed version after fixing release dash
v_ok = LooseVersion('.'.join(str(v) for v in l_check)) <= LooseVersion('.'.join(str(v) for v in l_req))
return v_ok, l_check, l_req


def migrate_checkpoint(ckptdata, # type: thelper.typedefs.CheckpointContentType
): # type: (...) -> thelper.typedefs.CheckpointContentType
"""Migrates the content of an incompatible or outdated checkpoint to the current version of the framework.
Expand All @@ -292,18 +327,16 @@ def migrate_checkpoint(ckptdata, # type: thelper.typedefs.CheckpointContentType
"""
if not isinstance(ckptdata, dict):
raise AssertionError("unexpected ckptdata type")
from thelper import __version__ as curr_ver
curr_ver = [int(num) for num in curr_ver.split(".")]
from thelper import __version__ as curr_ver_str
ckpt_ver_str = ckptdata["version"] if "version" in ckptdata else "0.0.0"
ckpt_ver = [int(num) for num in ckpt_ver_str.split(".")]
if (ckpt_ver[0] > curr_ver[0] or (ckpt_ver[0] == curr_ver[0] and ckpt_ver[1] > curr_ver[1]) or
(ckpt_ver[0:2] == curr_ver[0:2] and ckpt_ver[2] > curr_ver[2])):
raise AssertionError("cannot migrate checkpoints from future versions!")
ok_ver, ckpt_ver, curr_ver = check_version(ckpt_ver_str, curr_ver_str)
if not ok_ver:
raise AssertionError("cannot migrate checkpoints from future versions! You need to update your thelper package")
if "config" not in ckptdata:
raise AssertionError("checkpoint migration requires config")
old_config = ckptdata["config"]
new_config = migrate_config(copy.deepcopy(old_config), ckpt_ver_str)
if ckpt_ver == [0, 0, 0]:
if ckpt_ver[:3] == [0, 0, 0]:
logger.warning("trying to migrate checkpoint data from v0.0.0; all bets are off")
else:
logger.info("trying to migrate checkpoint data from v%s" % ckpt_ver_str)
Expand Down Expand Up @@ -359,15 +392,11 @@ def migrate_config(config, # type: thelper.typedefs.ConfigDict
"""
if not isinstance(config, dict):
raise AssertionError("unexpected config type")
if not isinstance(cfg_ver_str, str) or len(cfg_ver_str.split(".")) != 3:
raise AssertionError("unexpected checkpoint version formatting")
from thelper import __version__ as curr_ver
curr_ver = [int(num) for num in curr_ver.split(".")]
cfg_ver = [int(num) for num in cfg_ver_str.split(".")]
if (cfg_ver[0] > curr_ver[0] or (cfg_ver[0] == curr_ver[0] and cfg_ver[1] > curr_ver[1]) or
(cfg_ver[0:2] == curr_ver[0:2] and cfg_ver[2] > curr_ver[2])):
raise AssertionError("cannot migrate configs from future versions!")
if cfg_ver == [0, 0, 0]:
from thelper import __version__ as curr_ver_str
ok_ver, cfg_ver, curr_ver = check_version(cfg_ver_str, curr_ver_str)
if not ok_ver:
raise AssertionError("cannot migrate checkpoints from future versions! You need to update your thelper package")
if cfg_ver[:3] == [0, 0, 0]:
logger.warning("trying to migrate config from v0.0.0; all bets are off")
else:
logger.info("trying to migrate config from v%s" % cfg_ver_str)
Expand Down

0 comments on commit 804d890

Please sign in to comment.