1
1
import pkg_resources
2
2
import functools
3
+ import re
4
+ import sqlite3
5
+
6
+ from .row import Row
7
+
3
8
4
9
db_connectors = {}
5
10
@@ -26,76 +31,216 @@ def load():
26
31
def add_connector (name , connector ):
27
32
db_connectors [name ] = connector
28
33
29
- @staticmethod
30
- @for_each_connector
31
- def table_names (connector , path ):
32
- return connector .table_names (path )
33
-
34
- @staticmethod
35
- @for_each_connector
36
- def hidden_table_names (connector , path ):
37
- return connector .hidden_table_names (path )
34
+ class DatabaseNotSupported (Exception ):
35
+ pass
38
36
39
37
@staticmethod
40
- @for_each_connector
41
- def view_names (connector , path ):
42
- return connector .view_names (path )
38
+ def connect (path ):
39
+ for connector in db_connectors .values ():
40
+ try :
41
+ return connector .connect (path )
42
+ except :
43
+ pass
44
+ else :
45
+ raise ConnectorList .DatabaseNotSupported
46
+
47
+
48
+ class Connection :
49
+ def __init__ (self , path , connector ):
50
+ self .path = path
51
+ self .connector = connector
52
+
53
+ def execute (self , * args , ** kwargs ):
54
+ cursor = Cursor (self )
55
+ cursor .execute (* args , ** kwargs )
56
+ return cursor
57
+
58
+ def cursor (self ):
59
+ return Cursor (self )
60
+
61
+ def set_progress_handler (self , handler , n ):
62
+ pass
63
+
64
+
65
+ class OperationalError (Exception ):
66
+ pass
67
+
68
+
69
+ class Cursor :
70
+ class QueryNotSupported (Exception ):
71
+ pass
72
+
73
+ def __init__ (self , conn ):
74
+ self .conn = conn
75
+ self .connector = conn .connector
76
+ self .rows = []
77
+ self .description = ()
78
+
79
+ def execute (
80
+ self ,
81
+ sql ,
82
+ params = None ,
83
+ truncate = False ,
84
+ custom_time_limit = None ,
85
+ page_size = None ,
86
+ log_sql_errors = True ,
87
+ ):
88
+ if params is None :
89
+ params = {}
90
+ results = []
91
+ truncated = False
92
+ description = ()
93
+
94
+ # Normalize sql
95
+ sql = sql .strip ()
96
+ sql = ' ' .join (sql .split ())
97
+
98
+ if sql == "select name from sqlite_master where type='table'" or \
99
+ sql == "select name from sqlite_master where type=\" table\" " :
100
+ results = [{'name' : name } for name in self .connector .table_names ()]
101
+ elif sql == "select name from sqlite_master where rootpage = 0 and sql like '%VIRTUAL TABLE%USING FTS%'" :
102
+ results = [{'name' : name } for name in self .connector .hidden_table_names ()]
103
+ elif sql == 'select 1 from sqlite_master where tbl_name = "geometry_columns"' :
104
+ if self .connector .detect_spatialite ():
105
+ results = [{'1' : '1' }]
106
+ elif sql == "select name from sqlite_master where type='view'" :
107
+ results = [{'name' : name } for name in self .connector .view_names ()]
108
+ elif sql .startswith ("select count(*) from [" ):
109
+ match = re .search (r'select count\(\*\) from \[(.*)\]' , sql )
110
+ results = [{'count(*)' : self .connector .table_count (match .group (1 ))}]
111
+ elif sql .startswith ("select count(*) from " ):
112
+ match = re .search (r'select count\(\*\) from (.*)' , sql )
113
+ results = [{'count(*)' : self .connector .table_count (match .group (1 ))}]
114
+ elif sql .startswith ("PRAGMA table_info(" ):
115
+ match = re .search (r'PRAGMA table_info\((.*)\)' , sql )
116
+ results = self .connector .table_info (match .group (1 ))
117
+ elif sql .startswith ("select name from sqlite_master where rootpage = 0 and ( sql like \' %VIRTUAL TABLE%USING FTS%content=" ):
118
+ match = re .search (r'select name from sqlite_master where rootpage = 0 and \( sql like \'%VIRTUAL TABLE%USING FTS%content="(.*)"' , sql )
119
+ if self .connector .detect_fts (match .group (1 )):
120
+ results = [{'name' : match .group (1 )}]
121
+ elif sql .startswith ("PRAGMA foreign_key_list([" ):
122
+ match = re .search (r'PRAGMA foreign_key_list\(\[(.*)\]\)' , sql )
123
+ results = self .connector .foreign_keys (match .group (1 ))
124
+ elif sql == "select 1 from sqlite_master where type='table' and name=?" :
125
+ if self .connector .table_exists (params [0 ]):
126
+ results = [{'1' : '1' }]
127
+ elif sql == "select sql from sqlite_master where name = :n and type=:t" :
128
+ results = [{'sql' : self .connector .table_definition (params ['t' ], params ['n' ])}]
129
+ elif sql == "select sql from sqlite_master where tbl_name = :n and type='index' and sql is not null" :
130
+ results = [{'sql' : sql } for sql in self .connector .indices_definition (params ['n' ])]
131
+ else :
132
+ try :
133
+ results , truncated , description = \
134
+ self .connector .execute (
135
+ sql ,
136
+ params = params ,
137
+ truncate = truncate ,
138
+ custom_time_limit = custom_time_limit ,
139
+ page_size = page_size ,
140
+ log_sql_errors = log_sql_errors ,
141
+ )
142
+ except OperationalError as ex :
143
+ raise sqlite3 .OperationalError (* ex .args )
43
144
44
- @staticmethod
45
- @for_each_connector
46
- def table_columns (connector , path , table ):
47
- return connector .table_columns (path , table )
145
+ self .rows = [Row (result ) for result in results ]
146
+ self .description = description
48
147
49
- @staticmethod
50
- @for_each_connector
51
- def primary_keys (connector , path , table ):
52
- return connector .primary_keys (path , table )
148
+ def fetchall (self ):
149
+ return self .rows
53
150
54
- @staticmethod
55
- @for_each_connector
56
- def fts_table (connector , path , table ):
57
- return connector .fts_table (path , table )
151
+ def fetchmany (self , max ):
152
+ return self .rows [:max ]
58
153
59
- @staticmethod
60
- @for_each_connector
61
- def get_all_foreign_keys (connector , path ):
62
- return connector .get_all_foreign_keys (path )
63
-
64
- @staticmethod
65
- @for_each_connector
66
- def table_counts (connector , path , * args , ** kwargs ):
67
- return connector .table_counts (path , * args , ** kwargs )
154
+ def __getitem__ (self , index ):
155
+ return self .rows [index ]
68
156
69
157
70
158
class Connector :
71
- @staticmethod
72
- def table_names (path ):
73
- return []
74
-
75
- @staticmethod
76
- def hidden_table_names (path ):
77
- return []
78
-
79
- @staticmethod
80
- def view_names (path ):
81
- return []
82
-
83
- @staticmethod
84
- def table_columns (path , table ):
85
- return []
86
-
87
- @staticmethod
88
- def primary_keys (path , table ):
89
- return []
90
-
91
- @staticmethod
92
- def fts_table (path , table ):
93
- return None
94
-
95
- @staticmethod
96
- def get_all_foreign_keys (path ):
97
- return {}
98
-
99
- @staticmethod
100
- def table_counts (path , * args , ** kwargs ):
101
- return {}
159
+ connector_type = None
160
+ connection_class = Connection
161
+
162
+ def connect (self , path ):
163
+ return self .connection_class (path , self )
164
+
165
+ def table_names (self ):
166
+ """
167
+ Return a list of table names
168
+ """
169
+ raise NotImplementedError
170
+
171
+ def hidden_table_names (self ):
172
+ raise NotImplementedError
173
+
174
+ def detect_spatialite (self ):
175
+ """
176
+ Return boolean indicating if geometry_columns exists
177
+ """
178
+ raise NotImplementedError
179
+
180
+ def view_names (self ):
181
+ """
182
+ Return a list of view names
183
+ """
184
+ raise NotImplementedError
185
+
186
+ def table_count (self , table_name ):
187
+ """
188
+ Return an integer with the rows count of the table
189
+ """
190
+ raise NotImplementedError
191
+
192
+ def table_info (self , table_name ):
193
+ """
194
+ Return a list of dictionaries with columns description, with format:
195
+ [
196
+ {
197
+ 'idx': 0,
198
+ 'name': 'column1',
199
+ 'primary_key': False,
200
+ },
201
+ ...
202
+ ]
203
+ """
204
+ raise NotImplementedError
205
+
206
+ def detect_fts (self , table_name ):
207
+ """
208
+ Return boolean indicating if table has a corresponding FTS virtual table
209
+ """
210
+ raise NotImplementedError
211
+
212
+ def foreign_keys (self , table_name ):
213
+ """
214
+ Return a list of dictionaries with foreign keys description
215
+ id, seq, table_name, from_, to_, on_update, on_delete, match
216
+ """
217
+ raise NotImplementedError
218
+
219
+ def table_exists (self , table_name ):
220
+ """
221
+ Return boolean indicating if table exists in the database
222
+ """
223
+ raise NotImplementedError
224
+
225
+ def table_definition (self , table_type , table_name ):
226
+ """
227
+ Return string with a 'CREATE TABLE' sql definition
228
+ """
229
+ raise NotImplementedError
230
+
231
+ def indices_definition (self , table_name ):
232
+ """
233
+ Return a list of strings with 'CREATE INDEX' sql definitions
234
+ """
235
+ raise NotImplementedError
236
+
237
+ def execute (
238
+ self ,
239
+ sql ,
240
+ params = None ,
241
+ truncate = False ,
242
+ custom_time_limit = None ,
243
+ page_size = None ,
244
+ log_sql_errors = True ,
245
+ ):
246
+ raise NotImplementedError
0 commit comments