-
Notifications
You must be signed in to change notification settings - Fork 54
/
Copy pathmock_pg.py
108 lines (82 loc) · 2.38 KB
/
mock_pg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""This module serves as a mock object for the pg API module"""
import sys
sys.modules['pg'] = sys.modules[__name__]
class Error(Exception):
pass
class DatabaseError(Error):
pass
class InternalError(DatabaseError):
pass
class ProgrammingError(DatabaseError):
pass
def connect(*args, **kwargs):
return PgConnection(*args, **kwargs)
class PgConnection:
"""The underlying pg API connection class."""
def __init__(self, dbname=None, user=None):
self.db = dbname
self.user = user
self.num_queries = 0
self.session = []
if dbname == 'error':
self.status = False
self.valid = False
raise InternalError
self.status = True
self.valid = True
def close(self):
if not self.valid:
raise InternalError
self.num_queries = 0
self.session = []
self.status = False
self.valid = False
def reset(self):
self.num_queries = 0
self.session = []
self.status = True
self.valid = True
def query(self, qstr):
if not self.valid:
raise InternalError
if qstr in ('begin', 'end', 'commit', 'rollback'):
self.session.append(qstr)
return None
if qstr.startswith('select '):
self.num_queries += 1
return qstr[7:]
if qstr.startswith('set '):
self.session.append(qstr[4:])
return None
raise ProgrammingError
class DB:
"""Wrapper class for the pg API connection class."""
def __init__(self, *args, **kw):
self.db = connect(*args, **kw)
self.dbname = self.db.db
self.__args = args, kw
def __getattr__(self, name):
if not self.db:
raise AttributeError
return getattr(self.db, name)
def close(self):
if not self.db:
raise InternalError
self.db.close()
self.db = None
def reopen(self):
if self.db:
self.close()
try:
self.db = connect(*self.__args[0], **self.__args[1])
except Exception:
self.db = None
raise
def query(self, qstr):
if not self.db:
raise InternalError
return self.db.query(qstr)
def get_tables(self):
if not self.db:
raise InternalError
return 'test'