Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for LIMIT BY clause #97

Merged
merged 4 commits into from
Jun 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions clickhouse_sqlalchemy/drivers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,21 @@ def visit_if__func(self, func, **kw):
self.process(func.clauses.clauses[2], **kw)
)

def limit_by_clause(self, select, **kw):
text = ''
limit_by_clause = select._limit_by_clause
if limit_by_clause:
text += ' LIMIT '
if limit_by_clause.offset is not None:
text += self.process(limit_by_clause.offset, **kw) + ', '
text += self.process(limit_by_clause.limit, **kw)
limit_by_exprs = limit_by_clause.by_clauses._compiler_dispatch(
self, **kw
)
text += ' BY ' + limit_by_exprs

return text

def limit_clause(self, select, **kw):
text = ''
if select._limit_clause is not None:
Expand Down Expand Up @@ -277,6 +292,11 @@ def _compose_select_body(
if select._order_by_clause.clauses:
text += self.order_by_clause(select, **kwargs)

limit_by_clause = getattr(select, '_limit_by_clause', None)

if limit_by_clause is not None:
text += self.limit_by_clause(select, **kwargs)

if (select._limit_clause is not None or
select._offset_clause is not None):
text += self.limit_clause(select, **kwargs)
Expand Down
15 changes: 15 additions & 0 deletions clickhouse_sqlalchemy/ext/clauses.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from sqlalchemy import util, exc
from sqlalchemy.sql import type_api
from sqlalchemy.sql.elements import (
_literal_as_label_reference,
BindParameter,
ColumnElement,
ClauseList
)
from sqlalchemy.sql.selectable import _offset_or_limit_clause
from sqlalchemy.sql.visitors import Visitable


Expand All @@ -29,6 +31,19 @@ def sample_clause(element):
return SampleParam(None, element, unique=True)


class LimitByClause:

def __init__(self, by_clauses, limit, offset):
self.by_clauses = ClauseList(
*by_clauses, _literal_as_text=_literal_as_label_reference
)
self.offset = _offset_or_limit_clause(offset)
self.limit = _offset_or_limit_clause(limit)

def __bool__(self):
return bool(self.by_clauses.clauses)


class Lambda(ColumnElement):
"""Represent a lambda function, ``Lambda(lambda x: 2 * x)``."""

Expand Down
19 changes: 12 additions & 7 deletions clickhouse_sqlalchemy/orm/query.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
from contextlib import contextmanager

from sqlalchemy import exc
from sqlalchemy.orm.base import _generative
import sqlalchemy.orm.query as query_module
from sqlalchemy.orm.query import Query as BaseQuery
from sqlalchemy.orm.util import _ORMJoin as _StandardORMJoin

from ..ext.clauses import (
sample_clause,
ArrayJoin,
LimitByClause,
sample_clause,
)


class Query(BaseQuery):
_with_totals = False
_final = None
_sample = None
_limit_by = None
_array_join = None

def _compile_context(self, labels=True):
Expand All @@ -24,10 +27,12 @@ def _compile_context(self, labels=True):
statement._with_totals = self._with_totals
statement._final_clause = self._final
statement._sample_clause = sample_clause(self._sample)
statement._limit_by_clause = self._limit_by
statement._array_join = self._array_join

return context

@_generative()
def with_totals(self):
if not self._group_by:
raise exc.InvalidRequestError(
Expand All @@ -37,21 +42,21 @@ def with_totals(self):

self._with_totals = True

return self

@_generative()
def array_join(self, *columns):
self._array_join = ArrayJoin(*columns)
return self

@_generative()
def final(self):
self._final = True

return self

@_generative()
def sample(self, sample):
self._sample = sample

return self
@_generative()
def limit_by(self, by_clauses, limit, offset=None):
self._limit_by = LimitByClause(by_clauses, limit, offset)

def join(self, *props, **kwargs):
type = kwargs.pop('type', None)
Expand Down
21 changes: 15 additions & 6 deletions clickhouse_sqlalchemy/sql/selectable.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from sqlalchemy.sql.base import _generative
from sqlalchemy.sql.selectable import (
Select as StandardSelect,
Join as StandardJoin,
)

from clickhouse_sqlalchemy.ext.clauses import ArrayJoin
from ..ext.clauses import sample_clause
from ..ext.clauses import (
ArrayJoin,
LimitByClause,
sample_clause,
)


__all__ = ('Select', 'select')
Expand All @@ -30,23 +34,28 @@ class Select(StandardSelect):
_with_totals = False
_final_clause = None
_sample_clause = None
_limit_by_clause = None
_array_join = None

@_generative
def with_totals(self):
self._with_totals = True
return self

@_generative
def final(self):
self._final_clause = True
return self

@_generative
def sample(self, sample):
self._sample_clause = sample_clause(sample)
return self

@_generative
def limit_by(self, by_clauses, limit, offset=None):
self._limit_by_clause = LimitByClause(by_clauses, limit, offset)

@_generative
def array_join(self, *columns):
self._array_join = ArrayJoin(*columns)
return self

def join(self, right, onclause=None, isouter=False, full=False, type=None,
strictness=None, distribution=None):
Expand Down
30 changes: 30 additions & 0 deletions tests/orm/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,36 @@ def test_final(self):
'SELECT t1.x AS t1_x FROM t1 FINAL GROUP BY t1.x'
)

def test_limit_by(self):
table = self._make_table()

query = self.session.query(table.c.x).order_by(table.c.x)\
.limit_by([table.c.x], limit=1)
self.assertEqual(
self.compile(query),
'SELECT t1.x AS t1_x FROM t1 ORDER BY t1.x '
'LIMIT %(param_1)s BY t1.x'
)
self.assertEqual(
self.compile(query, literal_binds=True),
'SELECT t1.x AS t1_x FROM t1 ORDER BY t1.x LIMIT 1 BY t1.x'
)

def test_limit_by_with_offset(self):
table = self._make_table()

query = self.session.query(table.c.x).order_by(table.c.x)\
.limit_by([table.c.x], offset=1, limit=2)
self.assertEqual(
self.compile(query),
'SELECT t1.x AS t1_x FROM t1 ORDER BY t1.x '
'LIMIT %(param_1)s, %(param_2)s BY t1.x'
)
self.assertEqual(
self.compile(query, literal_binds=True),
'SELECT t1.x AS t1_x FROM t1 ORDER BY t1.x LIMIT 1, 2 BY t1.x'
)

def test_lambda_functions(self):
query = self.session.query(
func.arrayFilter(
Expand Down
29 changes: 29 additions & 0 deletions tests/sql/test_selectable.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,35 @@ def test_final(self):
'SELECT t1.x FROM t1 FINAL GROUP BY t1.x'
)

def test_limit_by(self):
table = self._make_table()

query = select([table.c.x]).order_by(table.c.x)\
.limit_by([table.c.x], limit=1)
self.assertEqual(
self.compile(query),
'SELECT t1.x FROM t1 ORDER BY t1.x LIMIT %(param_1)s BY t1.x'
)
self.assertEqual(
self.compile(query, literal_binds=True),
'SELECT t1.x FROM t1 ORDER BY t1.x LIMIT 1 BY t1.x'
)

def test_limit_by_with_offset(self):
table = self._make_table()

query = select([table.c.x]).order_by(table.c.x)\
.limit_by([table.c.x], offset=1, limit=2)
self.assertEqual(
self.compile(query),
'SELECT t1.x FROM t1 ORDER BY t1.x '
'LIMIT %(param_1)s, %(param_2)s BY t1.x'
)
self.assertEqual(
self.compile(query, literal_binds=True),
'SELECT t1.x FROM t1 ORDER BY t1.x LIMIT 1, 2 BY t1.x'
)

def test_nested_type(self):
table = self._make_table(
't1',
Expand Down