Skip to content

Commit

Permalink
Issue #21 updated: migrate dbs when creating a DiffDatabaseMapping.
Browse files Browse the repository at this point in the history
  • Loading branch information
Manuel Marin committed Feb 4, 2019
1 parent 7af067a commit dd50307
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 53 deletions.
12 changes: 6 additions & 6 deletions spinedatabase_api/database_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def single_object_parameter_value(self, id=None, parameter_id=None, object_id=No
if id:
return qry.filter(self.ParameterValue.id == id)
if parameter_id and object_id:
return qry.filter(self.ParameterValue.parameter_id == parameter_id).\
return qry.filter(self.ParameterValue.parameter_definition_id == parameter_id).\
filter(self.ParameterValue.object_id == object_id)
return self.empty_list()

Expand Down Expand Up @@ -501,7 +501,7 @@ def parameter_value_list(self, id_list=None, object_id=None, relationship_id=Non
"""Return parameter values."""
qry = self.session.query(
self.ParameterValue.id,
self.ParameterValue.parameter_id,
self.ParameterValue.parameter_definition_id,
self.ParameterValue.object_id,
self.ParameterValue.relationship_id,
self.ParameterValue.index,
Expand Down Expand Up @@ -539,7 +539,7 @@ def object_parameter_value_list(self, parameter_name=None):
self.ParameterValue.time_pattern,
self.ParameterValue.time_series_id,
self.ParameterValue.stochastic_model_id
).filter(parameter_list.c.id == self.ParameterValue.parameter_id).\
).filter(parameter_list.c.id == self.ParameterValue.parameter_definition_id).\
filter(self.ParameterValue.object_id == object_list.c.id).\
filter(parameter_list.c.object_class_id == object_class_list.c.id)
if parameter_name:
Expand Down Expand Up @@ -568,7 +568,7 @@ def relationship_parameter_value_list(self, parameter_name=None):
self.ParameterValue.time_pattern,
self.ParameterValue.time_series_id,
self.ParameterValue.stochastic_model_id
).filter(parameter_list.c.id == self.ParameterValue.parameter_id).\
).filter(parameter_list.c.id == self.ParameterValue.parameter_definition_id).\
filter(self.ParameterValue.relationship_id == wide_relationship_list.c.id).\
filter(parameter_list.c.relationship_class_id == wide_relationship_class_list.c.id)
if parameter_name:
Expand Down Expand Up @@ -602,7 +602,7 @@ def unvalued_object_parameter_list(self, object_id):
object_ = self.single_object(id=object_id).one_or_none()
if not object_:
return self.empty_list()
valued_parameter_ids = self.session.query(self.ParameterValue.parameter_id).\
valued_parameter_ids = self.session.query(self.ParameterValue.parameter_definition_id).\
filter_by(object_id=object_id)
return self.parameter_list(object_class_id=object_.class_id).\
filter(~self.ParameterDefinition.id.in_(valued_parameter_ids))
Expand All @@ -623,7 +623,7 @@ def unvalued_relationship_parameter_list(self, relationship_id):
relationship = self.single_wide_relationship(id=relationship_id).one_or_none()
if not relationship:
return self.empty_list()
valued_parameter_ids = self.session.query(self.ParameterValue.parameter_id).\
valued_parameter_ids = self.session.query(self.ParameterValue.parameter_definition_id).\
filter_by(relationship_id=relationship_id)
return self.parameter_list().filter_by(relationship_class_id=relationship.class_id).\
filter(~self.ParameterDefinition.id.in_(valued_parameter_ids))
Expand Down
53 changes: 28 additions & 25 deletions spinedatabase_api/diff_database_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,10 +344,10 @@ def single_object_parameter_value(self, id=None, parameter_id=None, object_id=No
if parameter_id and object_id:
return qry.filter(or_(
and_(
self.ParameterValue.parameter_id == parameter_id,
self.ParameterValue.parameter_definition_id == parameter_id,
self.ParameterValue.object_id == object_id),
and_(
self.DiffParameterValue.parameter_id == parameter_id,
self.DiffParameterValue.parameter_definition_id == parameter_id,
self.DiffParameterValue.object_id == object_id)))
return self.empty_list()

Expand Down Expand Up @@ -760,7 +760,7 @@ def parameter_value_list(self, id_list=None, object_id=None, relationship_id=Non
).filter(~self.ParameterValue.id.in_(self.touched_item_id["parameter_value"]))
diff_qry = self.session.query(
self.DiffParameterValue.id,
self.DiffParameterValue.parameter_id,
self.DiffParameterValue.parameter_definition_id,
self.DiffParameterValue.object_id,
self.DiffParameterValue.relationship_id,
self.DiffParameterValue.index,
Expand Down Expand Up @@ -798,7 +798,7 @@ def object_parameter_value_list(self, object_class_id=None, parameter_name=None)
self.ParameterValue.time_pattern,
self.ParameterValue.time_series_id,
self.ParameterValue.stochastic_model_id
).filter(parameter_list.c.id == self.ParameterValue.parameter_id).\
).filter(parameter_list.c.id == self.ParameterValue.parameter_definition_id).\
filter(self.ParameterValue.object_id == object_list.c.id).\
filter(parameter_list.c.object_class_id == object_class_list.c.id).\
filter(~self.ParameterValue.id.in_(self.touched_item_id["parameter_value"]))
Expand All @@ -817,7 +817,7 @@ def object_parameter_value_list(self, object_class_id=None, parameter_name=None)
self.DiffParameterValue.time_pattern,
self.DiffParameterValue.time_series_id,
self.DiffParameterValue.stochastic_model_id
).filter(parameter_list.c.id == self.DiffParameterValue.parameter_id).\
).filter(parameter_list.c.id == self.DiffParameterValue.parameter_definition_id).\
filter(self.DiffParameterValue.object_id == object_list.c.id).\
filter(parameter_list.c.object_class_id == object_class_list.c.id)
if object_class_id:
Expand Down Expand Up @@ -851,7 +851,7 @@ def relationship_parameter_value_list(self, relationship_class_id=None, paramete
self.ParameterValue.time_pattern,
self.ParameterValue.time_series_id,
self.ParameterValue.stochastic_model_id
).filter(parameter_list.c.id == self.ParameterValue.parameter_id).\
).filter(parameter_list.c.id == self.ParameterValue.parameter_definition_id).\
filter(self.ParameterValue.relationship_id == wide_relationship_list.c.id).\
filter(parameter_list.c.relationship_class_id == wide_relationship_class_list.c.id).\
filter(~self.ParameterValue.id.in_(self.touched_item_id["parameter_value"]))
Expand All @@ -873,7 +873,7 @@ def relationship_parameter_value_list(self, relationship_class_id=None, paramete
self.DiffParameterValue.time_pattern,
self.DiffParameterValue.time_series_id,
self.DiffParameterValue.stochastic_model_id
).filter(parameter_list.c.id == self.DiffParameterValue.parameter_id).\
).filter(parameter_list.c.id == self.DiffParameterValue.parameter_definition_id).\
filter(self.DiffParameterValue.relationship_id == wide_relationship_list.c.id).\
filter(parameter_list.c.relationship_class_id == wide_relationship_class_list.c.id)
if relationship_class_id:
Expand Down Expand Up @@ -901,7 +901,7 @@ def unvalued_object_list(self, parameter_id):
if not parameter:
return self.empty_list()
valued_object_ids = self.session.query(self.ParameterValue.object_id).\
filter_by(parameter_id=parameter_id)
filter_by(parameter_definition_id=parameter_id)
return self.object_list().filter_by(class_id=parameter.object_class_id).\
filter(~self.Object.id.in_(valued_object_ids))

Expand All @@ -920,7 +920,7 @@ def unvalued_relationship_list(self, parameter_id):
if not parameter:
return self.empty_list()
valued_relationship_ids = self.session.query(self.ParameterValue.relationship_id).\
filter_by(parameter_id=parameter_id)
filter_by(parameter_definition_id=parameter_id)
return self.wide_relationship_list().filter_by(class_id=parameter.relationship_class_id).\
filter(~self.Relationship.id.in_(valued_relationship_ids))

Expand Down Expand Up @@ -1389,10 +1389,10 @@ def check_parameter_values_for_insert(self, *kwargs_list, raise_intgr_error=True
checked_kwargs_list = list()
# Per's suggestions
object_parameter_values = {
(x.object_id, x.parameter_id) for x in self.parameter_value_list() if x.object_id
(x.object_id, x.parameter_definition_id) for x in self.parameter_value_list() if x.object_id
}
relationship_parameter_values = {
(x.relationship_id, x.parameter_id) for x in self.parameter_value_list() if x.relationship_id
(x.relationship_id, x.parameter_definition_id) for x in self.parameter_value_list() if x.relationship_id
}
parameter_definition_dict = {
x.id: {
Expand Down Expand Up @@ -1436,16 +1436,16 @@ def check_parameter_values_for_update(self, *kwargs_list, raise_intgr_error=True
checked_kwargs_list = list()
parameter_value_dict = {
x.id: {
"parameter_id": x.parameter_id,
"parameter_definition_id": x.parameter_definition_id,
"object_id": x.object_id,
"relationship_id": x.relationship_id
} for x in self.parameter_value_list()}
# Per's suggestions
object_parameter_values = {
(x.object_id, x.parameter_id) for x in self.parameter_value_list() if x.object_id
(x.object_id, x.parameter_definition_id) for x in self.parameter_value_list() if x.object_id
}
relationship_parameter_values = {
(x.relationship_id, x.parameter_id) for x in self.parameter_value_list() if x.relationship_id
(x.relationship_id, x.parameter_definition_id) for x in self.parameter_value_list() if x.relationship_id
}
parameter_definition_dict = {
x.id: {
Expand Down Expand Up @@ -1479,9 +1479,9 @@ def check_parameter_values_for_update(self, *kwargs_list, raise_intgr_error=True
object_id = updated_kwargs.get("object_id", None)
relationship_id = updated_kwargs.get("relationship_id", None)
if object_id:
object_parameter_values.remove((object_id, updated_kwargs['parameter_id']))
object_parameter_values.remove((object_id, updated_kwargs['parameter_definition_id']))
elif relationship_id:
relationship_parameter_values.remove((relationship_id, updated_kwargs['parameter_id']))
relationship_parameter_values.remove((relationship_id, updated_kwargs['parameter_definition_id']))
except KeyError:
msg = "Parameter value not found."
if raise_intgr_error:
Expand All @@ -1505,9 +1505,9 @@ def check_parameter_values_for_update(self, *kwargs_list, raise_intgr_error=True
object_id = updated_kwargs.get("object_id", None)
relationship_id = updated_kwargs.get("relationship_id", None)
if object_id:
object_parameter_values.add((object_id, updated_kwargs['parameter_id']))
object_parameter_values.add((object_id, updated_kwargs['parameter_definition_id']))
elif relationship_id:
relationship_parameter_values.add((relationship_id, updated_kwargs['parameter_id']))
relationship_parameter_values.add((relationship_id, updated_kwargs['parameter_definition_id']))
except SpineIntegrityError as e:
if raise_intgr_error:
raise e
Expand All @@ -1519,7 +1519,7 @@ def check_parameter_value(
parameter_definition_dict, object_dict, relationship_dict):
"""Raise a SpineIntegrityError if the parameter value given by `kwargs` violates any integrity constraints."""
try:
parameter_id = kwargs["parameter_id"]
parameter_definition_id = kwargs["parameter_definition_id"]
except KeyError:
raise SpineIntegrityError("Missing parameter identifier.")
try:
Expand Down Expand Up @@ -1549,7 +1549,7 @@ def check_parameter_value(
parameter_name = parameter_definition['name']
raise SpineIntegrityError("Incorrect object '{}' for "
"parameter '{}'.".format(object_name, parameter_name))
if (object_id, parameter_id) in object_parameter_values:
if (object_id, parameter_definition_id) in object_parameter_values:
object_name = object_dict[object_id]['name']
parameter_name = parameter_definition['name']
raise SpineIntegrityError("The value of parameter '{}' for object '{}' is "
Expand All @@ -1564,7 +1564,7 @@ def check_parameter_value(
parameter_name = parameter_definition['name']
raise SpineIntegrityError("Incorrect relationship '{}' for "
"parameter '{}'.".format(relationship_name, parameter_name))
if (relationship_id, parameter_id) in relationship_parameter_values:
if (relationship_id, parameter_definition_id) in relationship_parameter_values:
relationship_name = relationship_dict[relationship_id]['name']
parameter_name = parameter_definition['name']
raise SpineIntegrityError("The value of parameter '{}' for relationship '{}' is "
Expand Down Expand Up @@ -1979,8 +1979,8 @@ def _add_parameters(self, *kwargs_list):
parameters (list): added instances
"""
next_id = self.next_id_with_lock()
if next_id.parameter_id:
id = next_id.parameter_id
if next_id.parameter_definition_id:
id = next_id.parameter_definition_id
else:
max_id = self.session.query(func.max(self.ParameterDefinition.id)).scalar()
id = max_id + 1 if max_id else 1
Expand Down Expand Up @@ -2009,6 +2009,9 @@ def add_parameter_values(self, *kwargs_list, raise_intgr_error=True):
parameter_values (list): added instances
intgr_error_log (list): list of integrity error messages
"""
# FIXME: this should be removed once the 'parameter_definition_id' comes in the kwargs
for kwargs in kwargs_list:
kwargs["parameter_definition_id"] = kwargs["parameter_id"]
checked_kwargs_list, intgr_error_log = self.check_parameter_values_for_insert(
*kwargs_list, raise_intgr_error=raise_intgr_error)
new_item_list = self._add_parameter_values(*checked_kwargs_list)
Expand Down Expand Up @@ -2807,9 +2810,9 @@ def _remove_cascade_parameter_definitions(self, ids, diff_ids, removed_item_id,
removed_diff_item_id.setdefault("parameter_definition", set()).update(diff_ids)
# parameter_value
item_list = self.session.query(self.ParameterValue.id).\
filter(self.ParameterValue.parameter_id.in_(ids))
filter(self.ParameterValue.parameter_definition_id.in_(ids))
diff_item_list = self.session.query(self.DiffParameterValue.id).\
filter(self.DiffParameterValue.parameter_id.in_(ids + diff_ids))
filter(self.DiffParameterValue.parameter_definition_id.in_(ids + diff_ids))
self._remove_cascade_parameter_values(
[x.id for x in item_list],
[x.id for x in diff_item_list],
Expand Down
41 changes: 24 additions & 17 deletions test/test_DiffDatabaseMapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,15 @@
:author: P. Vennström (VTT)
:date: 29.11.2018
"""
#import sys
#sys.path.append('/spinedatabase_api')

import os
import unittest
import logging
import sys
from spinedatabase_api.diff_database_mapping import DiffDatabaseMapping, SpineIntegrityError
from spinedatabase_api.helpers import create_new_spine_database
from sqlalchemy.util import KeyedTuple
import unittest
from unittest import mock
import logging
import sys
from sqlalchemy.orm import Session

class TestDiffDatabaseMapping(unittest.TestCase):
Expand All @@ -35,14 +34,21 @@ def setUpClass(cls):
logging.basicConfig(stream=sys.stderr, level=logging.DEBUG,
format='%(asctime)s %(levelname)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
engine = create_new_spine_database('sqlite://')
cls.db_map = DiffDatabaseMapping("", username='UnitTest', create_all=False)
cls.db_map.engine = engine
cls.db_map.engine.connect()
cls.db_map.session = Session(cls.db_map.engine, autoflush=False)
cls.db_map.create_mapping()
cls.db_map.create_diff_tables_and_mapping()
cls.db_map.init_next_id()
try:
os.remove("temp.sqlite")
except OSError:
pass
db_url = 'sqlite:///temp.sqlite'
engine = create_new_spine_database(db_url)
cls.db_map = DiffDatabaseMapping(db_url, username='UnitTest')

@classmethod
def tearDownClass(cls):
"""Overridden method. Runs once after all tests in this class."""
try:
os.remove("temp.sqlite")
except OSError:
pass

def setUp(self):
"""Overridden method. Runs before each test. Makes instances of TreeViewForm and GraphViewForm classes.
Expand Down Expand Up @@ -432,11 +438,11 @@ def test_add_parameter_values(self):
self.fail("add_parameter_values() raised SpineIntegrityError unexpectedly")
parameter_values = self.db_map.session.query(self.db_map.DiffParameterValue).all()
self.assertEqual(len(parameter_values), 2)
self.assertEqual(parameter_values[0].parameter_id, 1)
self.assertEqual(parameter_values[0].parameter_definition_id, 1)
self.assertEqual(parameter_values[0].object_id, 1)
self.assertIsNone(parameter_values[0].relationship_id)
self.assertEqual(parameter_values[0].value, 'orange')
self.assertEqual(parameter_values[1].parameter_id, 2)
self.assertEqual(parameter_values[1].parameter_definition_id, 2)
self.assertIsNone(parameter_values[1].object_id)
self.assertEqual(parameter_values[1].relationship_id, 1)
self.assertEqual(parameter_values[1].value, '125')
Expand Down Expand Up @@ -535,7 +541,7 @@ def test_add_same_parameter_value_twice(self):
raise_intgr_error=False)
parameter_values = self.db_map.session.query(self.db_map.DiffParameterValue).all()
self.assertEqual(len(parameter_values), 1)
self.assertEqual(parameter_values[0].parameter_id, 1)
self.assertEqual(parameter_values[0].parameter_definition_id, 1)
self.assertEqual(parameter_values[0].object_id, 1)
self.assertIsNone(parameter_values[0].relationship_id)
self.assertEqual(parameter_values[0].value, 'orange')
Expand All @@ -559,7 +565,8 @@ def test_add_existing_parameter_value(self):
]
mock_parameter_value_list.return_value = [
KeyedTuple(
[1, 1, 1, None, 'orange'], labels=["id", "parameter_id", "object_id", "relationship_id", "value"])
[1, 1, 1, None, 'orange'],
labels=["id", "parameter_definition_id", "object_id", "relationship_id", "value"])
]
with self.assertRaises(SpineIntegrityError):
self.db_map.add_parameter_values({'parameter_id': 1, 'object_id': 1, 'value': 'blue'})
Expand Down
Loading

0 comments on commit dd50307

Please sign in to comment.