Skip to content

Commit

Permalink
Merge pull request #33225 from strk/dbmanager-reconnect-button
Browse files Browse the repository at this point in the history
DBManager PostgreSQL backend using core APIs instead of psycopg2
  • Loading branch information
strk committed Jan 20, 2020
2 parents 875c03e + 5b4e581 commit d39b6ac
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 81 deletions.
236 changes: 166 additions & 70 deletions python/plugins/db_manager/db_plugins/postgis/connector.py
Expand Up @@ -26,8 +26,20 @@

from functools import cmp_to_key

from qgis.PyQt.QtCore import QRegExp, QFile, QCoreApplication
from qgis.core import Qgis, QgsCredentials, QgsDataSourceUri, QgsCoordinateReferenceSystem
from qgis.PyQt.QtCore import (
QRegExp,
QFile,
QCoreApplication,
QVariant
)
from qgis.core import (
Qgis,
QgsCredentials,
QgsVectorLayer,
QgsDataSourceUri,
QgsProviderRegistry,
QgsProviderConnectionException,
)

from ..connector import DBConnector
from ..plugin import ConnectionError, DbError, Table
Expand All @@ -44,16 +56,140 @@ def classFactory():
return PostGisDBConnector


class CursorAdapter():

def _debug(self, msg):
pass
#print("XXX CursorAdapter[" + hex(id(self)) + "]: " + msg)

def __init__(self, connection, sql=None):
self._debug("Created with sql: " + str(sql))
self.connection = connection
self.sql = sql
self.result = None
self.cursor = 0
self.closed = False
if (self.sql != None):
self._execute()

def _toStrResultSet(self, res):
newres = []
for rec in res:
newrec = []
for col in rec:
if type(col) == type(QVariant(None)):
if (str(col) == 'NULL'):
col = None
else:
col = str(col) # force to string
newrec.append(col)
newres.append(newrec)
return newres

def _execute(self, sql=None):
if self.sql == sql and self.result != None:
return
if (sql != None):
self.sql = sql
if (self.sql == None):
return
self._debug("execute called with sql " + self.sql)
try:
self.result = self._toStrResultSet(self.connection.executeSql(self.sql))
except QgsProviderConnectionException as e:
raise DbError(e, self.sql)
self._debug("execute returned " + str(len(self.result)) + " rows")
self.cursor = 0

self._description = None # reset description

@property
def description(self):

if self._description is None:

uri = QgsDataSourceUri(self.connection.uri())

# TODO: make this part provider-agnostic
uri.setTable('(SELECT row_number() OVER () AS __rid__, * FROM (' + self.sql + ') as foo)')
uri.setKeyColumn('__rid__')
# TODO: fetch provider name from connection (QgsAbstractConnectionProvider)
# TODO: re-use the VectorLayer for fetching rows in batch mode
vl = QgsVectorLayer(uri.uri(False), 'dbmanager_cursor', 'postgres')

fields = vl.fields()
self._description = []
for i in range(1, len(fields)): # skip first field (__rid__)
f = fields[i]
self._description.append([
f.name(), # name
f.type(), # type_code
f.length(), # display_size
f.length(), # internal_size
f.precision(), # precision
None, # scale
True # null_ok
])
self._debug("get_description returned " + str(len(self._description)) + " cols")

return self._description

def fetchone(self):
self._execute()
if len(self.result) - self.cursor:
res = self.result[self.cursor]
++self.cursor
return res
return None

def fetchmany(self, size):
self._execute()
if self.result is None:
self._debug("fetchmany: none result after _execute (self.sql is " + str(self.sql) + ", returning []")
return []
leftover = len(self.result) - self.cursor
self._debug("fetchmany: cursor: " + str(self.cursor) + " leftover: " + str(leftover) + " requested: " + str(size))
if leftover < 1:
return []
if size > leftover:
size = leftover
stop = self.cursor + size
res = self.result[self.cursor:stop]
self.cursor = stop
self._debug("fetchmany: new cursor: " + str(self.cursor) + " reslen: " + str(len(self.result)))
return res

def fetchall(self):
self._execute()
res = self.result[self.cursor:]
self.cursor = len(self.result)
return res

def scroll(self, pos, mode='relative'):
self._execute()
if pos < 0:
self._debug("scroll pos is negative: " + str(pos))
if mode == 'relative':
self.cursor = self.cursor + pos
elif mode == 'absolute':
self.cursor = pos

def close(self):
self.result = None
self.closed = True


class PostGisDBConnector(DBConnector):

def __init__(self, uri):
DBConnector.__init__(self, uri)
"""Creates a new PostgreSQL connector
self.host = uri.host() or os.environ.get('PGHOST')
self.port = uri.port() or os.environ.get('PGPORT')
:param uri: data source URI
:type uri: QgsDataSourceUri
"""
DBConnector.__init__(self, uri)

username = uri.username() or os.environ.get('PGUSER')
password = uri.password() or os.environ.get('PGPASSWORD')

# Do not get db and user names from the env if service is used
if not uri.service():
Expand All @@ -62,42 +198,13 @@ def __init__(self, uri):
self.dbname = uri.database() or os.environ.get('PGDATABASE') or username
uri.setDatabase(self.dbname)

expandedConnInfo = self._connectionInfo()
try:
self.connection = psycopg2.connect(expandedConnInfo)
except self.connection_error_types() as e:
# get credentials if cached or asking to the user no more than 3 times
err = str(e)
uri = self.uri()
conninfo = uri.connectionInfo(False)

for i in range(3):
(ok, username, password) = QgsCredentials.instance().get(conninfo, username, password, err)
if not ok:
raise ConnectionError(QCoreApplication.translate('db_manager', 'Could not connect to database as user {user}').format(user=username))

if username:
uri.setUsername(username)

if password:
uri.setPassword(password)

newExpandedConnInfo = uri.connectionInfo(True)
try:
self.connection = psycopg2.connect(newExpandedConnInfo)
QgsCredentials.instance().put(conninfo, username, password)
except self.connection_error_types() as e:
if i == 2:
raise ConnectionError(e)
err = str(e)
finally:
# clear certs for each time trying to connect
self._clearSslTempCertsIfAny(newExpandedConnInfo)
finally:
# clear certs of the first connection try
self._clearSslTempCertsIfAny(expandedConnInfo)

self.connection.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
#self.connName = connName
#self.user = uri.username() or os.environ.get('USER')
#self.passwd = uri.password()
self.host = uri.host()

md = QgsProviderRegistry.instance().providerMetadata('postgres')
self.core_connection = md.createConnection(uri.database())

c = self._execute(None, u"SELECT current_user,current_database()")
self.user, self.dbname = self._fetchone(c)
Expand Down Expand Up @@ -752,19 +859,15 @@ def renameTable(self, table, new_table):
if new_table == tablename:
return

c = self._get_cursor()

sql = u"ALTER TABLE %s RENAME TO %s" % (self.quoteId(table), self.quoteId(new_table))
self._execute(c, sql)
self._executeSql(sql)

# update geometry_columns if PostGIS is enabled
if self.has_geometry_columns and not self.is_geometry_columns_view:
schema_where = u" AND f_table_schema=%s " % self.quoteString(schema) if schema is not None else ""
sql = u"UPDATE geometry_columns SET f_table_name=%s WHERE f_table_name=%s %s" % (
self.quoteString(new_table), self.quoteString(tablename), schema_where)
self._execute(c, sql)

self._commit()
self._executeSql(sql)

def commentTable(self, schema, tablename, comment=None):
if comment is None:
Expand Down Expand Up @@ -1034,33 +1137,26 @@ def execution_error_types(self):
def connection_error_types(self):
return psycopg2.InterfaceError, psycopg2.OperationalError

# moved into the parent class: DbConnector._execute()
# def _execute(self, cursor, sql):
# pass

# moved into the parent class: DbConnector._execute_and_commit()
# def _execute_and_commit(self, sql):
# pass

# moved into the parent class: DbConnector._get_cursor()
# def _get_cursor(self, name=None):
# pass
def _execute(self, cursor, sql):
if cursor != None:
cursor._execute(sql)
return cursor
return CursorAdapter(self.core_connection, sql)

# moved into the parent class: DbConnector._fetchall()
# def _fetchall(self, c):
# pass
def _executeSql(self, sql):
return self.core_connection.executeSql(sql)

# moved into the parent class: DbConnector._fetchone()
# def _fetchone(self, c):
# pass
def _get_cursor(self, name=None):
#if name is not None:
# print("XXX _get_cursor called with a Name: " + name)
return CursorAdapter(self.core_connection, name)

# moved into the parent class: DbConnector._commit()
# def _commit(self):
# pass
def _commit(self):
pass

# moved into the parent class: DbConnector._rollback()
# def _rollback(self):
# pass
def _rollback(self):
pass

# moved into the parent class: DbConnector._get_cursor_columns()
# def _get_cursor_columns(self, c):
Expand Down
3 changes: 3 additions & 0 deletions python/plugins/db_manager/db_plugins/postgis/plugin.py
Expand Up @@ -184,6 +184,9 @@ def hasLowercaseFieldNamesOption(self):
def supportsComment(self):
return True

def executeSql(self, sql):
return self.connector._executeSql(sql)


class PGSchema(Schema):

Expand Down
Expand Up @@ -33,21 +33,20 @@

current_path = os.path.dirname(__file__)


# The load function is called when the "db" database or either one of its
# children db objects (table o schema) is selected by the user.
# @param db is the selected database
# @param mainwindow is the DBManager mainwindow


def load(db, mainwindow):
# check whether the selected database supports topology
# (search for topology.topology)
sql = u"""SELECT count(*)
FROM pg_class AS cls JOIN pg_namespace AS nsp ON nsp.oid = cls.relnamespace
WHERE cls.relname = 'topology' AND nsp.nspname = 'topology'"""
c = db.connector._get_cursor()
db.connector._execute(c, sql)
res = db.connector._fetchone(c)
if res is None or int(res[0]) <= 0:
res = db.executeSql(sql)
if res is None or len(res) < 1 or int(res[0][0]) <= 0:
return

# add the action to the DBManager menu
Expand Down Expand Up @@ -78,10 +77,8 @@ def run(item, action, mainwindow):

if item.schema() is not None:
sql = u"SELECT srid FROM topology.topology WHERE name = %s" % quoteStr(item.schema().name)
c = db.connector._get_cursor()
db.connector._execute(c, sql)
res = db.connector._fetchone(c)
isTopoSchema = res is not None
res = db.executeSql(sql)
isTopoSchema = len(res) > 0

if not isTopoSchema:
mainwindow.infoBar.pushMessage("Invalid topology",
Expand All @@ -90,11 +87,11 @@ def run(item, action, mainwindow):
mainwindow.iface.messageTimeout())
return False

if (res[0] < 0):
if (res[0][0] < 0):
mainwindow.infoBar.pushMessage("WARNING", u'Topology "{0}" is registered as having a srid of {1} in topology.topology, we will assume 0 (for unknown)'.format(item.schema().name, res[0]), Qgis.Warning, mainwindow.iface.messageTimeout())
toposrid = '0'
else:
toposrid = str(res[0])
toposrid = str(res[0][0])

# load layers into the current project
toponame = item.schema().name
Expand Down
4 changes: 4 additions & 0 deletions src/providers/postgres/qgspostgresproviderconnection.cpp
Expand Up @@ -263,6 +263,10 @@ QList<QVariantList> QgsPostgresProviderConnection::executeSqlPrivate( const QStr
{
vType = QVariant::Bool;
}
else
{
QgsDebugMsg( QStringLiteral( "Unhandled PostgreSQL type %1" ).arg( typName ) );
}
}
typeMap[ rowIdx ] = vType;
}
Expand Down

0 comments on commit d39b6ac

Please sign in to comment.