Skip to content
Permalink
Browse files

Merge pull request #4685 from strk/dbmanager-test

Add test for DBManager's PostGIS connector and plugin
  • Loading branch information
strk committed Jun 7, 2017
2 parents 9397c4f + 11ace44 commit d914a195d35520718d95dc1f232c4c62c2fbe14c
@@ -78,7 +78,7 @@ def _execute(self, cursor, sql):
if cursor is None:
cursor = self._get_cursor()
try:
cursor.execute(unicode(sql))
cursor.execute(str(sql))

except self.connection_error_types() as e:
raise ConnectionError(e)
@@ -98,7 +98,7 @@ def _execute_and_commit(self, sql):
def _get_cursor(self, name=None):
try:
if name is not None:
name = unicode(name).encode('ascii', 'replace').replace('?', "_")
name = str(name).encode('ascii', 'replace').replace('?', "_")
self._last_cursor_named_id = 0 if not hasattr(self,
'_last_cursor_named_id') else self._last_cursor_named_id + 1
return self.connection.cursor("%s_%d" % (name, self._last_cursor_named_id))
@@ -181,37 +181,39 @@ def _get_cursor_columns(self, c):

@classmethod
def quoteId(self, identifier):
if hasattr(identifier, '__iter__'):
if hasattr(identifier, '__iter__') and not isinstance(identifier, str):
ids = list()
for i in identifier:
if i is None or i == "":
continue
ids.append(self.quoteId(i))
return u'.'.join(ids)

identifier = unicode(
identifier) if identifier is not None else unicode() # make sure it's python unicode string
identifier = str(
identifier) if identifier is not None else str() # make sure it's python unicode string
return u'"%s"' % identifier.replace('"', '""')

@classmethod
def quoteString(self, txt):
""" make the string safe - replace ' with '' """
if hasattr(txt, '__iter__'):
if hasattr(txt, '__iter__') and not isinstance(txt, str):
txts = list()
for i in txt:
if i is None:
continue
txts.append(self.quoteString(i))
return u'.'.join(txts)

txt = unicode(txt) if txt is not None else unicode() # make sure it's python unicode string
txt = str(txt) if txt is not None else str() # make sure it's python unicode string
return u"'%s'" % txt.replace("'", "''")

@classmethod
def getSchemaTableName(self, table):
if not hasattr(table, '__iter__'):
if not hasattr(table, '__iter__') and not isinstance(table, str):
return (None, table)
elif len(table) < 2:
if isinstance(table, str):
table = table.split('.')
if len(table) < 2:
return (None, table[0])
else:
return (table[0], table[1])
@@ -8,3 +8,10 @@ PLUGIN_INSTALL(db_manager db_plugins/postgis/icons ${ICON_FILES})

ADD_SUBDIRECTORY(plugins)

IF(ENABLE_TESTS)
INCLUDE(UsePythonTest)
IF (ENABLE_PGTEST)
ADD_PYTHON_TEST(dbmanager-postgis-connector connector_test.py)
ADD_PYTHON_TEST(dbmanager-postgis-plugin plugin_test.py)
ENDIF (ENABLE_PGTEST)
ENDIF(ENABLE_TESTS)
@@ -24,6 +24,7 @@

from qgis.PyQt.QtCore import QRegExp
from qgis.core import QgsCredentials, QgsDataSourceURI
from functools import cmp_to_key

from ..connector import DBConnector
from ..plugin import ConnectionError, DbError, Table
@@ -60,7 +61,7 @@ def __init__(self, uri):

expandedConnInfo = self._connectionInfo()
try:
self.connection = psycopg2.connect(expandedConnInfo.encode('utf-8'))
self.connection = psycopg2.connect(expandedConnInfo)
except self.connection_error_types() as e:
err = unicode(e)
uri = self.uri()
@@ -79,7 +80,7 @@ def __init__(self, uri):

newExpandedConnInfo = uri.connectionInfo(True)
try:
self.connection = psycopg2.connect(newExpandedConnInfo.encode('utf-8'))
self.connection = psycopg2.connect(newExpandedConnInfo)
QgsCredentials.instance().put(conninfo, username, password)
except self.connection_error_types() as e:
if i == 2:
@@ -135,7 +136,7 @@ def __init__(self, uri):
self._checkRasterColumnsTable()

def _connectionInfo(self):
return unicode(self.uri().connectionInfo(True))
return str(self.uri().connectionInfo(True))

def _checkSpatial(self):
""" check whether postgis_version is present in catalog """
@@ -331,7 +332,7 @@ def getTables(self, schema=None, add_sys_tables=False):
items.append(item)
self._close_cursor(c)

return sorted(items, cmp=lambda x, y: cmp((x[2], x[1]), (y[2], y[1])))
return sorted(items, key=cmp_to_key(lambda x, y: (x[1] > y[1]) - (x[1] < y[1])))

def getVectorTables(self, schema=None):
""" get list of table with a geometry column
@@ -0,0 +1,76 @@
# -*- coding: utf-8 -*-

"""
***************************************************************************
connector_test.py
---------------------
Date : May 2017
Copyright : (C) 2017, Sandro Santilli
Email : strk at kbt dot io
***************************************************************************
* *
* This program is free software; you can redistribute it and/or modify *
* it under the terms of the GNU General Public License as published by *
* the Free Software Foundation; either version 2 of the License, or *
* (at your option) any later version. *
* *
***************************************************************************
"""

__author__ = 'Sandro Santilli'
__date__ = 'May 2017'
__copyright__ = '(C) 2017, Sandro Santilli'
# This will get replaced with a git SHA1 when you do a git archive
__revision__ = '$Format:%H$'

import os
import qgis
from qgis.testing import start_app, unittest
from qgis.core import QgsDataSourceURI
from qgis.utils import iface

start_app()

from db_manager.db_plugins.postgis.connector import PostGisDBConnector


class TestDBManagerPostgisConnector(unittest.TestCase):

#def setUpClass():

def _getUser(self, connector):
r = connector._execute(None, "SELECT USER")
val = connector._fetchone(r)[0]
connector._close_cursor(r)
return val

def _getDatabase(self, connector):
r = connector._execute(None, "SELECT current_database()")
val = connector._fetchone(r)[0]
connector._close_cursor(r)
return val

# See https://issues.qgis.org/issues/16625
# and https://issues.qgis.org/issues/10600
def test_dbnameLessURI(self):
c = PostGisDBConnector(QgsDataSourceURI())
self.assertIsInstance(c, PostGisDBConnector)
uri = c.uri()

# No username was passed, so we expect it to be taken
# from PGUSER or USER environment variables
expected_user = os.environ.get('PGUSER') or os.environ.get('USER')
actual_user = self._getUser(c)
self.assertEqual(actual_user, expected_user)

# No database was passed, so we expect it to be taken
# from PGDATABASE or expected user
expected_db = os.environ.get('PGDATABASE') or expected_user
actual_db = self._getDatabase(c)
self.assertEqual(actual_db, expected_db)

# TODO: add service-only test (requires a ~/.pg_service.conf file)


if __name__ == '__main__':
unittest.main()
@@ -0,0 +1,130 @@
# -*- coding: utf-8 -*-

"""
***************************************************************************
plugin_test.py
---------------------
Date : May 2017
Copyright : (C) 2017, Sandro Santilli
Email : strk at kbt dot io
***************************************************************************
* *
* This program is free software; you can redistribute it and/or modify *
* it under the terms of the GNU General Public License as published by *
* the Free Software Foundation; either version 2 of the License, or *
* (at your option) any later version. *
* *
***************************************************************************
"""

__author__ = 'Sandro Santilli'
__date__ = 'May 2017'
__copyright__ = '(C) 2017, Sandro Santilli'
# This will get replaced with a git SHA1 when you do a git archive
__revision__ = '$Format:%H$'

import os
import re
import qgis
from qgis.testing import start_app, unittest
from qgis.core import QgsDataSourceURI
from qgis.utils import iface
from qgis.PyQt.QtCore import QObject

start_app()

from db_manager.db_plugins.postgis.plugin import PostGisDBPlugin, PGRasterTable
from db_manager.db_plugins.postgis.plugin import PGDatabase
from db_manager.db_plugins.plugin import Table
from db_manager.db_plugins.postgis.connector import PostGisDBConnector


class TestDBManagerPostgisPlugin(unittest.TestCase):

@classmethod
def setUpClass(self):
self.old_pgdatabase_env = os.environ.get('PGDATABASE')
self.testdb = os.environ.get('QGIS_PGTEST_DB') or 'qgis_test'
os.environ['PGDATABASE'] = self.testdb

# Create temporary service file
self.old_pgservicefile_env = os.environ.get('PGSERVICEFILE')
self.tmpservicefile = '/tmp/qgis-test-{}-pg_service.conf'.format(os.getpid())
os.environ['PGSERVICEFILE'] = self.tmpservicefile

f = open(self.tmpservicefile, "w")
f.write("[dbmanager]\ndbname={}\n".format(self.testdb))
# TODO: add more things if PGSERVICEFILE was already set ?
f.close()

@classmethod
def tearDownClass(self):
# Restore previous env variables if needed
if self.old_pgdatabase_env:
os.environ['PGDATABASE'] = self.old_pgdatabase_env
if self.old_pgservicefile_env:
os.environ['PGSERVICEFILE'] = self.old_pgservicefile_env
# Remove temporary service file
os.unlink(self.tmpservicefile)

# See https://issues.qgis.org/issues/16625

def test_rasterTableGdalURI(self):

def check_rasterTableGdalURI(expected_dbname):
tables = database.tables()
raster_tables_count = 0
for tab in tables:
if tab.type == Table.RasterType:
raster_tables_count += 1
gdalUri = tab.gdalUri()
m = re.search(' dbname=([^ ]*) ', gdalUri)
self.assertTrue(m)
actual_dbname = m.group(1)
self.assertEqual(actual_dbname, expected_dbname)
#print(tab.type)
#print(tab.quotedName())
#print(tab)

# We need to make sure a database is created with at
# least one raster table !
self.assertEqual(raster_tables_count, 1)

obj = QObject() # needs to be kept alive

# Test for empty URI
# See https://issues.qgis.org/issues/16625
# and https://issues.qgis.org/issues/10600

expected_dbname = self.testdb
os.environ['PGDATABASE'] = expected_dbname

database = PGDatabase(obj, QgsDataSourceURI())
self.assertIsInstance(database, PGDatabase)

uri = database.uri()
self.assertEqual(uri.host(), '')
self.assertEqual(uri.username(), '')
self.assertEqual(uri.database(), expected_dbname)
self.assertEqual(uri.service(), '')

check_rasterTableGdalURI(expected_dbname)

# Test for service-only URI
# See https://issues.qgis.org/issues/16626

os.environ['PGDATABASE'] = 'fake'
database = PGDatabase(obj, QgsDataSourceURI('service=dbmanager'))
self.assertIsInstance(database, PGDatabase)

uri = database.uri()
self.assertEqual(uri.host(), '')
self.assertEqual(uri.username(), '')
self.assertEqual(uri.database(), '')
self.assertEqual(uri.service(), 'dbmanager')

check_rasterTableGdalURI(expected_dbname)


if __name__ == '__main__':
unittest.main()
@@ -4,6 +4,7 @@ SCRIPTS="
tests/testdata/provider/testdata_pg.sql
tests/testdata/provider/testdata_pg_reltests.sql
tests/testdata/provider/testdata_pg_vectorjoin.sql
tests/testdata/provider/testdata_pg_raster.sql
"

createdb qgis_test || exit 1
@@ -0,0 +1,15 @@
-- Table: qgis_test.raster1

CREATE TABLE qgis_test."Raster1"
(
pk serial NOT NULL,
name character varying(255),
"Rast" raster
);

INSERT INTO qgis_test."Raster1" (name, "Rast") SELECT
'simple one',
ST_AddBand(
ST_MakeEmptyRaster(16, 32, 7, -5, 0.2, -0.7, 0, 0, 0),
1, '8BUI', 0.0, NULL
);

0 comments on commit d914a19

Please sign in to comment.
You can’t perform that action at this time.