Skip to content

Commit

Permalink
Add _VersionManager class.
Browse files Browse the repository at this point in the history
  • Loading branch information
sile committed Mar 29, 2019
1 parent 0aef9b4 commit 187ff9c
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 31 deletions.
90 changes: 62 additions & 28 deletions optuna/storages/rdb/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
import six
from sqlalchemy.engine import create_engine
from sqlalchemy.engine import Engine # NOQA
from sqlalchemy.exc import IntegrityError
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy import orm
Expand Down Expand Up @@ -47,7 +48,6 @@ def __init__(self, url, connect_args=None, skip_compatibility_check=False):
connect_args = connect_args or {}

url = self._fill_storage_url_template(url)
self.url = url

try:
self.engine = create_engine(url, connect_args=connect_args)
Expand All @@ -61,10 +61,9 @@ def __init__(self, url, connect_args=None, skip_compatibility_check=False):

self.logger = optuna.logging.get_logger(__name__)

self._init_version_info_model()
self._init_alembic()
self._version_manager = _VersionManager(url, self.engine, self.scoped_session)
if not skip_compatibility_check:
self._check_table_schema_compatibility()
self._version_manager._check_table_schema_compatibility()

def create_new_study_id(self, study_name=None):
# type: (Optional[str]) -> int
Expand Down Expand Up @@ -539,21 +538,24 @@ def _fill_storage_url_template(template):

return template.format(SCHEMA_VERSION=models.SCHEMA_VERSION)

def _commit_with_integrity_check(self, session):
@staticmethod
def _commit_with_integrity_check(session):
# type: (orm.Session) -> bool

try:
session.commit()
except IntegrityError as e:
self.logger.debug(
logger = optuna.logging.get_logger(__name__)
logger.debug(
'Ignoring {}. This happens due to a timing issue among threads/processes/nodes. '
'Another one might have committed a record with the same key(s).'.format(repr(e)))
session.rollback()
return False

return True

def _commit(self, session):
@staticmethod
def _commit(session):
# type: (orm.Session) -> None

try:
Expand Down Expand Up @@ -599,38 +601,37 @@ def upgrade(self):
# type: () -> None
"""Upgrade the storage schema."""

config = self._create_alembic_config()
alembic.command.upgrade(config, 'head')
self._version_manager._upgrade()

def get_current_version(self):
# type: () -> str
"""Return the schema version currently used by this storage."""

context = alembic.migration.MigrationContext.configure(self.engine.connect())
version = context.get_current_revision()
assert version is not None

return version
return self._version_manager._get_current_version()

def get_head_version(self):
# type: () -> str
"""Return the latest schema version."""

script = self._create_alembic_script()
return script.get_current_head()

def _get_base_version(self):
# type: () -> str

script = self._create_alembic_script()
return script.get_base()
return self._version_manager._get_head_version()

def get_all_versions(self):
# type: () -> List[str]
"""Return the schema version list."""

script = self._create_alembic_script()
return list(r.revision for r in script.walk_revisions())
return self._version_manager._get_all_versions()


class _VersionManager(object):
def __init__(self, url, engine, scoped_session):
# type: (str, Engine, orm.Session) -> None

self.url = url
self.engine = engine
self.scoped_session = scoped_session

self._init_version_info_model()
self._init_alembic()

def _init_version_info_model(self):
# type: () -> None
Expand Down Expand Up @@ -660,7 +661,7 @@ def _init_alembic(self):
return

if self._is_alembic_supported():
revision = self.get_head_version()
revision = self._get_head_version()
else:
# The storage has been created before alembic is introduced.
revision = self._get_base_version()
Expand All @@ -684,15 +685,15 @@ def _check_table_schema_compatibility(self):
version_info = models.VersionInfoModel.find(session)
assert version_info is not None

current_version = self.get_current_version()
head_version = self.get_head_version()
current_version = self._get_current_version()
head_version = self._get_head_version()
if current_version == head_version:
return

message = 'The runtime optuna version {} is no longer compatible with the table schema ' \
'(set up by optuna {}). '.format(version.__version__,
version_info.library_version)
known_versions = self.get_all_versions()
known_versions = self._get_all_versions()
if current_version in known_versions:
message += 'Please execute `$ optuna storage upgrade --storage $STORAGE_URL`' \
'for upgrading the storage.'
Expand All @@ -702,6 +703,39 @@ def _check_table_schema_compatibility(self):

raise RuntimeError(message)

def _get_current_version(self):
# type: () -> str

context = alembic.migration.MigrationContext.configure(self.engine.connect())
version = context.get_current_revision()
assert version is not None

return version

def _get_head_version(self):
# type: () -> str

script = self._create_alembic_script()
return script.get_current_head()

def _get_base_version(self):
# type: () -> str

script = self._create_alembic_script()
return script.get_base()

def _get_all_versions(self):
# type: () -> List[str]

script = self._create_alembic_script()
return [r.revision for r in script.walk_revisions()]

def _upgrade(self):
# type: () -> None

config = self._create_alembic_config()
alembic.command.upgrade(config, 'head')

def _is_alembic_supported(self):
# type: () -> bool

Expand Down
6 changes: 3 additions & 3 deletions tests/storages_tests/rdb_tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,18 +192,18 @@ def test_check_table_schema_compatibility():
session = storage.scoped_session()

# The schema version of a newly created storage is always up-to-date.
storage._check_table_schema_compatibility()
storage._version_manager._check_table_schema_compatibility()

# `SCHEMA_VERSION` has not been used for compatibility check since alembic was introduced.
version_info = session.query(VersionInfoModel).one()
version_info.schema_version = SCHEMA_VERSION - 1
session.commit()

storage._check_table_schema_compatibility()
storage._version_manager._check_table_schema_compatibility()

# TODO(ohta): Remove the following comment out when the second revision is introduced.
# with pytest.raises(RuntimeError):
# storage._set_alembic_revision(storage._get_base_version())
# storage._set_alembic_revision(storage._version_manager._get_base_version())
# storage._check_table_schema_compatibility()


Expand Down

0 comments on commit 187ff9c

Please sign in to comment.