Skip to content

Commit 5c00383

Browse files
committed
Overwriting Connector class is enough to operate with
datasette-connectors
1 parent 52416a7 commit 5c00383

File tree

7 files changed

+382
-273
lines changed

7 files changed

+382
-273
lines changed

datasette_connectors/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .connectors import Connector, OperationalError

datasette_connectors/connectors.py

Lines changed: 209 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
import pkg_resources
22
import functools
3+
import re
4+
import sqlite3
5+
6+
from .row import Row
7+
38

49
db_connectors = {}
510

@@ -26,76 +31,216 @@ def load():
2631
def add_connector(name, connector):
2732
db_connectors[name] = connector
2833

29-
@staticmethod
30-
@for_each_connector
31-
def table_names(connector, path):
32-
return connector.table_names(path)
33-
34-
@staticmethod
35-
@for_each_connector
36-
def hidden_table_names(connector, path):
37-
return connector.hidden_table_names(path)
34+
class DatabaseNotSupported(Exception):
35+
pass
3836

3937
@staticmethod
40-
@for_each_connector
41-
def view_names(connector, path):
42-
return connector.view_names(path)
38+
def connect(path):
39+
for connector in db_connectors.values():
40+
try:
41+
return connector.connect(path)
42+
except:
43+
pass
44+
else:
45+
raise ConnectorList.DatabaseNotSupported
46+
47+
48+
class Connection:
49+
def __init__(self, path, connector):
50+
self.path = path
51+
self.connector = connector
52+
53+
def execute(self, *args, **kwargs):
54+
cursor = Cursor(self)
55+
cursor.execute(*args, **kwargs)
56+
return cursor
57+
58+
def cursor(self):
59+
return Cursor(self)
60+
61+
def set_progress_handler(self, handler, n):
62+
pass
63+
64+
65+
class OperationalError(Exception):
66+
pass
67+
68+
69+
class Cursor:
70+
class QueryNotSupported(Exception):
71+
pass
72+
73+
def __init__(self, conn):
74+
self.conn = conn
75+
self.connector = conn.connector
76+
self.rows = []
77+
self.description = ()
78+
79+
def execute(
80+
self,
81+
sql,
82+
params=None,
83+
truncate=False,
84+
custom_time_limit=None,
85+
page_size=None,
86+
log_sql_errors=True,
87+
):
88+
if params is None:
89+
params = {}
90+
results = []
91+
truncated = False
92+
description = ()
93+
94+
# Normalize sql
95+
sql = sql.strip()
96+
sql = ' '.join(sql.split())
97+
98+
if sql == "select name from sqlite_master where type='table'" or \
99+
sql == "select name from sqlite_master where type=\"table\"":
100+
results = [{'name': name} for name in self.connector.table_names()]
101+
elif sql == "select name from sqlite_master where rootpage = 0 and sql like '%VIRTUAL TABLE%USING FTS%'":
102+
results = [{'name': name} for name in self.connector.hidden_table_names()]
103+
elif sql == 'select 1 from sqlite_master where tbl_name = "geometry_columns"':
104+
if self.connector.detect_spatialite():
105+
results = [{'1': '1'}]
106+
elif sql == "select name from sqlite_master where type='view'":
107+
results = [{'name': name} for name in self.connector.view_names()]
108+
elif sql.startswith("select count(*) from ["):
109+
match = re.search(r'select count\(\*\) from \[(.*)\]', sql)
110+
results = [{'count(*)': self.connector.table_count(match.group(1))}]
111+
elif sql.startswith("select count(*) from "):
112+
match = re.search(r'select count\(\*\) from (.*)', sql)
113+
results = [{'count(*)': self.connector.table_count(match.group(1))}]
114+
elif sql.startswith("PRAGMA table_info("):
115+
match = re.search(r'PRAGMA table_info\((.*)\)', sql)
116+
results = self.connector.table_info(match.group(1))
117+
elif sql.startswith("select name from sqlite_master where rootpage = 0 and ( sql like \'%VIRTUAL TABLE%USING FTS%content="):
118+
match = re.search(r'select name from sqlite_master where rootpage = 0 and \( sql like \'%VIRTUAL TABLE%USING FTS%content="(.*)"', sql)
119+
if self.connector.detect_fts(match.group(1)):
120+
results = [{'name': match.group(1)}]
121+
elif sql.startswith("PRAGMA foreign_key_list(["):
122+
match = re.search(r'PRAGMA foreign_key_list\(\[(.*)\]\)', sql)
123+
results = self.connector.foreign_keys(match.group(1))
124+
elif sql == "select 1 from sqlite_master where type='table' and name=?":
125+
if self.connector.table_exists(params[0]):
126+
results = [{'1': '1'}]
127+
elif sql == "select sql from sqlite_master where name = :n and type=:t":
128+
results = [{'sql': self.connector.table_definition(params['t'], params['n'])}]
129+
elif sql == "select sql from sqlite_master where tbl_name = :n and type='index' and sql is not null":
130+
results = [{'sql': sql} for sql in self.connector.indices_definition(params['n'])]
131+
else:
132+
try:
133+
results, truncated, description = \
134+
self.connector.execute(
135+
sql,
136+
params=params,
137+
truncate=truncate,
138+
custom_time_limit=custom_time_limit,
139+
page_size=page_size,
140+
log_sql_errors=log_sql_errors,
141+
)
142+
except OperationalError as ex:
143+
raise sqlite3.OperationalError(*ex.args)
43144

44-
@staticmethod
45-
@for_each_connector
46-
def table_columns(connector, path, table):
47-
return connector.table_columns(path, table)
145+
self.rows = [Row(result) for result in results]
146+
self.description = description
48147

49-
@staticmethod
50-
@for_each_connector
51-
def primary_keys(connector, path, table):
52-
return connector.primary_keys(path, table)
148+
def fetchall(self):
149+
return self.rows
53150

54-
@staticmethod
55-
@for_each_connector
56-
def fts_table(connector, path, table):
57-
return connector.fts_table(path, table)
151+
def fetchmany(self, max):
152+
return self.rows[:max]
58153

59-
@staticmethod
60-
@for_each_connector
61-
def get_all_foreign_keys(connector, path):
62-
return connector.get_all_foreign_keys(path)
63-
64-
@staticmethod
65-
@for_each_connector
66-
def table_counts(connector, path, *args, **kwargs):
67-
return connector.table_counts(path, *args, **kwargs)
154+
def __getitem__(self, index):
155+
return self.rows[index]
68156

69157

70158
class Connector:
71-
@staticmethod
72-
def table_names(path):
73-
return []
74-
75-
@staticmethod
76-
def hidden_table_names(path):
77-
return []
78-
79-
@staticmethod
80-
def view_names(path):
81-
return []
82-
83-
@staticmethod
84-
def table_columns(path, table):
85-
return []
86-
87-
@staticmethod
88-
def primary_keys(path, table):
89-
return []
90-
91-
@staticmethod
92-
def fts_table(path, table):
93-
return None
94-
95-
@staticmethod
96-
def get_all_foreign_keys(path):
97-
return {}
98-
99-
@staticmethod
100-
def table_counts(path, *args, **kwargs):
101-
return {}
159+
connector_type = None
160+
connection_class = Connection
161+
162+
def connect(self, path):
163+
return self.connection_class(path, self)
164+
165+
def table_names(self):
166+
"""
167+
Return a list of table names
168+
"""
169+
raise NotImplementedError
170+
171+
def hidden_table_names(self):
172+
raise NotImplementedError
173+
174+
def detect_spatialite(self):
175+
"""
176+
Return boolean indicating if geometry_columns exists
177+
"""
178+
raise NotImplementedError
179+
180+
def view_names(self):
181+
"""
182+
Return a list of view names
183+
"""
184+
raise NotImplementedError
185+
186+
def table_count(self, table_name):
187+
"""
188+
Return an integer with the rows count of the table
189+
"""
190+
raise NotImplementedError
191+
192+
def table_info(self, table_name):
193+
"""
194+
Return a list of dictionaries with columns description, with format:
195+
[
196+
{
197+
'idx': 0,
198+
'name': 'column1',
199+
'primary_key': False,
200+
},
201+
...
202+
]
203+
"""
204+
raise NotImplementedError
205+
206+
def detect_fts(self, table_name):
207+
"""
208+
Return boolean indicating if table has a corresponding FTS virtual table
209+
"""
210+
raise NotImplementedError
211+
212+
def foreign_keys(self, table_name):
213+
"""
214+
Return a list of dictionaries with foreign keys description
215+
id, seq, table_name, from_, to_, on_update, on_delete, match
216+
"""
217+
raise NotImplementedError
218+
219+
def table_exists(self, table_name):
220+
"""
221+
Return boolean indicating if table exists in the database
222+
"""
223+
raise NotImplementedError
224+
225+
def table_definition(self, table_type, table_name):
226+
"""
227+
Return string with a 'CREATE TABLE' sql definition
228+
"""
229+
raise NotImplementedError
230+
231+
def indices_definition(self, table_name):
232+
"""
233+
Return a list of strings with 'CREATE INDEX' sql definitions
234+
"""
235+
raise NotImplementedError
236+
237+
def execute(
238+
self,
239+
sql,
240+
params=None,
241+
truncate=False,
242+
custom_time_limit=None,
243+
page_size=None,
244+
log_sql_errors=True,
245+
):
246+
raise NotImplementedError

datasette_connectors/monkey.py

Lines changed: 28 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
import asyncio
12
import threading
23
import sqlite3
4+
35
import datasette.views.base
6+
from datasette.tracer import trace
47
from datasette.database import Database
8+
from datasette.database import Results
59

610
from .connectors import ConnectorList
711

@@ -13,36 +17,6 @@ def patch_datasette():
1317
Monkey patching for original Datasette
1418
"""
1519

16-
async def table_names(self):
17-
try:
18-
return await self.original_table_names()
19-
except sqlite3.DatabaseError:
20-
return ConnectorList.table_names(self.path)
21-
22-
Database.original_table_names = Database.table_names
23-
Database.table_names = table_names
24-
25-
26-
async def hidden_table_names(self):
27-
try:
28-
return await self.original_hidden_table_names()
29-
except sqlite3.DatabaseError:
30-
return ConnectorList.hidden_table_names(self.path)
31-
32-
Database.original_hidden_table_names = Database.hidden_table_names
33-
Database.hidden_table_names = hidden_table_names
34-
35-
36-
async def view_names(self):
37-
try:
38-
return await self.original_view_names()
39-
except sqlite3.DatabaseError:
40-
return ConnectorList.view_names(self.path)
41-
42-
Database.original_view_names = Database.view_names
43-
Database.view_names = view_names
44-
45-
4620
async def table_columns(self, table):
4721
try:
4822
return await self.original_table_columns(table)
@@ -73,21 +47,33 @@ async def fts_table(self, table):
7347
Database.fts_table = fts_table
7448

7549

76-
async def get_all_foreign_keys(self):
50+
def connect(self, write=False):
7751
try:
78-
return await self.original_get_all_foreign_keys()
52+
# Check if it's a sqlite database
53+
conn = self.original_connect(write=write)
54+
conn.execute("select name from sqlite_master where type='table'")
55+
return conn
7956
except sqlite3.DatabaseError:
80-
return ConnectorList.get_all_foreign_keys(self.path)
57+
conn = ConnectorList.connect(self.path)
58+
return conn
59+
60+
Database.original_connect = Database.connect
61+
Database.connect = connect
8162

82-
Database.original_get_all_foreign_keys = Database.get_all_foreign_keys
83-
Database.get_all_foreign_keys = get_all_foreign_keys
8463

64+
async def execute_fn(self, fn):
65+
def in_thread():
66+
conn = getattr(connections, self.name, None)
67+
if not conn:
68+
conn = self.connect()
69+
if isinstance(conn, sqlite3.Connection):
70+
self.ds._prepare_connection(conn, self.name)
71+
setattr(connections, self.name, conn)
72+
return fn(conn)
8573

86-
async def table_counts(self, *args, **kwargs):
87-
counts = await self.original_table_counts(**kwargs)
88-
# If all tables has None as counts, an error had ocurred
89-
if len(list(filter(lambda table_count: table_count is not None, counts.values()))) == 0:
90-
return ConnectorList.table_counts(self.path, *args, **kwargs)
74+
return await asyncio.get_event_loop().run_in_executor(
75+
self.ds.executor, in_thread
76+
)
9177

92-
Database.original_table_counts = Database.table_counts
93-
Database.table_counts = table_counts
78+
Database.original_execute_fn = Database.execute_fn
79+
Database.execute_fn = execute_fn

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ def get_long_description():
2222
url='https://github.com/pytables/datasette-connectors',
2323
license='Apache License, Version 2.0',
2424
packages=['datasette_connectors'],
25-
install_requires=['datasette==0.48'],
25+
install_requires=[
26+
'datasette==0.48',
27+
],
2628
tests_require=[
2729
'pytest',
2830
'aiohttp',

0 commit comments

Comments
 (0)