Skip to content

Commit

Permalink
joins support
Browse files Browse the repository at this point in the history
  • Loading branch information
antonio_antuan committed Mar 21, 2019
1 parent c0d2952 commit b89f9ec
Show file tree
Hide file tree
Showing 7 changed files with 446 additions and 92 deletions.
72 changes: 40 additions & 32 deletions clickhouse_sqlalchemy/drivers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
to_list,
)

from clickhouse_sqlalchemy import Table
from .. import types
from ..util import compat


# Column specifications
colspecs = {}

Expand Down Expand Up @@ -142,36 +142,39 @@ def visit_extract(self, extract, **kw):
return column

def visit_join(self, join, asfrom=False, **kwargs):
join_type = " "

if join.global_:
join_type += "GLOBAL "

if join.any:
join_type += "ANY "

if join.all:
join_type += "ALL "

if join.full:
join_type += "FULL OUTER JOIN "
elif join.isouter:
join_type += "LEFT OUTER JOIN "
join_stmt = join.left._compiler_dispatch(self, asfrom=asfrom, **kwargs)
join_type = join.type # need to make variable to prevent leaks in some debuggers
if join_type is None:
if join.isouter:
join_type = 'LEFT OUTER'
else:
join_type = 'INNER'
elif join_type is not None:
join_type = join_type.upper()
if join.isouter and 'INNER' in join_type:
raise exc.CompileError('can\'t compile join with specified INNER type and isouter=True')
# isouter=False by default, disable that checking
# elif not join.isouter and 'OUTER' in join.type:
# raise exc.CompileError('can\'t compile join with specified OUTER type and isouter=False')
if join.full and 'FULL' not in join_type:
join_type = 'FULL ' + join_type

if join.strictness:
join_type = join.strictness.upper() + ' ' + join_type

if join.distribution:
join_type = join.distribution.upper() + ' ' + join_type

if join_type is not None:
join_stmt += ' ' + join_type.upper() + ' JOIN '

join_stmt += join.right._compiler_dispatch(self, asfrom=asfrom, **kwargs)

if isinstance(join.onclause, elements.Tuple):
join_stmt += ' USING ' + join.onclause._compiler_dispatch(self, **kwargs)
else:
join_type += "INNER JOIN "

if not isinstance(join.onclause, elements.Tuple):
raise exc.CompileError(
"Only tuple elements are supported. "
"Got: %s" % type(join.onclause)
)

return (
join.left._compiler_dispatch(self, asfrom=True, **kwargs) +
join_type +
join.right._compiler_dispatch(self, asfrom=True, **kwargs) +
" USING " + join.onclause._compiler_dispatch(self, **kwargs)
)
join_stmt += ' ON ' + join.onclause._compiler_dispatch(self, **kwargs)
return join_stmt

def visit_array_join(self, array_join, **kwargs):
return ' \nARRAY JOIN {columns}'.format(
Expand Down Expand Up @@ -526,7 +529,7 @@ class ClickHouseDialect(default.DefaultDialect):
construct_arguments = [
(schema.Table, {
'data': [],
'cluster': None
'cluster': None,
})
]

Expand All @@ -544,6 +547,11 @@ def has_table(self, connection, table_name, schema=None):
return True
return False

def reflecttable(self, connection, table, include_columns, exclude_columns, **opts):
table.metadata.remove(table)
ch_table = Table._make_from_standard(table)
super().reflecttable(connection, ch_table, include_columns, exclude_columns, **opts)

@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
query = 'DESCRIBE TABLE {}'.format(table_name)
Expand Down Expand Up @@ -681,5 +689,5 @@ def _check_unicode_description(self, connection):
return True

def _get_server_version_info(self, connection):
version = connection.scalar('select version()')
version = connection.scalar('select version() format TabSeparatedWithNamesAndTypes;')
return tuple(int(x) for x in version.split('.'))
95 changes: 55 additions & 40 deletions clickhouse_sqlalchemy/orm/query.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from contextlib import contextmanager
from sqlalchemy import exc
import sqlalchemy.orm.query as query_module
from sqlalchemy.orm.query import Query as BaseQuery

from sqlalchemy.orm.util import _ORMJoin as _StandardORMJoin
from clickhouse_sqlalchemy.sql.selectable import Join
from ..ext.clauses import (
sample_clause,
ArrayJoin,
Expand Down Expand Up @@ -43,45 +46,57 @@ def sample(self, sample):
return self

def join(self, *props, **kwargs):
global_ = kwargs.pop('global_', False)

any_ = kwargs.pop('any', None)
all_ = kwargs.pop('all', None)

if all_ is None and any_ is None:
raise ValueError("ANY or ALL must be specified")

type = kwargs.pop('type', None)
strictness = kwargs.pop('strictness', None)
distribution = kwargs.pop('distribution', None)
rv = super(Query, self).join(*props, **kwargs)
diff = set(rv._from_obj) - set(self._from_obj)

assert len(diff) < 2

if diff:
orm_join = diff.pop()
orm_join.any = any_
orm_join.all = all_
orm_join.global_ = global_

return rv
joined = list(set(rv._from_obj) - set(self._from_obj))[0]
new = _ORMJoin._from_standard(joined,
type=type,
strictness=strictness,
distribution=distribution)

@contextmanager
def replace_join():
original = query_module.orm_join
query_module.orm_join = new
yield
query_module.orm_join = original

with replace_join():
return super(Query, self).join(*props, **kwargs)

def outerjoin(self, *props, **kwargs):
global_ = kwargs.pop('global_', False)

any_ = kwargs.pop('any', None)
all_ = kwargs.pop('all', None)

if all_ is None and any_ is None:
raise ValueError("ANY or ALL must be specified")

rv = super(Query, self).outerjoin(*props, **kwargs)
diff = set(rv._from_obj) - set(self._from_obj)

assert len(diff) < 2

if diff:
orm_join = diff.pop()
orm_join.any = any_
orm_join.all = all_
orm_join.global_ = global_

return rv
kwargs['type'] = kwargs.get('type') or 'LEFT OUTER'
return self.join(*props, **kwargs)


class _ORMJoin(_StandardORMJoin):
@classmethod
def _from_standard(cls, standard_join, type, strictness, distribution):
return cls(
standard_join.left,
standard_join.right,
standard_join.onclause,
type=type,
strictness=strictness,
distribution=distribution
)

def __init__(self, left, right, onclause=None, type=None, strictness=None, distribution=None):
if type is None:
raise ValueError('JOIN type must be specified, '
'expected one of: '
'INNER, RIGHT OUTER, LEFT OUTER, FULL OUTER, CROSS')
super().__init__(left, right, onclause, False, False, None, None)
self.distribution = distribution
self.strictness = str
self.type = type
self.strictness = None
if strictness:
self.strictness = strictness
self.distribution = distribution
self.type = type

def __call__(self, *args, **kwargs):
return self
25 changes: 25 additions & 0 deletions clickhouse_sqlalchemy/sql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
_bind_or_error,
)

from clickhouse_sqlalchemy.sql.selectable import (
Join,
Select,
)
from . import ddl


Expand All @@ -13,3 +17,24 @@ def drop(self, bind=None, checkfirst=False, if_exists=False):
bind._run_visitor(ddl.SchemaDropper,
self,
checkfirst=checkfirst, if_exists=if_exists)

def join(self, right, onclause=None, isouter=False, full=False, type=None, strictness=None, distribution=None):
return Join(self, right,
onclause=onclause, type=type,
isouter=isouter, full=full,
strictness=strictness, distribution=distribution)

def select(self, whereclause=None, **params):
return Select([self], whereclause, **params)

@classmethod
def _make_from_standard(cls, std_table):
ch_table = cls(std_table.name, std_table.metadata)
ch_table.schema = std_table.schema
ch_table.fullname = std_table.fullname
ch_table.implicit_returning = std_table.implicit_returning
ch_table.comment = std_table.comment
ch_table.info = std_table.info
ch_table._prefixes = std_table._prefixes
ch_table.dialect_options = std_table.dialect_options
return ch_table
27 changes: 26 additions & 1 deletion clickhouse_sqlalchemy/sql/selectable.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from sqlalchemy.sql.selectable import Select as StandardSelect
from sqlalchemy.sql.selectable import (
Select as StandardSelect,
Join as StandardJoin,
)

from clickhouse_sqlalchemy.ext.clauses import ArrayJoin
from ..ext.clauses import sample_clause
Expand All @@ -7,6 +10,21 @@
__all__ = ('Select', 'select')


class Join(StandardJoin):

def __init__(self, left, right,
onclause=None, isouter=False, full=False,
type=None, strictness=None, distribution=None):
if type is not None:
type = type.upper()
super().__init__(left, right, onclause, isouter=isouter, full=full)
self.strictness = None
if strictness:
self.strictness = strictness
self.distribution = distribution
self.type = type


class Select(StandardSelect):
_with_totals = False
_sample_clause = None
Expand All @@ -24,5 +42,12 @@ def array_join(self, *columns):
self._array_join = ArrayJoin(*columns)
return self

def join(self, right, onclause=None, isouter=False, full=False, type=None, strictness=None, distribution=None):
return Join(self, right,
onclause=onclause, type=type,
isouter=isouter, full=full,
strictness=strictness, distribution=distribution)


select = Select
join = Join
50 changes: 34 additions & 16 deletions tests/orm/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,29 +117,29 @@ def create_tables(self, num):
Column('y', types.Int32, primary_key=True),
) for i in range(num)]

def test_unsupported_expressoin(self):
t1, t2 = self.create_tables(2)

query = session.query(t1.c.x).join(t2, literal(True), any=True)
with self.assertRaises(exc.CompileError) as ex:
self.compile(query)

self.assertIn('Only tuple elements are supported', str(ex.exception))

def test_joins(self):
t1, t2 = self.create_tables(2)

query = session.query(t1.c.x, t2.c.x) \
.join(t2, tuple_(t1.c.x, t1.c.y), any=True)
.join(
t2,
t1.c.x == t1.c.y,
type='inner',
strictness='any')

self.assertEqual(
self.compile(query),
"SELECT x AS t0_x, x AS t1_x FROM t0 "
"ANY INNER JOIN t1 USING x, y"
"ANY INNER JOIN t1 ON x = y"
)

query = session.query(t1.c.x, t2.c.x) \
.join(t2, tuple_(t1.c.x, t1.c.y), all=True)
.join(
t2,
tuple_(t1.c.x, t1.c.y),
type='inner',
strictness='all'
)

self.assertEqual(
self.compile(query),
Expand All @@ -148,7 +148,11 @@ def test_joins(self):
)

query = session.query(t1.c.x, t2.c.x) \
.join(t2, tuple_(t1.c.x, t1.c.y), all=True, global_=True)
.join(t2,
tuple_(t1.c.x, t1.c.y),
type='inner',
strictness='all',
distribution='global')

self.assertEqual(
self.compile(query),
Expand All @@ -157,7 +161,12 @@ def test_joins(self):
)

query = session.query(t1.c.x, t2.c.x) \
.outerjoin(t2, tuple_(t1.c.x, t1.c.y), all=True, global_=True)
.outerjoin(t2,
tuple_(t1.c.x, t1.c.y),
type='left outer',
strictness='all',
distribution='global'
)

self.assertEqual(
self.compile(query),
Expand All @@ -166,7 +175,13 @@ def test_joins(self):
)

query = session.query(t1.c.x, t2.c.x) \
.outerjoin(t2, tuple_(t1.c.x, t1.c.y), all=True, global_=True)
.outerjoin(
t2,
tuple_(t1.c.x, t1.c.y),
type='LEFT OUTER',
strictness='ALL',
distribution='GLOBAL'
)

self.assertEqual(
self.compile(query),
Expand All @@ -175,7 +190,10 @@ def test_joins(self):
)

query = session.query(t1.c.x, t2.c.x) \
.outerjoin(t2, tuple_(t1.c.x, t1.c.y), all=True, full=True)
.outerjoin(t2,
tuple_(t1.c.x, t1.c.y),
strictness='ALL',
type='FULL OUTER')

self.assertEqual(
self.compile(query),
Expand Down
Loading

0 comments on commit b89f9ec

Please sign in to comment.