Skip to content

Commit

Permalink
Fix percent escape binary mod
Browse files Browse the repository at this point in the history
  • Loading branch information
xzkostyan committed Oct 21, 2017
1 parent 0465287 commit 552100e
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 9 deletions.
16 changes: 10 additions & 6 deletions src/drivers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,19 @@


class ClickHouseIdentifierPreparer(compiler.IdentifierPreparer):
def quote_identifier(self, value):
# Never quote identifiers.
return self._escape_identifier(value)

def quote(self, ident, force=None):
return ident
def _escape_identifier(self, value):
value = value.replace(self.escape_quote, self.escape_to_quote)
return value.replace('%', '%%')


class ClickHouseCompiler(compiler.SQLCompiler):
def visit_mod_binary(self, binary, operator, **kw):
return self.process(binary.left, **kw) + " %% " + \
self.process(binary.right, **kw)

def post_process_text(self, text):
return text.replace('%', '%%')

def visit_count_func(self, fn, **kw):
# count accepts zero arguments.
return 'count%s' % self.process(fn.clause_expr, **kw)
Expand Down
27 changes: 24 additions & 3 deletions tests/drivers/test_escaping.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
from decimal import Decimal
from datetime import date

from sqlalchemy import literal
from sqlalchemy import Column, literal

from src import types, engines, Table
from src.drivers.escaper import Escaper
from tests.session import session
from tests.testcase import BaseTestCase


class EscapingTestCase(BaseTestCase):
def compile(self, clause, **kwargs):
def escaped_compile(self, clause, **kwargs):
return str(self._compile(clause, **kwargs))

def test_select_escaping(self):
query = session.query(literal('\t'))
self.assertEqual(
self.compile(session.query(literal('\t')), literal_binds=True),
self.escaped_compile(query, literal_binds=True),
"SELECT '\t' AS param_1"
)

Expand All @@ -36,3 +38,22 @@ def test_escaper(self):
e.escape('str')

self.assertIn('Unsupported param format', str(ex.exception))

def test_escape_binary_mod(self):
query = session.query(literal(1) % literal(2))
self.assertEqual(
self.compile(query, literal_binds=True),
'SELECT 1 %% 2 AS anon_1'
)

table = Table(
't', self.metadata(),
Column('x', types.Int32, primary_key=True),
engines.Memory()
)

query = session.query(table.c.x % table.c.x)
self.assertEqual(
self.compile(query, literal_binds=True),
'SELECT x %% x AS anon_1 FROM t'
)

0 comments on commit 552100e

Please sign in to comment.