Skip to content

Commit

Permalink
Cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
Greger Stolt Nilsen committed Jun 19, 2023
1 parent 007cfec commit b5c811c
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 16 deletions.
3 changes: 1 addition & 2 deletions clickhouse_sqlalchemy/drivers/compilers/typecompiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def visit_map(self, type_, **kw):
self.process(key_type, **kw),
self.process(value_type, **kw)
)

def visit_point(self, type_, **kw):
return 'Point'

Expand All @@ -143,4 +143,3 @@ def visit_polygon(self, type_, *kw):

def visit_multipolygon(self, type_, *kw):
return 'MultiPolygon'

6 changes: 6 additions & 0 deletions clickhouse_sqlalchemy/drivers/http/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,23 +46,29 @@ def nullable_converter(subtype_str, x):
def nothing_converter(x):
return None


POINT_RE = re.compile(r'(-?\d*\.?\d+)')
RING_RE = re.compile(r'(\(.*?\))')
POLYGON_RE = re.compile(r'(\[.*?\])')
MULTIPOLYGON_RE = re.compile(r'\[\[.*?\]\]')


def point_converter(x):
return tuple([float(f) for f in POINT_RE.findall(x[1:-1])])


def ring_converter(x):
return [point_converter(f) for f in RING_RE.findall(x[1:-1])]


def polygon_converter(x):
return [ring_converter(f) for f in POLYGON_RE.findall(x[1:-1])]


def multipolygon_converter(x):
return [polygon_converter(f) for f in MULTIPOLYGON_RE.findall(x[1:-1])]


converters = {
'Int8': int,
'UInt8': int,
Expand Down
1 change: 0 additions & 1 deletion clickhouse_sqlalchemy/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,3 @@
from .geo import Ring
from .geo import Polygon
from .geo import MultiPolygon

3 changes: 3 additions & 0 deletions clickhouse_sqlalchemy/types/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
class Point(types.UserDefinedType):
__visit_name__ = "point"


class Ring(types.UserDefinedType):
__visit_name__ = "ring"


class Polygon(types.UserDefinedType):
__visit_name__ = "polygon"


class MultiPolygon(types.UserDefinedType):
__visit_name__ = "multipolygon"
21 changes: 8 additions & 13 deletions tests/types/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tests.testcase import BaseTestCase
from tests.util import with_native_and_http_sessions


@with_native_and_http_sessions
class GeoPointTestCase(BaseTestCase):
table = Table(
Expand All @@ -18,7 +19,7 @@ def test_create_table(self):
self.compile(CreateTable(self.table)),
'CREATE TABLE test (p Point) ENGINE = Memory'
)

def test_select_insert(self):
a = (10.1, 12.3)

Expand All @@ -38,8 +39,6 @@ def test_select_where_point(self):
self.table.c.p == (10.1, 12.3)).scalar(), a)




@with_native_and_http_sessions
class GeoRingTestCase(BaseTestCase):
table = Table(
Expand All @@ -53,7 +52,7 @@ def test_create_table(self):
self.compile(CreateTable(self.table)),
'CREATE TABLE test (r Ring) ENGINE = Memory'
)

def test_select_insert(self):
a = [(0, 0), (10, 0), (10, 10), (0, 10)]

Expand All @@ -79,7 +78,8 @@ def test_create_table(self):
)

def test_select_insert(self):
a = [[(20, 20), (50, 20), (50, 50), (20, 50)], [(30, 30), (50, 50), (50, 30)]]
a = [[(20, 20), (50, 20), (50, 50), (20, 50)],
[(30, 30), (50, 50), (50, 30)]]

with self.create_table(self.table):
self.session.execute(self.table.insert(), [{'pg': a}])
Expand All @@ -88,8 +88,6 @@ def test_select_insert(self):
self.assertEqual(res, a)




@with_native_and_http_sessions
class GeoMultiPolygonTestCase(BaseTestCase):
table = Table(
Expand All @@ -104,16 +102,13 @@ def test_create_table(self):
'CREATE TABLE test (mpg MultiPolygon) ENGINE = Memory'
)


def test_select_insert(self):
a = [[[(0, 0), (10, 0), (10, 10), (0, 10)]], [[(20, 20), (50, 20), (50, 50), (20, 50)],[(30, 30), (50, 50), (50, 30)]]]
a = [[[(0, 0), (10, 0), (10, 10), (0, 10)]],
[[(20, 20), (50, 20), (50, 50), (20, 50)],
[(30, 30), (50, 50), (50, 30)]]]

with self.create_table(self.table):
self.session.execute(self.table.insert(), [{'mpg': a}])
qres = self.session.query(self.table.c.mpg)
res = qres.scalar()
self.assertEqual(res, a)




0 comments on commit b5c811c

Please sign in to comment.