Skip to content

Commit

Permalink
Merge cdef15d into aac50aa
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasberbuer committed Mar 3, 2020
2 parents aac50aa + cdef15d commit 3046be7
Show file tree
Hide file tree
Showing 7 changed files with 251 additions and 87 deletions.
115 changes: 42 additions & 73 deletions src/vallenae/io/_database.py
Expand Up @@ -5,13 +5,13 @@
from pathlib import Path
from typing import Any, Dict, Optional, Sequence, Set, Tuple, Union

from ._sql import insert_from_dict, read_sql_generator, update_from_dict
from ._sql import ConnectionWrapper, insert_from_dict, read_sql_generator, update_from_dict


def require_write_access(func):
@wraps(func)
def wrapper(self: "Database", *args, **kwargs):
if self.readonly:
if self._readonly: # pylint: disable=protected-access
raise ValueError(
"Can not write to database in read-only mode. Open database with mode='rw'"
)
Expand All @@ -32,43 +32,20 @@ def __init__(
table_prefix: str,
required_file_ext: Optional[str] = None,
):
self._connected: bool = False
self._filename: str = str(filename) # forced str conversion (e.g. for pathlib.Path)

# check file extension
if required_file_ext is not None:
file_ext = Path(self._filename).suffix
file_ext = Path(filename).suffix
if file_ext.lower() != required_file_ext.lower():
raise ValueError(
f"File {filename} does not have the required extension {required_file_ext}"
)

# check mode
valid_modes = ("ro", "rw", "rwc")
if mode not in valid_modes:
raise ValueError(f"Invalid access mode '{mode}', use: {valid_modes}")
if mode == "rwc":
if not Path(filename).exists():
self.create(filename) # call abstract method (implemented by child class)
self._readonly: bool = (mode == "ro")

# open sqlite connection
self._connection = sqlite3.connect(
f"file:{self._filename}?mode={mode}",
uri=True,
check_same_thread=(not self._readonly), # allow multithreading in read-only mode
)
self._connected = True

# set pragmas for write-mode
if not self._readonly:
self._connection.executescript(
"""
PRAGMA journal_mode = WAL;
PRAGMA locking_mode = EXCLUSIVE;
PRAGMA synchronous = OFF;
"""
)

self._readonly = (mode == "ro")
self._connection_wrapper = ConnectionWrapper(filename, mode)

self._table_prefix: str = table_prefix
self._table_main: str = f"{table_prefix}_data"
Expand Down Expand Up @@ -97,17 +74,12 @@ def create(filename: str):
@property
def filename(self) -> str:
"""Filename of database."""
return self._filename

@property
def readonly(self) -> bool:
"""Read-only mode for database connection."""
return self._readonly
return self._connection_wrapper.filename

@property
def connected(self) -> bool:
"""Check if connected to SQLite database."""
return self._connected
return self._connection_wrapper.connected

def connection(self) -> sqlite3.Connection:
"""
Expand All @@ -116,9 +88,7 @@ def connection(self) -> sqlite3.Connection:
Raises:
RuntimeError: If connection is closed
"""
if not self._connected:
raise RuntimeError("Not connected to SQLite database")
return self._connection
return self._connection_wrapper.connection()

def rows(self) -> int:
"""Number of rows in data table."""
Expand Down Expand Up @@ -146,11 +116,11 @@ def _add_columns(
"""Add columns to specified table."""
if dtype is None:
dtype = ""
con = self.connection()
columns_exist = self._columns(table)
for column in columns: # keep order of columns
if column not in columns_exist:
con.execute(f"ALTER TABLE {table} ADD COLUMN {column} {dtype}")
with self.connection() as con: # commit/rollback transaction
columns_exist = self._columns(table)
for column in columns: # keep order of columns
if column not in columns_exist:
con.execute(f"ALTER TABLE {table} ADD COLUMN {column} {dtype}")

def tables(self) -> Set[str]:
"""Get table names."""
Expand Down Expand Up @@ -187,17 +157,17 @@ def write_fieldinfo(self, field: str, info: Dict[str, Any]):
if field not in self.columns():
raise ValueError(f"Field {field} must be a column of data table")

con = self.connection()
row_dict = info
row_dict["field"] = field
try:
if field in self.fieldinfo().keys():
update_from_dict(con, self._table_fieldinfo, row_dict, "field")
else:
insert_from_dict(con, self._table_fieldinfo, row_dict)
except sqlite3.OperationalError: # missing column(s)
self._add_columns(self._table_fieldinfo, list(row_dict.keys()))
self.write_fieldinfo(field, info) # try again
with self.connection() as con: # commit/rollback transaction
try:
if field in self.fieldinfo().keys():
update_from_dict(con, self._table_fieldinfo, row_dict, "field")
else:
insert_from_dict(con, self._table_fieldinfo, row_dict)
except sqlite3.OperationalError: # missing column(s)
self._add_columns(self._table_fieldinfo, list(row_dict.keys()))
self.write_fieldinfo(field, info) # try again

def globalinfo(self) -> Dict[str, Any]:
"""Read globalinfo table."""
Expand All @@ -216,22 +186,23 @@ def try_convert_string(value: str) -> Any:
def _update_globalinfo(self):
"""Update globalinfo after writes."""
keys = self.globalinfo().keys()
if "ValidSets" in keys:
self.connection().execute(
"""
UPDATE {prefix}_globalinfo
SET Value = (SELECT MAX(rowid) FROM {prefix}_data)
WHERE Key == "ValidSets"
""".format(prefix=self._table_prefix)
)
if "TRAI" in keys:
self.connection().execute(
"""
UPDATE {prefix}_globalinfo
SET Value = (SELECT MAX(TRAI) FROM {prefix}_data)
WHERE Key == "TRAI";
""".format(prefix=self._table_prefix)
)
with self.connection() as con: # commit/rollback transaction
if "ValidSets" in keys:
con.execute(
"""
UPDATE {prefix}_globalinfo
SET Value = (SELECT MAX(rowid) FROM {prefix}_data)
WHERE Key == "ValidSets"
""".format(prefix=self._table_prefix)
)
if "TRAI" in keys:
con.execute(
"""
UPDATE {prefix}_globalinfo
SET Value = (SELECT MAX(TRAI) FROM {prefix}_data)
WHERE Key == "TRAI";
""".format(prefix=self._table_prefix)
)

def _parameter_table(self) -> Dict[int, Dict[str, Any]]:
"""Read *_params table to dict."""
Expand All @@ -258,12 +229,10 @@ def _parameter(self, param_id: int) -> Dict[str, Any]:

def close(self):
"""Close database connection."""
if self._connected:
if self.connected:
if not self._readonly:
self._update_globalinfo()
self._connection.commit() # commit remaining changes
self._connection.close()
self._connected = False
self._connection_wrapper.close()

def __del__(self):
self.close()
Expand Down
111 changes: 105 additions & 6 deletions src/vallenae/io/_sql.py
Expand Up @@ -7,35 +7,134 @@

from .types import SizedIterable


logger = logging.getLogger(__name__)


class ConnectionWrapper:
"""SQLite3 connection wrapper (picklable)."""

def __init__(self, filename: str, mode: str = "ro", multithreading: bool = False):
# check mode
valid_modes = ("ro", "rw", "rwc")
if mode not in valid_modes:
raise ValueError(f"Invalid access mode '{mode}', use: {valid_modes}")

self._filename = str(filename)
self._mode = mode
self._multithreading = multithreading
# enable multithreading for read-only connections
if mode == "ro":
self._multithreading = True

self._connected = False
self._connect()

def _connect(self):
"""Open SQLite connection."""
self._connection = sqlite3.connect(
f"file:{self._filename}?mode={self._mode}",
uri=True,
check_same_thread=(not self._multithreading),
)
self._connected = True

# set pragmas for write-mode
if self._mode != "ro":
self._connection.executescript(
"""
PRAGMA journal_mode = WAL;
PRAGMA synchronous = OFF;
"""
)

@property
def filename(self) -> str:
return self._filename

@property
def mode(self) -> str:
return self._mode

@property
def connected(self) -> bool:
return self._connected

def connection(self) -> sqlite3.Connection:
"""
Get SQLite connection object.
Raises:
RuntimeError: If connection is closed
"""
if not self._connected:
raise RuntimeError("Not connected to SQLite database")
return self._connection

def get_readonly_connection(self) -> "ConnectionWrapper":
"""
Return read-only ConnectionWrapper.
Create new connection if mode != ro.
"""
if self._mode == "ro":
return self
return ConnectionWrapper(self._filename, mode="ro")

def close(self):
if self._connected:
self._connection.commit() # commit remaining changes
self._connection.close()
self._connected = False

def __del__(self):
self.close()

def __getstate__(self):
# commit changes, database will be reopened with __setstate__
if self._connected:
self._connection.commit()
state = self.__dict__.copy()
del state["_connection"] # remove the unpicklable sqlite3.connection
return state

def __setstate__(self, state):
self.__dict__.update(state)
# reopen connection if connected before
if self._connected:
self._connect()


T = TypeVar("T")
class QueryIterable(SizedIterable[T]):
"""Sized iterable to query results from SQLite as dictionaries."""
"""
Sized iterable to query results from SQLite as dictionaries.
SQLite connection is stored in picklable ConnectionWrapper to be used with multiprocessing.
"""
def __init__(
self,
connection: sqlite3.Connection,
connection_wrapper: ConnectionWrapper,
query: str,
dict_to_type: Callable[[Dict[str, Any]], T],
):
self._connection = connection
self._connection_wrapper = connection_wrapper
self._query = query
self._dict_to_type = dict_to_type
self._count_result: Optional[int] = None # cache result of __len__

def __len__(self) -> int:
if self._count_result is None:
self._count_result = count_sql_results(self._connection, self._query)
self._count_result = count_sql_results(
self._connection_wrapper.connection(),
self._query
)
return self._count_result

def __iter__(self) -> Iterator[T]:
if self.__len__() == 0:
logger.debug("Empty SQLite query")

for row in read_sql_generator(self._connection, self._query):
for row in read_sql_generator(self._connection_wrapper.connection(), self._query):
yield self._dict_to_type(row)


Expand Down
24 changes: 20 additions & 4 deletions src/vallenae/io/pridb.py
Expand Up @@ -218,7 +218,11 @@ def iread_hits(
greater_equal={"vae.Time": time_start},
less={"vae.Time": time_stop},
)
return QueryIterable(self.connection(), query, HitRecord.from_sql)
return QueryIterable(
self._connection_wrapper.get_readonly_connection(),
query,
HitRecord.from_sql,
)

def iread_markers(
self,
Expand Down Expand Up @@ -246,7 +250,11 @@ def iread_markers(
greater_equal={"vae.Time": time_start},
less={"vae.Time": time_stop},
)
return QueryIterable(self.connection(), query, MarkerRecord.from_sql)
return QueryIterable(
self._connection_wrapper.get_readonly_connection(),
query,
MarkerRecord.from_sql,
)

def iread_parametric(
self,
Expand Down Expand Up @@ -276,7 +284,11 @@ def iread_parametric(
greater_equal={"vae.Time": time_start},
less={"vae.Time": time_stop},
)
return QueryIterable(self.connection(), query, ParametricRecord.from_sql)
return QueryIterable(
self._connection_wrapper.get_readonly_connection(),
query,
ParametricRecord.from_sql,
)

def iread_status(
self,
Expand Down Expand Up @@ -309,7 +321,11 @@ def iread_status(
greater_equal={"vae.Time": time_start},
less={"vae.Time": time_stop},
)
return QueryIterable(self.connection(), query, StatusRecord.from_sql)
return QueryIterable(
self._connection_wrapper.get_readonly_connection(),
query,
StatusRecord.from_sql,
)

@require_write_access
@check_monotonic_time
Expand Down
6 changes: 5 additions & 1 deletion src/vallenae/io/tradb.py
Expand Up @@ -145,7 +145,11 @@ def iread(
greater_equal={"vtr.SetID": setid_time_start},
less={"vtr.SetID": setid_time_stop},
)
return QueryIterable(con, query, TraRecord.from_sql)
return QueryIterable(
self._connection_wrapper.get_readonly_connection(),
query,
TraRecord.from_sql,
)

def read_wave(
self, trai: int, time_axis: bool = True,
Expand Down

0 comments on commit 3046be7

Please sign in to comment.