Skip to content

Commit

Permalink
Merge pull request #97 from ods/limit_by
Browse files Browse the repository at this point in the history
Add support for LIMIT BY clause
  • Loading branch information
xzkostyan committed Jun 19, 2020
2 parents 0ff0bee + 5ad7f69 commit 5890383
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 13 deletions.
20 changes: 20 additions & 0 deletions clickhouse_sqlalchemy/drivers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,21 @@ def visit_if__func(self, func, **kw):
self.process(func.clauses.clauses[2], **kw)
)

def limit_by_clause(self, select, **kw):
text = ''
limit_by_clause = select._limit_by_clause
if limit_by_clause:
text += ' LIMIT '
if limit_by_clause.offset is not None:
text += self.process(limit_by_clause.offset, **kw) + ', '
text += self.process(limit_by_clause.limit, **kw)
limit_by_exprs = limit_by_clause.by_clauses._compiler_dispatch(
self, **kw
)
text += ' BY ' + limit_by_exprs

return text

def limit_clause(self, select, **kw):
text = ''
if select._limit_clause is not None:
Expand Down Expand Up @@ -277,6 +292,11 @@ def _compose_select_body(
if select._order_by_clause.clauses:
text += self.order_by_clause(select, **kwargs)

limit_by_clause = getattr(select, '_limit_by_clause', None)

if limit_by_clause is not None:
text += self.limit_by_clause(select, **kwargs)

if (select._limit_clause is not None or
select._offset_clause is not None):
text += self.limit_clause(select, **kwargs)
Expand Down
15 changes: 15 additions & 0 deletions clickhouse_sqlalchemy/ext/clauses.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from sqlalchemy import util, exc
from sqlalchemy.sql import type_api
from sqlalchemy.sql.elements import (
_literal_as_label_reference,
BindParameter,
ColumnElement,
ClauseList
)
from sqlalchemy.sql.selectable import _offset_or_limit_clause
from sqlalchemy.sql.visitors import Visitable


Expand All @@ -29,6 +31,19 @@ def sample_clause(element):
return SampleParam(None, element, unique=True)


class LimitByClause:

def __init__(self, by_clauses, limit, offset):
self.by_clauses = ClauseList(
*by_clauses, _literal_as_text=_literal_as_label_reference
)
self.offset = _offset_or_limit_clause(offset)
self.limit = _offset_or_limit_clause(limit)

def __bool__(self):
return bool(self.by_clauses.clauses)


class Lambda(ColumnElement):
"""Represent a lambda function, ``Lambda(lambda x: 2 * x)``."""

Expand Down
19 changes: 12 additions & 7 deletions clickhouse_sqlalchemy/orm/query.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
from contextlib import contextmanager

from sqlalchemy import exc
from sqlalchemy.orm.base import _generative
import sqlalchemy.orm.query as query_module
from sqlalchemy.orm.query import Query as BaseQuery
from sqlalchemy.orm.util import _ORMJoin as _StandardORMJoin

from ..ext.clauses import (
sample_clause,
ArrayJoin,
LimitByClause,
sample_clause,
)


class Query(BaseQuery):
_with_totals = False
_final = None
_sample = None
_limit_by = None
_array_join = None

def _compile_context(self, labels=True):
Expand All @@ -24,10 +27,12 @@ def _compile_context(self, labels=True):
statement._with_totals = self._with_totals
statement._final_clause = self._final
statement._sample_clause = sample_clause(self._sample)
statement._limit_by_clause = self._limit_by
statement._array_join = self._array_join

return context

@_generative()
def with_totals(self):
if not self._group_by:
raise exc.InvalidRequestError(
Expand All @@ -37,21 +42,21 @@ def with_totals(self):

self._with_totals = True

return self

@_generative()
def array_join(self, *columns):
self._array_join = ArrayJoin(*columns)
return self

@_generative()
def final(self):
self._final = True

return self

@_generative()
def sample(self, sample):
self._sample = sample

return self
@_generative()
def limit_by(self, by_clauses, limit, offset=None):
self._limit_by = LimitByClause(by_clauses, limit, offset)

def join(self, *props, **kwargs):
type = kwargs.pop('type', None)
Expand Down
21 changes: 15 additions & 6 deletions clickhouse_sqlalchemy/sql/selectable.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from sqlalchemy.sql.base import _generative
from sqlalchemy.sql.selectable import (
Select as StandardSelect,
Join as StandardJoin,
)

from clickhouse_sqlalchemy.ext.clauses import ArrayJoin
from ..ext.clauses import sample_clause
from ..ext.clauses import (
ArrayJoin,
LimitByClause,
sample_clause,
)


__all__ = ('Select', 'select')
Expand All @@ -30,23 +34,28 @@ class Select(StandardSelect):
_with_totals = False
_final_clause = None
_sample_clause = None
_limit_by_clause = None
_array_join = None

@_generative
def with_totals(self):
self._with_totals = True
return self

@_generative
def final(self):
self._final_clause = True
return self

@_generative
def sample(self, sample):
self._sample_clause = sample_clause(sample)
return self

@_generative
def limit_by(self, by_clauses, limit, offset=None):
self._limit_by_clause = LimitByClause(by_clauses, limit, offset)

@_generative
def array_join(self, *columns):
self._array_join = ArrayJoin(*columns)
return self

def join(self, right, onclause=None, isouter=False, full=False, type=None,
strictness=None, distribution=None):
Expand Down
30 changes: 30 additions & 0 deletions tests/orm/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,36 @@ def test_final(self):
'SELECT t1.x AS t1_x FROM t1 FINAL GROUP BY t1.x'
)

def test_limit_by(self):
table = self._make_table()

query = self.session.query(table.c.x).order_by(table.c.x)\
.limit_by([table.c.x], limit=1)
self.assertEqual(
self.compile(query),
'SELECT t1.x AS t1_x FROM t1 ORDER BY t1.x '
'LIMIT %(param_1)s BY t1.x'
)
self.assertEqual(
self.compile(query, literal_binds=True),
'SELECT t1.x AS t1_x FROM t1 ORDER BY t1.x LIMIT 1 BY t1.x'
)

def test_limit_by_with_offset(self):
table = self._make_table()

query = self.session.query(table.c.x).order_by(table.c.x)\
.limit_by([table.c.x], offset=1, limit=2)
self.assertEqual(
self.compile(query),
'SELECT t1.x AS t1_x FROM t1 ORDER BY t1.x '
'LIMIT %(param_1)s, %(param_2)s BY t1.x'
)
self.assertEqual(
self.compile(query, literal_binds=True),
'SELECT t1.x AS t1_x FROM t1 ORDER BY t1.x LIMIT 1, 2 BY t1.x'
)

def test_lambda_functions(self):
query = self.session.query(
func.arrayFilter(
Expand Down
29 changes: 29 additions & 0 deletions tests/sql/test_selectable.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,35 @@ def test_final(self):
'SELECT t1.x FROM t1 FINAL GROUP BY t1.x'
)

def test_limit_by(self):
table = self._make_table()

query = select([table.c.x]).order_by(table.c.x)\
.limit_by([table.c.x], limit=1)
self.assertEqual(
self.compile(query),
'SELECT t1.x FROM t1 ORDER BY t1.x LIMIT %(param_1)s BY t1.x'
)
self.assertEqual(
self.compile(query, literal_binds=True),
'SELECT t1.x FROM t1 ORDER BY t1.x LIMIT 1 BY t1.x'
)

def test_limit_by_with_offset(self):
table = self._make_table()

query = select([table.c.x]).order_by(table.c.x)\
.limit_by([table.c.x], offset=1, limit=2)
self.assertEqual(
self.compile(query),
'SELECT t1.x FROM t1 ORDER BY t1.x '
'LIMIT %(param_1)s, %(param_2)s BY t1.x'
)
self.assertEqual(
self.compile(query, literal_binds=True),
'SELECT t1.x FROM t1 ORDER BY t1.x LIMIT 1, 2 BY t1.x'
)

def test_nested_type(self):
table = self._make_table(
't1',
Expand Down

0 comments on commit 5890383

Please sign in to comment.