Skip to content

Commit

Permalink
array join support
Browse files Browse the repository at this point in the history
  • Loading branch information
antonio_antuan committed Feb 2, 2019
1 parent d753192 commit e325fb6
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 6 deletions.
11 changes: 10 additions & 1 deletion clickhouse_sqlalchemy/drivers/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import enum

from sqlalchemy import schema, types as sqltypes, exc, util as sa_util
from sqlalchemy.engine import default, reflection
from sqlalchemy.sql import (
compiler, expression, type_api, literal_column, elements
)
from sqlalchemy.sql.elements import Label
from sqlalchemy.types import DATE, DATETIME, FLOAT
from sqlalchemy.util import warn
from sqlalchemy.util.compat import inspect_getfullargspec
Expand Down Expand Up @@ -195,6 +195,15 @@ def _compose_select_body(
else:
text += self.default_from()

if select._array_join:
text += ' \nARRAY JOIN {columns}'.format(
columns=', '.join(
col.element.name if isinstance(col, Label) else col.name
for col in select._array_join

)
)

sample_clause = getattr(select, '_sample_clause', None)

if sample_clause is not None:
Expand Down
11 changes: 10 additions & 1 deletion clickhouse_sqlalchemy/ext/clauses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from sqlalchemy import util, exc
from sqlalchemy.sql import type_api
from sqlalchemy.sql.elements import BindParameter, ColumnElement
from sqlalchemy.sql.elements import (
BindParameter,
ColumnElement,
ClauseElement,
ClauseList,
)
from sqlalchemy.sql.visitors import Visitable


Expand Down Expand Up @@ -36,3 +41,7 @@ def __init__(self, func):

self.type = type_api.NULLTYPE
self.func = func


class ArrayJoin(ClauseList):
__visit_name__ = 'ARRAY_JOIN'
6 changes: 6 additions & 0 deletions clickhouse_sqlalchemy/orm/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
class Query(BaseQuery):
_with_totals = False
_sample = None
_array_join = None

def _compile_context(self, labels=True):
context = super(Query, self)._compile_context(labels=labels)
statement = context.statement

statement._with_totals = self._with_totals
statement._sample_clause = sample_clause(self._sample)
statement._array_join = self._array_join

return context

Expand All @@ -28,6 +30,10 @@ def with_totals(self):

return self

def array_join(self, *columns):
self._array_join = columns
return self

def sample(self, sample):
self._sample = sample

Expand Down
5 changes: 5 additions & 0 deletions clickhouse_sqlalchemy/sql/selectable.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
class Select(StandardSelect):
_with_totals = False
_sample_clause = None
_array_join = None

def with_totals(self):
self._with_totals = True
Expand All @@ -18,5 +19,9 @@ def sample(self, sample):
self._sample_clause = sample_clause(sample)
return self

def array_join(self, *columns):
self._array_join = columns
return self


select = Select
37 changes: 33 additions & 4 deletions tests/orm/test_select.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from sqlalchemy import Column, exc, func, literal
from sqlalchemy import (
Column,
exc,
func,
literal,
literal_column,
)
from sqlalchemy import text
from sqlalchemy import tuple_

Expand All @@ -9,12 +15,13 @@


class SelectTestCase(BaseTestCase):
def create_table(self):
def create_table(self, *columns):
metadata = self.metadata()

return Table(
't1', metadata,
Column('x', types.Int32, primary_key=True)
Column('x', types.Int32, primary_key=True),
*columns
)

def test_select(self):
Expand Down Expand Up @@ -46,6 +53,27 @@ def test_group_by_query(self):

self.assertIn('with_totals', str(ex.exception))

def make_count_query(self, base_query):
return base_query.from_self(func.count(literal_column('*')))

def test_array_join(self):
table = self.create_table(
Column('nested.array_column', types.Array(types.Int8)),
Column('nested.another_array_column', types.Array(types.Int8))
)
first_label = table.c['nested.array_column'].label('from_array')
second_not_label = table.c['nested.another_array_column']
query = session.query(first_label, second_not_label)\
.array_join(first_label, second_not_label)
self.assertEqual(
self.compile(query),
'SELECT '
'"nested.array_column" AS from_array, '
'"nested.another_array_column" AS "t1_nested.another_array_column"'
' FROM t1 '
'ARRAY JOIN nested.array_column, nested.another_array_column'
)

def test_sample(self):
table = self.create_table()

Expand All @@ -54,7 +82,8 @@ def test_sample(self):
self.compile(query),
'SELECT x AS t1_x FROM t1 SAMPLE %(param_1)s GROUP BY x'
)

q = self.make_count_query(query)
self.compile(q)
self.assertEqual(
self.compile(query, literal_binds=True),
'SELECT x AS t1_x FROM t1 SAMPLE 0.1 GROUP BY x'
Expand Down

0 comments on commit e325fb6

Please sign in to comment.