Skip to content

Commit

Permalink
Fix: cursor handling for multiple cursors in a single transaction
Browse files Browse the repository at this point in the history
  • Loading branch information
ReneBakkerCineca committed Feb 20, 2023
1 parent 82457f2 commit ed9e3b5
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 50 deletions.
39 changes: 20 additions & 19 deletions lwetl/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
+ LOGIN_ALIAS
"""

import logging
import os
import sys
import yaml
Expand All @@ -40,6 +41,9 @@
from .utils import verified_boolean
from .security import decrypt

# define a logger
LOGGER = logging.getLogger(os.path.basename(__file__).split('.')[0])

MODULE_DIR = os.path.realpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), os.path.pardir))
HOME_DIR = os.path.expanduser('~')
WORK_DIR = os.getcwd()
Expand Down Expand Up @@ -194,8 +198,7 @@ def print_info():
except PermissionError:
pass
except yaml.YAMLError as pe:
print('ERROR: cannot parse the configuration file {}'.format(fn), file=sys.stderr)
print(pe, file=sys.stderr)
LOGGER.error('Cannot parse the configuration file {}: {}'.format(fn, pe))
sys.exit(1)

if (len(configuration) == 0) or (count_cfg_files <= 1):
Expand All @@ -211,7 +214,7 @@ def print_info():
except (PermissionError, FileNotFoundError, FileExistsError):
home_cfg_dir = None
if home_cfg_dir is None:
print('FATAL ERROR: no configuration found. Looked for:\n- ' + '\n- '.join(CFG_FILES))
LOGGER.critical('FATAL: no configuration found. Looked for:\n- ' + '\n- '.join(CFG_FILES))
sys.exit(1)
else:
from shutil import copyfile
Expand All @@ -220,7 +223,7 @@ def print_info():
for trg_file in [os.path.join(home_cfg_dir, f) for f in ['config-example.yml', 'config.yml']]:
copyfile(src_file, trg_file)
os.chmod(trg_file, S_IREAD | S_IWRITE)
print('INFO: Sample configuration files installed in: ' + home_cfg_dir)
LOGGER.info('Sample configuration files installed in: ' + home_cfg_dir)

# add environment variables
for var_name, value in configuration.get('env', {}).items():
Expand All @@ -236,11 +239,11 @@ def print_info():
JDBC_DRIVERS = dict()
for jdbc_type, cfg in configuration.get('drivers', {}).items():
if 'jar' not in cfg:
print('ERROR in definition of driver type {}: jar file not specified.'.format(jdbc_type))
LOGGER.error('Error in definition of driver type {}: jar file not specified.'.format(jdbc_type))
elif 'class' not in cfg:
print('ERROR in definition of driver type {}: driver class not specified.'.format(jdbc_type))
LOGGER.error('Error in definition of driver type {}: driver class not specified.'.format(jdbc_type))
elif 'url' not in cfg:
print('ERROR in definition of driver type {}: url not specified.'.format(jdbc_type))
LOGGER.error('Error in definition of driver type {}: url not specified.'.format(jdbc_type))
elif os.path.isfile(cfg['jar']):
JDBC_DRIVERS[jdbc_type] = cfg
else:
Expand All @@ -266,15 +269,14 @@ def print_info():
dst_file = os.path.join(lib_dir, jar_file)
try:
urlretrieve(cfg['jar'], dst_file)
print('INFO: {} downloaded to: {}'.format(jar_file, lib_dir))
LOGGER.info('{} downloaded to: {}'.format(jar_file, lib_dir))
except (HTTPError, URLError) as http_error:
print('ERROR - failed to retrieve: ' + cfg['jar'])
print(http_error)
LOGGER.error('Failed to retrieve {}: {}'.format(cfg['jar'], http_error))
if os.path.isfile(dst_file):
JDBC_DRIVERS[jdbc_type] = merge({'jar': dst_file}, cfg)
break
if jdbc_type not in JDBC_DRIVERS:
print('WARNING - no driver found for: ' + jdbc_type)
LOGGER.warning('No driver found for: ' + jdbc_type)

JAR_FILES = []
for cfg in JDBC_DRIVERS.values():
Expand All @@ -285,11 +287,11 @@ def print_info():
JDBC_SERVERS = dict()
for service, cfg in configuration.get('servers', {}).items():
if 'type' not in cfg:
print('ERROR in definition of service {}: database type not specified.'.format(service))
LOGGER.error('Error in definition of service {}: database type not specified.'.format(service))
elif cfg['type'] not in JDBC_DRIVERS:
print('ERROR in definition of service {}: unknown driver type {}.'.format(service, cfg['type']))
LOGGER.error('Error in definition of service {}: unknown driver type {}.'.format(service, cfg['type']))
elif 'url' not in cfg:
print('ERROR in definition of service {}: url not specified.'.format(service))
LOGGER.error('Error in definition of service {}: url not specified.'.format(service))
else:
JDBC_SERVERS[service.lower()] = cfg

Expand All @@ -303,11 +305,10 @@ def print_info():
import regex
except ImportError:
regex = None
print('''
WARNING: tnsnames.ora can only be parsed if the regex module is installed.
- use 'pip install regex' to install
This message may be removed by setting the environment variable IGNORE_TNS to true
(either in the system, or in the env section of the configuration file).''')
LOGGER.warning('''tnsnames.ora can only be parsed if the regex module is installed.
Use 'pip install regex' to install.
This message may be removed by setting the environment variable IGNORE_TNS to true
(either in the system, or in the env section of the configuration file).''')

if regex:
with open(tns, 'r') as fh:
Expand Down
88 changes: 61 additions & 27 deletions lwetl/jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,25 @@
Main jdbc connection
"""

import sys
import logging
import os

from collections import OrderedDict
from decimal import Decimal, InvalidOperation

from jpype import JPackage
from jaydebeapi import Cursor, Error, DatabaseError, connect

from typing import Union
from typing import List, Union

from .config_parser import JDBC_DRIVERS, JAR_FILES, parse_login, parse_dummy_login
from .exceptions import DriverNotFoundException, SQLExecuteException, CommitException
from .runtime_statistics import RuntimeStatistics
from .utils import *

# define a logger
LOGGER = logging.getLogger(os.path.basename(__file__).split('.')[0])

# marker (attribute) to trace chained connections
PARENT_CONNECTION = '_lwetl_jdbc'

Expand Down Expand Up @@ -60,8 +64,8 @@ def func_wrapper(*args, **kwargs):
cursor = argl[1]

if isinstance(cursor, Cursor):
if cursor not in self.cursors:
print('WARNING: the specified cursor not found in this instance. Possible commit?')
if not self.has_cursor(cursor):
LOGGER.debug('The specified cursor not found in this instance. Possible commit?')
cursor = None
elif cursor is not None:
raise ValueError('Illegal cursor specifier.')
Expand Down Expand Up @@ -222,7 +226,7 @@ def default_transformer(v):
return v

def oracle_lob_to_bytes(self, lob):
# print(type(lob).__name__)
# LOGGER.debug(type(lob).__name__)
return self.byte_array_to_bytes(lob.getBytes(1, int(lob.length())))

# noinspection PyMethodMayBeStatic
Expand Down Expand Up @@ -314,7 +318,7 @@ def __call__(self, row):
# noinspection PyCallingNonCallable
values.append(func(value))
except Exception as e:
print('ERROR - cannot parse {}: {}'.format(value, str(e)))
LOGGER.error('Cannot parse {}: {}'.format(value, str(e)))
parse_exception = e
if parse_exception is not None:
raise parse_exception
Expand Down Expand Up @@ -358,6 +362,13 @@ def __call__(self, row):
return dd


class CursorStorage:

def __init__(self, cursor: Cursor, keep: bool = False):
self.cursor = cursor
self.keep = keep


class DummyJdbc:
"""
Dummy JDBC connection.
Expand Down Expand Up @@ -392,6 +403,9 @@ class Jdbc:
"""

CURSOR_MODE_USE_CURRENT = 0x01
CURSOR_MODE_DELETE_ON_COMMIT = 0x02

def __init__(self, login: str, auto_commit=False, upper_case=True):
"""
Init the jdbc connection.
Expand Down Expand Up @@ -423,7 +437,7 @@ def __init__(self, login: str, auto_commit=False, upper_case=True):
connection_error = ConnectionError(error_msg)
if ':' in error_msg:
error_msg = error_msg.split(':', 1)[-1]
print('ERROR - jdbc connection failed: ' + error_msg, file=sys.stderr)
LOGGER.error('Jdbc connection failed: ' + error_msg)

if connection_error is not None:
raise connection_error
Expand All @@ -433,25 +447,33 @@ def __init__(self, login: str, auto_commit=False, upper_case=True):

# cursor handling
self.counter = 0
self.cursors = []
self.cursors = [] # type: List[CursorStorage]
self.current_cursor = None

# noinspection PyBroadException
def __del__(self):
if self.connection:
for cursor in self.cursors:
for cs in self.cursors:
try:
cursor.close()
cs.cursor.close()
except Exception:
pass
try:
self.connection.close()
except Exception:
pass

def has_cursor(self, cursor: Cursor):
return cursor in [cs.cursor for cs in self.cursors]

def close_all_cursors(self):
for c in list(self.cursors):
self.close(c)
cursor_list = []
for cs in self.cursors:
if cs.keep:
cursor_list.append(cs)
else:
self.close(cs.cursor)
self.cursors = cursor_list

@default_cursor(False)
def close(self, cursor: Union[Cursor, str, None] = None):
Expand All @@ -470,12 +492,8 @@ def close(self, cursor: Union[Cursor, str, None] = None):
except Error:
close_ok = False
else:
self.cursors.remove(cursor)
if cursor == self.current_cursor:
if len(self.cursors) > 0:
self.current_cursor = self.cursors[-1]
else:
self.current_cursor = None
self.current_cursor = None
# noinspection PyBroadException
try:
# for garbage collection
Expand All @@ -485,7 +503,8 @@ def close(self, cursor: Union[Cursor, str, None] = None):
return close_ok

def execute(self, sql: str, parameters: Union[list, tuple] = None,
cursor: Union[Cursor, None] = None, use_current_cursor: bool = True) -> Cursor:
cursor: Union[Cursor, None] = None,
use_current_cursor: bool = True, keep_cursor: bool = False) -> Cursor:
"""
Execute a query
@param sql: str query to execute
Expand All @@ -494,6 +513,7 @@ def execute(self, sql: str, parameters: Union[list, tuple] = None,
@param cursor: to use for execution of the sql command. Create a new one if None (default)
@param use_current_cursor: if set to False, a None cursor will trigger the creation of a new cursor.
Otherwise, the default cursor will be used, if present.
@param keep_cursor: if set to true, the cursor will not be closed upon a commit or rollback.
@return: Cursor of the execution
@raise SQLExecutionError on an execution exception
Expand Down Expand Up @@ -525,15 +545,18 @@ def string2java_string(sql_or_list):
elif not isinstance(sql, str):
raise TypeError('Query (sql) must be a string.')

if (self.current_cursor is None) and (len(self.cursors) > 0):
self.current_cursor = self.cursors[-1].cursor

if cursor is None:
if use_current_cursor:
cursor = self.current_cursor
elif cursor not in self.cursors:
elif not self.has_cursor(cursor):
cursor = None
if cursor is None:
self.counter += 1
cursor = self.connection.cursor()
self.cursors.append(cursor)
self.cursors.append(CursorStorage(cursor, keep_cursor))
self.current_cursor = cursor

while sql.strip().endswith(';'):
Expand All @@ -558,9 +581,9 @@ def string2java_string(sql_or_list):
if error_message.startswith(prefix):
error_message = error_message[len(prefix):]
if error_message is not None:
print(sql, file=sys.stderr)
LOGGER.error(sql)
if isinstance(parameters, (list, tuple)):
print(parameters, file=sys.stderr)
LOGGER.error(str(parameters))
raise SQLExecuteException(error_message)

if not hasattr(cursor, PARENT_CONNECTION):
Expand Down Expand Up @@ -624,10 +647,11 @@ def get_data(self, cursor: Cursor = None, return_type=tuple,
fetch_error = error

if fetch_error is not None:
print('Fetch error in batch {} of size {}.'.format(batch_nr, array_size), file=sys.stderr)
LOGGER.error('Fetch error in batch {} of size {}.'.format(batch_nr, array_size))
error_msg = str(fetch_error)
print(error_msg, file=sys.stderr)
raise SQLExecuteException('Failed to fetch data in batch {}: {}'.format(batch_nr, error_msg))
LOGGER.error(error_msg)
error_msg = 'Failed to fetch data in batch {}: {}'.format(batch_nr, error_msg)
raise SQLExecuteException(error_msg)

if len(results) == 0:
self.close(cursor)
Expand All @@ -640,9 +664,14 @@ def get_data(self, cursor: Cursor = None, return_type=tuple,
break

def commit(self):
"""
Commit the current transaction. This will commit all open cursors.
All open cursors are closed (also in auto-commit mode)
"""

commit_error = None
with self.statistics as stt:
for c in [cc for cc in self.cursors if cc.rowcount > 0]:
for c in [cc.cursor for cc in self.cursors if cc.cursor.rowcount > 0]:
stt.add_row_count(c.rowcount)
if not self.auto_commit:
try:
Expand All @@ -654,6 +683,11 @@ def commit(self):
raise CommitException(str(commit_error))

def rollback(self):
"""
Roll-back the current transaction. This will affect all open cursors.
All open cursors are closed (also in auto-commit mode)
"""

if not self.auto_commit:
self.connection.rollback()
self.close_all_cursors()
Expand All @@ -670,7 +704,7 @@ def query(self, sql: str, parameters=None, return_type=tuple, max_rows=0, array_
@param array_size: batch size for which results are buffered when retrieving from the database
@return: iterator of the specified return type, or the return type if max_rows=1
"""
cur = self.execute(sql, parameters, cursor=None)
cur = self.execute(sql, parameters, cursor=None, keep_cursor=True)
if cur.rowcount >= 0:
raise ValueError('The provided SQL is for updates, not to query. Use Execute method instead.')
return self.get_data(cur, return_type=return_type, include_none=False, max_rows=max_rows, array_size=array_size)
Expand Down
6 changes: 5 additions & 1 deletion lwetl/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@

import base64
import getpass
import logging
import os
import random
import sys

from cryptography.fernet import Fernet
from .exceptions import DecryptionError

# define a logger
LOGGER = logging.getLogger(os.path.basename(__file__).split('.')[0])

KEY = None


Expand Down Expand Up @@ -86,7 +90,7 @@ def decrypt(s: str, key=None, raise_error=False):
if raise_error:
raise DecryptionError('Cannot decrypt.')
else:
print('Password decryption error. Wrong password? {}'.format(e), file=sys.stderr)
LOGGER.critical('Password decryption error. Wrong password? {}'.format(e))
sys.exit(1)
return s2[2:2 + int(s2[0:2], 16) - 128]

Expand Down

0 comments on commit ed9e3b5

Please sign in to comment.