Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP

Loading…

Create helper functions for common database access patterns #203

Merged
merged 12 commits into from
This page is out of date. Refresh to see the latest.
Showing with 399 additions and 291 deletions.
  1. +142 −0 tests/test_db.py
  2. +44 −0 warehouse/db.py
  3. +213 −291 warehouse/packaging/models.py
142 tests/test_db.py
View
@@ -0,0 +1,142 @@
+# Copyright 2013 Donald Stufft
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import absolute_import, division, print_function
+from __future__ import unicode_literals
+
+import pretend
+import pytest
+
+from warehouse import db
+
+
+@pytest.mark.parametrize(
+ ("value", "default", "expected", "eargs", "ekwargs"),
+ [
+ (1, None, 1, [], {}),
+ (None, None, None, [], {}),
+ (None, 10, 10, [], {}),
+ (1, None, 1, ["a"], {}),
+ (None, None, None, ["a"], {}),
+ (None, 10, 10, ["a"], {}),
+ (1, None, 1, [], {"a": "b"}),
+ (None, None, None, [], {"a": "b"}),
+ (None, 10, 10, [], {"a": "b"}),
+ ],
+)
+def test_scalar(value, default, expected, eargs, ekwargs):
+ result = pretend.stub(scalar=pretend.call_recorder(lambda: value))
+ execute = pretend.call_recorder(lambda q, *a, **kw: result)
+ model = pretend.stub(
+ engine=pretend.stub(
+ connect=lambda: pretend.stub(
+ __enter__=lambda: pretend.stub(execute=execute),
+ __exit__=lambda *a, **k: None,
+ ),
+ ),
+ )
+
+ sql = db.scalar("SELECT * FROM thing", default=default)
+
+ assert sql(model, *eargs, **ekwargs) == expected
+ assert execute.calls == [
+ pretend.call("SELECT * FROM thing", *eargs, **ekwargs),
+ ]
+ assert result.scalar.calls == [pretend.call()]
+
+
+@pytest.mark.parametrize(
+ ("row_func", "value", "expected", "eargs", "ekwargs"),
+ [
+ (None, [{"a": "b"}], [{"a": "b"}], [], {}),
+ (lambda r: r["a"], [{"a": "b"}], ["b"], [], {}),
+ (None, [{"a": "b"}], [{"a": "b"}], ["a"], {}),
+ (lambda r: r["a"], [{"a": "b"}], ["b"], ["a"], {}),
+ (None, [{"a": "b"}], [{"a": "b"}], [], {"a": "b"}),
+ (lambda r: r["a"], [{"a": "b"}], ["b"], [], {"a": "b"}),
+ ],
+)
+def test_rows(row_func, value, expected, eargs, ekwargs):
+ execute = pretend.call_recorder(lambda q, *a, **kw: value)
+ model = pretend.stub(
+ engine=pretend.stub(
+ connect=lambda: pretend.stub(
+ __enter__=lambda: pretend.stub(execute=execute),
+ __exit__=lambda *a, **k: None,
+ ),
+ ),
+ )
+ kwargs = {"row_func": row_func} if row_func else {}
+
+ sql = db.rows("SELECT * FROM thing", **kwargs)
+
+ assert sql(model, *eargs, **ekwargs) == expected
+ assert execute.calls == [
+ pretend.call("SELECT * FROM thing", *eargs, **ekwargs),
+ ]
+
+
+@pytest.mark.parametrize(
+ ("key_func", "value_func", "value", "expected", "eargs", "ekwargs"),
+ [
+ (None, None, [("a", "b")], {"a": "b"}, [], {}),
+ (
+ lambda r: r["a"],
+ lambda r: r["b"],
+ [{"a": 1, "b": 2}],
+ {1: 2},
+ [],
+ {},
+ ),
+ (None, None, [("a", "b")], {"a": "b"}, ["z"], {}),
+ (
+ lambda r: r["a"],
+ lambda r: r["b"],
+ [{"a": 1, "b": 2}],
+ {1: 2},
+ ["z"],
+ {},
+ ),
+ (None, None, [("a", "b")], {"a": "b"}, [], {"z": "g"}),
+ (
+ lambda r: r["a"],
+ lambda r: r["b"],
+ [{"a": 1, "b": 2}],
+ {1: 2},
+ [],
+ {"z": "g"},
+ ),
+ ],
+)
+def test_mapping(key_func, value_func, value, expected, eargs, ekwargs):
+ execute = pretend.call_recorder(lambda q, *a, **kw: value)
+ model = pretend.stub(
+ engine=pretend.stub(
+ connect=lambda: pretend.stub(
+ __enter__=lambda: pretend.stub(execute=execute),
+ __exit__=lambda *a, **k: None,
+ ),
+ ),
+ )
+ kwargs = {}
+ if key_func:
+ kwargs["key_func"] = key_func
+ if value_func:
+ kwargs["value_func"] = value_func
+
+ sql = db.mapping("SELECT * FROM thing", **kwargs)
+
+ assert sql(model, *eargs, **ekwargs) == expected
+ assert execute.calls == [
+ pretend.call("SELECT * FROM thing", *eargs, **ekwargs),
+ ]
44 warehouse/db.py
View
@@ -17,3 +17,47 @@
import sqlalchemy
metadata = sqlalchemy.MetaData()
+
+
+def scalar(query, default=None):
+ """
+ A helper function that takes a query and returns a function that will query
+ the database and return a scalar.
+ """
+ def inner(model, *args, **kwargs):
+ with model.engine.connect() as conn:
+ val = conn.execute(query, *args, **kwargs).scalar()
+
+ if default is not None and val is None:
+ return default
+ else:
+ return val
+
+ return inner
+
+
+def rows(query, row_func=dict):
+ """
+ A helper function that takes a query and returns a function that will query
+ the database and return a list of rows with the row_func applied to each.
+ """
+ def inner(model, *args, **kwargs):
+ with model.engine.connect() as conn:
+ return [row_func(r) for r in conn.execute(query, *args, **kwargs)]
+
+ return inner
+
+
+def mapping(query, key_func=lambda r: r[0], value_func=lambda r: r[1]):
+ """
+ A helper function that takes a query, a key_func, and a value_func and will
+ created a mapping that maps each row to a key: value pair.
+ """
+ def inner(model, *args, **kwargs):
+ with model.engine.connect() as conn:
+ return {
+ key_func(r): value_func(r)
+ for r in conn.execute(query, *args, **kwargs)
+ }
+
+ return inner
504 warehouse/packaging/models.py
View
@@ -19,7 +19,7 @@
import urlparse
import logging
-from warehouse import models
+from warehouse import db, models
from warehouse.packaging.tables import ReleaseDependencyKind
@@ -28,65 +28,53 @@
class Model(models.Model):
- def get_project_count(self):
- query = "SELECT COUNT(*) FROM packages"
+ get_project_count = db.scalar(
+ "SELECT COUNT(*) FROM packages"
+ )
- with self.engine.connect() as conn:
- return conn.execute(query).scalar()
-
- def get_download_count(self):
- query = "SELECT SUM(downloads) FROM release_files"
-
- with self.engine.connect() as conn:
- return conn.execute(query).scalar() or 0
+ get_download_count = db.scalar(
+ "SELECT SUM(downloads) FROM release_files",
+ default=0,
+ )
- def get_recently_updated(self, num=10):
+ get_recently_updated = db.rows(
# We only consider releases made in the last 7 days, otherwise we have
- # to do a Sequence Scan against the entire table and it takes 5+
- # seconds to complete. This shouldn't be a big deal as it is highly
- # unlikely we'll have a week without at least 10 releases.
- query = \
- """ SELECT *
- FROM (
- SELECT DISTINCT ON (name) name, version, summary, created
- FROM releases
- WHERE created >= now() - interval '7 days'
- ORDER BY name, created DESC
- ) r
- ORDER BY r.created DESC
- LIMIT %(num)s
- """
-
- with self.engine.connect() as conn:
- return [dict(r) for r in conn.execute(query, num=num)]
-
- def get_releases_since(self, since):
- query = \
- """ SELECT name, version, created, summary
+ # to do a Sequence Scan against the entire table and it takes 5+
+ # seconds to complete. This shouldn't be a big deal as it is highly
+ # unlikely we'll have a week without at least 10 releases.
+ """ SELECT *
+ FROM (
+ SELECT DISTINCT ON (name) name, version, summary, created
FROM releases
- WHERE created > %(since)s
- ORDER BY created DESC
- """
-
- with self.engine.connect() as conn:
- return [dict(r) for r in conn.execute(query, since=since)]
-
- def get_changed_since(self, since):
- query = \
- """SELECT name, max(submitted_date) FROM journals
- WHERE submitted_date > %(since)s
- GROUP BY name
- ORDER BY max(submitted_date) DESC
- """
-
- with self.engine.connect() as conn:
- return [r[0] for r in conn.execute(query, since=since)]
-
- def all_projects(self):
- query = "SELECT name FROM packages ORDER BY lower(name)"
-
- with self.engine.connect() as conn:
- return [r["name"] for r in conn.execute(query)]
+ WHERE created >= now() - interval '7 days'
+ ORDER BY name, created DESC
+ ) r
+ ORDER BY r.created DESC
+ LIMIT 10
+ """
+ )
+
+ get_releases_since = db.rows(
+ """ SELECT name, version, created, summary
+ FROM releases
+ WHERE created > %s
+ ORDER BY created DESC
+ """
+ )
+
+ get_changed_since = db.rows(
+ """ SELECT name, max(submitted_date) FROM journals
+ WHERE submitted_date > %s
+ GROUP BY name
+ ORDER BY max(submitted_date) DESC
+ """,
+ row_func=lambda r: r[0]
+ )
+
+ all_projects = db.rows(
+ "SELECT name FROM packages ORDER BY lower(name)",
+ row_func=lambda r: r["name"]
+ )
def get_top_projects(self, num=None):
query = \
@@ -101,151 +89,113 @@ def get_top_projects(self, num=None):
with self.engine.connect() as conn:
return [tuple(r) for r in conn.execute(query, limit=num)]
- def get_project(self, name):
- query = \
- """ SELECT name
- FROM packages
- WHERE normalized_name = lower(
- regexp_replace(%(name)s, '_', '-', 'ig')
- )
- """
-
- with self.engine.connect() as conn:
- return conn.execute(query, name=name).scalar()
-
- def get_projects_for_user(self, username):
- query = \
- """ SELECT DISTINCT ON (lower(name)) name, summary
- FROM (
- SELECT package_name
- FROM roles
- WHERE user_name = %(username)s
- ) roles
- INNER JOIN (
- SELECT name, summary
- FROM releases
- ORDER BY _pypi_ordering DESC
- ) releases
- ON (releases.name = roles.package_name)
- ORDER BY lower(name)
- """
-
- with self.engine.connect() as conn:
- return [dict(r) for r in conn.execute(query, username=username)]
-
- def get_users_for_project(self, project):
- query = \
- """ SELECT DISTINCT ON (u.username) u.username, u.email
- FROM (
- SELECT username, email
- FROM accounts_user
- LEFT OUTER JOIN accounts_email ON (
- accounts_email.user_id = accounts_user.id
- )
- ) u
- INNER JOIN roles ON (u.username = roles.user_name)
- WHERE roles.package_name = %(project)s
- """
-
- with self.engine.connect() as conn:
- return [dict(r) for r in conn.execute(query, project=project)]
-
- def get_roles_for_project(self, project):
- query = \
- """ SELECT user_name, role_name
- FROM roles
- WHERE package_name = %(project)s
- ORDER BY role_name, user_name
- """
-
- with self.engine.connect() as conn:
- return [dict(r) for r in conn.execute(query, project=project)]
+ get_project = db.scalar(
+ """ SELECT name
+ FROM packages
+ WHERE normalized_name = lower(
+ regexp_replace(%s, '_', '-', 'ig')
+ )
+ """
+ )
- def get_roles_for_user(self, user):
- query = \
- """ SELECT package_name, role_name
+ get_projects_for_user = db.rows(
+ """ SELECT DISTINCT ON (lower(name)) name, summary
+ FROM (
+ SELECT package_name
FROM roles
- WHERE user_name = %(user)s
- ORDER BY package_name, role_name
- """
-
- with self.engine.connect() as conn:
- return [dict(r) for r in conn.execute(query, user=user)]
-
- def get_hosting_mode(self, name):
- query = "SELECT hosting_mode FROM packages WHERE name = %(project)s"
-
- with self.engine.connect() as conn:
- return conn.execute(query, project=name).scalar()
-
- def get_release_urls(self, name):
- query = \
- """ SELECT version, home_page, download_url
+ WHERE user_name = %s
+ ) roles
+ INNER JOIN (
+ SELECT name, summary
FROM releases
- WHERE name = %(project)s
- ORDER BY version DESC
- """
-
- with self.engine.connect() as conn:
- return {
- r["version"]: (r["home_page"], r["download_url"])
- for r in conn.execute(query, project=name)
- }
-
- def get_external_urls(self, name):
- query = \
- """ SELECT DISTINCT ON (url) url
- FROM description_urls
- WHERE name = %(project)s
- ORDER BY url
- """
-
- with self.engine.connect() as conn:
- return [r["url"] for r in conn.execute(query, project=name)]
-
- def get_file_urls(self, name):
- query = \
- """ SELECT name, filename, python_version, md5_digest
- FROM release_files
- WHERE name = %(project)s
- ORDER BY filename DESC
- """
-
- with self.engine.connect() as conn:
- results = conn.execute(query, project=name)
-
- return [
- {
- "filename": r["filename"],
- "url": urlparse.urljoin(
- "/".join([
- "../../packages",
- r["python_version"],
- r["name"][0],
- r["name"],
- r["filename"],
- ]),
- "#md5={}".format(r["md5_digest"]),
- ),
- }
- for r in results
- ]
-
- def get_project_for_filename(self, filename):
- query = "SELECT name FROM release_files WHERE filename = %(filename)s"
-
- with self.engine.connect() as conn:
- return conn.execute(query, filename=filename).scalar()
+ ORDER BY _pypi_ordering DESC
+ ) releases
+ ON (releases.name = roles.package_name)
+ ORDER BY lower(name)
+ """
+ )
+
+ get_users_for_project = db.rows(
+ """ SELECT DISTINCT ON (u.username) u.username, u.email
+ FROM (
+ SELECT username, email
+ FROM accounts_user
+ LEFT OUTER JOIN accounts_email ON (
+ accounts_email.user_id = accounts_user.id
+ )
+ ) u
+ INNER JOIN roles ON (u.username = roles.user_name)
+ WHERE roles.package_name = %s
+ """
+ )
+
+ get_roles_for_project = db.rows(
+ """ SELECT user_name, role_name
+ FROM roles
+ WHERE package_name = %s
+ ORDER BY role_name, user_name
+ """
+ )
+
+ get_roles_for_user = db.rows(
+ """ SELECT package_name, role_name
+ FROM roles
+ WHERE user_name = %s
+ ORDER BY package_name, role_name
+ """
+ )
+
+ get_hosting_mode = db.scalar(
+ "SELECT hosting_mode FROM packages WHERE name = %s"
+ )
+
+ get_release_urls = db.mapping(
+ """ SELECT version, home_page, download_url
+ FROM releases
+ WHERE name = %s
+ ORDER BY version DESC
+ """,
+ key_func=lambda r: r["version"],
+ value_func=lambda r: (r["home_page"], r["download_url"]),
+ )
+
+ get_external_urls = db.rows(
+ """ SELECT DISTINCT ON (url) url
+ FROM description_urls
+ WHERE name = %s
+ ORDER BY url
+ """,
+ row_func=lambda r: r["url"]
+ )
+
+ get_file_urls = db.rows(
+ """ SELECT name, filename, python_version, md5_digest
+ FROM release_files
+ WHERE name = %s
+ ORDER BY filename DESC
+ """,
+ lambda r: {
+ "filename": r["filename"],
+ "url": urlparse.urljoin(
+ "/".join([
+ "../../packages",
+ r["python_version"],
+ r["name"][0],
+ r["name"],
+ r["filename"],
+ ]),
+ "#md5={}".format(r["md5_digest"]),
+ ),
+ }
+ )
- def get_filename_md5(self, filename):
- query = \
- """ SELECT md5_digest
- FROM release_files
- WHERE filename = %(filename)s
- """
+ get_project_for_filename = db.scalar(
+ "SELECT name FROM release_files WHERE filename = %s"
+ )
- with self.engine.connect() as conn:
- return conn.execute(query, filename=filename).scalar()
+ get_filename_md5 = db.scalar(
+ "SELECT md5_digest FROM release_files WHERE filename = %s"
+ )
def get_last_serial(self, name=None):
if name is not None:
@@ -253,26 +203,20 @@ def get_last_serial(self, name=None):
else:
query = "SELECT MAX(id) FROM journals"
- with self.engine.connect() as conn:
- return conn.execute(query, name=name).scalar()
+ return db.scalar(query)(self, name=name)
- def get_projects_with_serial(self):
- # return list of dict(name: max id)
- query = "SELECT name, max(id) FROM journals GROUP BY name"
+ get_projects_with_serial = db.mapping(
+ "SELECT name, max(id) FROM journals GROUP BY name",
+ )
- with self.engine.connect() as conn:
- return dict(r for r in conn.execute(query))
-
- def get_project_versions(self, project):
- query = \
- """ SELECT version
- FROM releases
- WHERE name = %(project)s
- ORDER BY _pypi_ordering DESC
- """
-
- with self.engine.connect() as conn:
- return [r["version"] for r in conn.execute(query, project=project)]
+ get_project_versions = db.rows(
+ """ SELECT version
+ FROM releases
+ WHERE name = %s
+ ORDER BY _pypi_ordering DESC
+ """,
+ row_func=lambda r: r["version"]
+ )
def get_downloads(self, project, version):
query = \
@@ -368,37 +312,26 @@ def get_release(self, project, version):
return result
- def get_releases(self, project):
- # Get the release data
- query = \
- """ SELECT
- name, version, author, author_email, maintainer,
- maintainer_email, home_page, license, summary, keywords,
- platform, download_url, created
- FROM releases
- WHERE name = %(project)s
- ORDER BY _pypi_ordering DESC
- """
-
- with self.engine.connect() as conn:
- results = [dict(r) for r in conn.execute(query, project=project)]
-
- return results
-
- def get_full_latest_releases(self):
- query = \
- """ SELECT DISTINCT ON (name)
- name, version, author, author_email, maintainer,
- maintainer_email, home_page, license, summary, description,
- keywords, platform, download_url, created
- FROM releases
- ORDER BY name, _pypi_ordering DESC
- """
-
- with self.engine.connect() as conn:
- results = [dict(r) for r in conn.execute(query)]
-
- return results
+ get_releases = db.rows(
+ """ SELECT
+ name, version, author, author_email, maintainer,
+ maintainer_email, home_page, license, summary, keywords,
+ platform, download_url, created
+ FROM releases
+ WHERE name = %s
+ ORDER BY _pypi_ordering DESC
+ """
+ )
+
+ get_full_latest_releases = db.rows(
+ """ SELECT DISTINCT ON (name)
+ name, version, author, author_email, maintainer,
+ maintainer_email, home_page, license, summary, description,
+ keywords, platform, download_url, created
+ FROM releases
+ ORDER BY name, _pypi_ordering DESC
+ """
+ )
def get_download_counts(self, project):
def _make_key(precision, datetime, key):
@@ -461,34 +394,30 @@ def _make_key(precision, datetime, key):
"last_month": last_30,
}
- def get_classifiers(self, project, version):
- query = \
- """ SELECT classifier
- FROM release_classifiers
- INNER JOIN trove_classifiers ON (
- release_classifiers.trove_id = trove_classifiers.id
- )
- WHERE name = %(project)s AND version = %(version)s
- ORDER BY classifier
- """
-
- with self.engine.connect() as conn:
- return [
- r["classifier"]
- for r in conn.execute(query, project=project, version=version)
- ]
+ get_classifiers = db.rows(
+ """ SELECT classifier
+ FROM release_classifiers
+ INNER JOIN trove_classifiers ON (
+ release_classifiers.trove_id = trove_classifiers.id
+ )
+ WHERE name = %s AND version = %s
+ ORDER BY classifier
+ """,
+ row_func=lambda r: r["classifier"]
+ )
def get_classifier_ids(self, classifiers):
- placeholders = ', '.join(['%s'] * len(classifiers))
query = \
- """SELECT classifier, id
- FROM trove_classifiers
- WHERE classifier IN (%s)
- """ % placeholders
+ """ SELECT classifier, id
+ FROM trove_classifiers
+ WHERE classifier IN %(classifiers)s
+ """
with self.engine.connect() as conn:
- return dict((r['classifier'], r['id'])
- for r in conn.execute(query, *classifiers))
+ return {
+ r["classifier"]: r["id"]
+ for r in conn.execute(query, classifiers=tuple(classifiers))
+ }
def search_by_classifier(self, selected_classifiers):
# Note: selected_classifiers is a list of ids from trove_classifiers
@@ -541,33 +470,26 @@ def get_documentation_url(self, project):
project
) + "/"
- def get_bugtrack_url(self, project):
- query = "SELECT bugtrack_url FROM packages WHERE name = %(project)s"
-
- with self.engine.connect() as conn:
- return conn.execute(query, project=project).scalar()
+ get_bugtrack_url = db.scalar(
+ "SELECT bugtrack_url FROM packages WHERE name = %s"
+ )
- #
- # Mirroring support
- #
- def get_changelog(self, since):
- query = '''SELECT name, version, submitted_date, action, id
+ get_changelog = db.rows(
+ """ SELECT name, version, submitted_date, action, id
FROM journals
- WHERE journals.submitted_date > %(since)s
+ WHERE journals.submitted_date > %s
ORDER BY submitted_date DESC
- '''
- with self.engine.connect() as conn:
- return [dict(r) for r in conn.execute(query, since=since)]
+ """
+ )
- def get_last_changelog_serial(self):
- with self.engine.connect() as conn:
- return conn.execute('SELECT max(id) FROM journals').scalar()
+ get_last_changelog_serial = db.scalar(
+ "SELECT max(id) FROM journals"
+ )
- def get_changelog_serial(self, since):
- query = '''SELECT name, version, submitted_date, action, id
+ get_changelog_serial = db.rows(
+ """ SELECT name, version, submitted_date, action, id
FROM journals
- WHERE journals.id > %(since)s
+ WHERE journals.id > %s
ORDER BY submitted_date DESC
- '''
- with self.engine.connect() as conn:
- return [dict(r) for r in conn.execute(query, since=since)]
+ """
+ )
Something went wrong with that request. Please try again.