Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

array join support #44

Merged
merged 4 commits into from
Feb 14, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion clickhouse_sqlalchemy/drivers/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import enum

from sqlalchemy import schema, types as sqltypes, exc, util as sa_util
from sqlalchemy.engine import default, reflection
from sqlalchemy.sql import (
compiler, expression, type_api, literal_column, elements
)
from sqlalchemy.sql.elements import Label
from sqlalchemy.types import DATE, DATETIME, FLOAT
from sqlalchemy.util import warn
from sqlalchemy.util.compat import inspect_getfullargspec
Expand Down Expand Up @@ -195,6 +195,15 @@ def _compose_select_body(
else:
text += self.default_from()

if select._array_join:
text += ' \nARRAY JOIN {columns}'.format(
columns=', '.join(
col.element.name if isinstance(col, Label) else col.name
for col in select._array_join

)
)

sample_clause = getattr(select, '_sample_clause', None)

if sample_clause is not None:
Expand Down
10 changes: 9 additions & 1 deletion clickhouse_sqlalchemy/ext/clauses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from sqlalchemy import util, exc
from sqlalchemy.sql import type_api
from sqlalchemy.sql.elements import BindParameter, ColumnElement
from sqlalchemy.sql.elements import (
BindParameter,
ColumnElement,
ClauseList,
)
from sqlalchemy.sql.visitors import Visitable


Expand Down Expand Up @@ -36,3 +40,7 @@ def __init__(self, func):

self.type = type_api.NULLTYPE
self.func = func


class ArrayJoin(ClauseList):
__visit_name__ = 'ARRAY_JOIN'
6 changes: 6 additions & 0 deletions clickhouse_sqlalchemy/orm/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
class Query(BaseQuery):
_with_totals = False
_sample = None
_array_join = None

def _compile_context(self, labels=True):
context = super(Query, self)._compile_context(labels=labels)
statement = context.statement

statement._with_totals = self._with_totals
statement._sample_clause = sample_clause(self._sample)
statement._array_join = self._array_join

return context

Expand All @@ -28,6 +30,10 @@ def with_totals(self):

return self

def array_join(self, *columns):
self._array_join = columns
return self

def sample(self, sample):
self._sample = sample

Expand Down
5 changes: 5 additions & 0 deletions clickhouse_sqlalchemy/sql/selectable.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
class Select(StandardSelect):
_with_totals = False
_sample_clause = None
_array_join = None

def with_totals(self):
self._with_totals = True
Expand All @@ -18,5 +19,9 @@ def sample(self, sample):
self._sample_clause = sample_clause(sample)
return self

def array_join(self, *columns):
self._array_join = columns
return self


select = Select
37 changes: 33 additions & 4 deletions tests/orm/test_select.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from sqlalchemy import Column, exc, func, literal
from sqlalchemy import (
Column,
exc,
func,
literal,
literal_column,
)
from sqlalchemy import text
from sqlalchemy import tuple_

Expand All @@ -9,12 +15,13 @@


class SelectTestCase(BaseTestCase):
def create_table(self):
def create_table(self, *columns):
metadata = self.metadata()

return Table(
't1', metadata,
Column('x', types.Int32, primary_key=True)
Column('x', types.Int32, primary_key=True),
*columns
)

def test_select(self):
Expand Down Expand Up @@ -46,6 +53,27 @@ def test_group_by_query(self):

self.assertIn('with_totals', str(ex.exception))

def make_count_query(self, base_query):
return base_query.from_self(func.count(literal_column('*')))

def test_array_join(self):
table = self.create_table(
Column('nested.array_column', types.Array(types.Int8)),
Column('nested.another_array_column', types.Array(types.Int8))
)
first_label = table.c['nested.array_column'].label('from_array')
second_not_label = table.c['nested.another_array_column']
query = session.query(first_label, second_not_label)\
.array_join(first_label, second_not_label)
self.assertEqual(
self.compile(query),
'SELECT '
'"nested.array_column" AS from_array, '
'"nested.another_array_column" AS "t1_nested.another_array_column"'
' FROM t1 '
'ARRAY JOIN nested.array_column, nested.another_array_column'
)

def test_sample(self):
table = self.create_table()

Expand All @@ -54,7 +82,8 @@ def test_sample(self):
self.compile(query),
'SELECT x AS t1_x FROM t1 SAMPLE %(param_1)s GROUP BY x'
)

q = self.make_count_query(query)
self.compile(q)
antonio-antuan marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(
self.compile(query, literal_binds=True),
'SELECT x AS t1_x FROM t1 SAMPLE 0.1 GROUP BY x'
Expand Down