Skip to content

Commit

Permalink
Simplify parameters quoting.
Browse files Browse the repository at this point in the history
  • Loading branch information
termim committed Mar 24, 2024
1 parent 1c38137 commit e63d542
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 52 deletions.
73 changes: 22 additions & 51 deletions src/pymssql/_mssql.pyx
Expand Up @@ -2032,75 +2032,49 @@ cdef _quote_simple_value(value, charset='utf8'):
if isinstance(value, datetime.time):
return value.strftime("'%H:%M:%S.%f'").encode(charset)

return None
raise ValueError(f"Unsupported parameter type: {type(value)}")

cdef _quote_or_flatten(data, charset='utf8'):
result = _quote_simple_value(data, charset)

if result is not None:
return result

if not issubclass(type(data), (list, tuple)):
raise ValueError('expected a simple type, a tuple or a list')

quoted = []
for value in data:
value = _quote_simple_value(value, charset)

if value is None:
raise ValueError('found an unsupported type')

quoted.append(value)
return b'(' + b','.join(quoted) + b')'

# This function is supposed to take a simple value, tuple or dictionary,
# normally passed in via the params argument in the execute_* methods. It
# then quotes and flattens the arguments and returns then.
cdef _quote_data(data, charset='utf8'):
result = _quote_simple_value(data)

if result is not None:
return result

if issubclass(type(data), dict):
"""
This function is supposed to take a simple value, tuple or dictionary,
passed in via the params argument in the execute_* methods.
It then quotes and flattens the arguments and returns them.
"""
if isinstance(data, dict):
result = {}
for k, v in data.items():
result[k] = _quote_or_flatten(v, charset)
if isinstance(v, (list, tuple)):
result[k] = b'(' + b','.join([ _quote_simple_value(_v, charset) for _v in v ]) + b')'
else:
result[k] = _quote_simple_value(v, charset)
return result

if issubclass(type(data), tuple):
if isinstance(data, (list, tuple)):
result = []
for v in data:
result.append(_quote_or_flatten(v, charset))
if isinstance(v, (list, tuple)):
_v = b'(' + b','.join([ _quote_simple_value(_v, charset) for _v in v ]) + b')'
else:
_v = _quote_simple_value(v, charset)
result.append(_v)
return tuple(result)

raise ValueError('expected a simple type, a tuple or a dictionary.')
return ( _quote_simple_value(data), )

_re_pos_param = re.compile(br'(%([sd]))')
_re_name_param = re.compile(br'(%\(([^\)]+)\)(?:[sd]))')

cdef _substitute_params(toformat, params=NoParams, charset='utf-8'):

if isinstance(toformat, unicode):
if isinstance(toformat, str):
toformat = toformat.encode(charset)
elif not isinstance(toformat, bytes):
raise exceptions.ProgrammingError(f"Query should be string or bytes, got { type(toformat)}")
raise exceptions.ProgrammingError(f"Query should be string or bytes, got {type(toformat)}")

if params is NoParams:
return toformat

if params is not None and not issubclass(type(params),
(bool, int, long, float, unicode, str, bytes, bytearray, dict, tuple,
datetime.datetime, datetime.date, datetime.time, dict, decimal.Decimal, uuid.UUID)):
raise ValueError("'params' arg (%r) can be only a tuple or a dictionary." % type(params))

if charset:
quoted = _quote_data(params, charset)
else:
quoted = _quote_data(params)

# positional string substitution now requires a tuple
if hasattr(quoted, 'startswith'):
quoted = (quoted,)
quoted = _quote_data(params, charset)

if isinstance(params, dict):
""" assume name based substitutions """
Expand Down Expand Up @@ -2164,9 +2138,6 @@ cdef _substitute_params(toformat, params=NoParams, charset='utf-8'):
def quote_simple_value(value):
return _quote_simple_value(value)

def quote_or_flatten(data):
return _quote_or_flatten(data)

def quote_data(data):
return _quote_data(data)

Expand Down
15 changes: 14 additions & 1 deletion tests/test_parameters.py
Expand Up @@ -72,7 +72,6 @@ def test_unicode_params():
'\u03A8'
)
eq_(res, b"SELECT * FROM \xce\x94 WHERE name = N'\xce\xa8'")

res = substitute_params(u"testing ascii (\u0105\u010D\u0119) 1=%d 'one'=%s", (1, 'str'))
eq_(res, b"testing ascii (\xc4\x85\xc4\x8d\xc4\x99) 1=1 'one'=N'str'")

Expand All @@ -90,6 +89,20 @@ def test_keyed_param_with_d():
{'emp_id': 13})
eq_(res, b'SELECT * FROM employees WHERE id = 13')

def test_keyed_tuple_param():
res = substitute_params(
'SELECT * FROM employees WHERE id IN %(emp_ids)d',
{'emp_ids': (13, 31)})
eq_(res, b'SELECT * FROM employees WHERE id IN (13,31)')
res = substitute_params(
'SELECT * FROM employees WHERE id IN %(emp_ids)d',
{'emp_ids': (b'13', b'31')})
eq_(res, b"SELECT * FROM employees WHERE id IN ('13','31')")
res = substitute_params(
'SELECT * FROM employees WHERE id IN %(emp_ids)d',
{'emp_ids': ('13', '31')})
eq_(res, b"SELECT * FROM employees WHERE id IN (N'13',N'31')")


def test_percent_not_touched_with_no_params():
sql = "SELECT COUNT(*) FROM employees WHERE name LIKE 'J%'"
Expand Down

0 comments on commit e63d542

Please sign in to comment.