Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,9 +442,7 @@ def db_request(pyramid_request, db_session):

@pytest.fixture
def enable_organizations(db_request):
flag = db_request.db.query(AdminFlag).get(
AdminFlagValue.DISABLE_ORGANIZATIONS.value
)
flag = db_request.db.get(AdminFlag, AdminFlagValue.DISABLE_ORGANIZATIONS.value)
flag.enabled = False
yield
flag.enabled = True
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/accounts/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,8 +1420,8 @@ def _find_service(service=None, name=None, context=None):

def test_register_fails_with_admin_flag_set(self, db_request):
# This flag was already set via migration, just need to enable it
flag = db_request.db.query(AdminFlag).get(
AdminFlagValue.DISALLOW_NEW_USER_REGISTRATION.value
flag = db_request.db.get(
AdminFlag, AdminFlagValue.DISALLOW_NEW_USER_REGISTRATION.value
)
flag.enabled = True

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/admin/views/test_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def test_deletes_user(self, db_request, monkeypatch):

db_request.db.flush()

assert not db_request.db.query(User).get(user.id)
assert not db_request.db.get(User, user.id)
assert db_request.db.query(Project).all() == [another_project]
assert db_request.route_path.calls == [pretend.call("admin.user.list")]
assert result.status_code == 303
Expand Down Expand Up @@ -371,7 +371,7 @@ def test_deletes_user_bad_confirm(self, db_request, monkeypatch):

db_request.db.flush()

assert db_request.db.query(User).get(user.id)
assert db_request.db.get(User, user.id)
assert db_request.db.query(Project).all() == [project]
assert db_request.route_path.calls == [
pretend.call("admin.user.detail", username=user.username)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/manage/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -6708,7 +6708,7 @@ def test_delete_oidc_publisher_not_found(self, monkeypatch, other_publisher):
session=pretend.stub(flash=pretend.call_recorder(lambda *a, **kw: None)),
POST=pretend.stub(),
db=pretend.stub(
query=lambda *a: pretend.stub(get=lambda id: other_publisher),
get=pretend.call_recorder(lambda *a, **kw: other_publisher),
),
remote_addr="0.0.0.0",
)
Expand Down
6 changes: 2 additions & 4 deletions tests/unit/subscriptions/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,11 +749,9 @@ def test_delete_subscription_price(self, subscription_service, db_request):
"""
subscription_price = StripeSubscriptionPriceFactory.create()

assert db_request.db.query(StripeSubscriptionPrice).get(subscription_price.id)
assert db_request.db.get(StripeSubscriptionPrice, subscription_price.id)

subscription_service.delete_subscription_price(subscription_price.id)
subscription_service.db.flush()

assert not (
db_request.db.query(StripeSubscriptionPrice).get(subscription_price.id)
)
assert not (db_request.db.get(StripeSubscriptionPrice, subscription_price.id))
10 changes: 4 additions & 6 deletions tests/unit/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from sqlalchemy.exc import OperationalError

from warehouse import db
from warehouse.admin.flags import AdminFlagValue
from warehouse.admin.flags import AdminFlag, AdminFlagValue
from warehouse.db import (
DEFAULT_ISOLATION,
DatabaseNotAvailableError,
Expand Down Expand Up @@ -134,7 +134,7 @@ def raiser():
def test_create_session(monkeypatch, pyramid_services):
session_obj = pretend.stub(
close=pretend.call_recorder(lambda: None),
query=lambda *a: pretend.stub(get=lambda *a: None),
get=pretend.call_recorder(lambda *a: None),
)
session_cls = pretend.call_recorder(lambda bind: session_obj)
monkeypatch.setattr(db, "Session", session_cls)
Expand Down Expand Up @@ -190,9 +190,7 @@ def test_create_session_read_only_mode(
admin_flag, is_superuser, doom_calls, monkeypatch, pyramid_services
):
get = pretend.call_recorder(lambda *a: admin_flag)
session_obj = pretend.stub(
close=lambda: None, query=lambda *a: pretend.stub(get=get)
)
session_obj = pretend.stub(close=lambda: None, get=get)
session_cls = pretend.call_recorder(lambda bind: session_obj)
monkeypatch.setattr(db, "Session", session_cls)

Expand All @@ -218,7 +216,7 @@ def test_create_session_read_only_mode(
)

assert _create_session(request) is session_obj
assert get.calls == [pretend.call(AdminFlagValue.READ_ONLY.value)]
assert get.calls == [pretend.call(AdminFlag, AdminFlagValue.READ_ONLY.value)]
assert request.tm.doom.calls == doom_calls


Expand Down
10 changes: 6 additions & 4 deletions warehouse/accounts/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
import requests

from passlib.context import CryptContext
from sqlalchemy import exists, select
from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm import joinedload
from sqlalchemy.sql import exists
from webauthn.helpers import bytes_to_base64url
from zope.interface import implementer

Expand Down Expand Up @@ -99,9 +99,11 @@ def _get_user(self, userid):
# object here.
# TODO: We need some sort of Anonymous User.
return (
self.db.query(User).options(joinedload(User.webauthn)).get(userid)
if userid
else None
self.db.scalars(
select(User).options(joinedload(User.webauthn)).where(User.id == userid)
)
.unique()
.one_or_none()
)

def get_user(self, userid):
Expand Down
4 changes: 2 additions & 2 deletions warehouse/accounts/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -1550,8 +1550,8 @@ def delete_pending_oidc_publisher(self):
)
return self.default_response

pending_publisher = self.request.db.query(PendingOIDCPublisher).get(
form.publisher_id.data
pending_publisher = self.request.db.get(
PendingOIDCPublisher, form.publisher_id.data
)

# pending_publisher will be `None` here if someone manually
Expand Down
2 changes: 1 addition & 1 deletion warehouse/admin/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def notifications(self):
)

def enabled(self, flag_member):
flag = self.request.db.query(AdminFlag).get(flag_member.value)
flag = self.request.db.get(AdminFlag, flag_member.value)
return flag.enabled if flag else False


Expand Down
2 changes: 1 addition & 1 deletion warehouse/admin/views/emails.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def email_mass(request):
rows = list(csv.DictReader(wrapper))
if rows:
for row in rows:
user = request.db.query(User).get(row["user_id"])
user = request.db.get(User, row["user_id"])
email = user.primary_email

if email:
Expand Down
2 changes: 1 addition & 1 deletion warehouse/admin/views/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_flags(request):
require_csrf=True,
)
def edit_flag(request):
flag = request.db.query(AdminFlag).get(request.POST["id"])
flag = request.db.get(AdminFlag, request.POST["id"])
flag.description = request.POST["description"]
flag.enabled = bool(request.POST.get("enabled"))

Expand Down
2 changes: 1 addition & 1 deletion warehouse/admin/views/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def delete_role(project, request):
confirm = request.POST.get("username")
role_id = request.matchdict.get("role_id")

role = request.db.query(Role).get(role_id)
role = request.db.get(Role, role_id)
if not role:
request.session.flash("This role no longer exists", queue="error")
raise HTTPSeeOther(
Expand Down
4 changes: 2 additions & 2 deletions warehouse/admin/views/verdicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get_verdicts(request):
uses_session=True,
)
def get_verdict(request):
verdict = request.db.query(MalwareVerdict).get(request.matchdict["verdict_id"])
verdict = request.db.get(MalwareVerdict, request.matchdict["verdict_id"])

if verdict:
return {
Expand All @@ -76,7 +76,7 @@ def get_verdict(request):
require_csrf=True,
)
def review_verdict(request):
verdict = request.db.query(MalwareVerdict).get(request.matchdict["verdict_id"])
verdict = request.db.get(MalwareVerdict, request.matchdict["verdict_id"])

try:
classification = getattr(VerdictClassification, request.POST["classification"])
Expand Down
5 changes: 2 additions & 3 deletions warehouse/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
from sqlalchemy import event, inspect
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.ext.declarative import declarative_base # type: ignore
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import declarative_base, sessionmaker

from warehouse.metrics import IMetricsService
from warehouse.utils.attrs import make_repr
Expand Down Expand Up @@ -155,7 +154,7 @@ def cleanup(request):
# Check if we're in read-only mode
from warehouse.admin.flags import AdminFlag, AdminFlagValue

flag = session.query(AdminFlag).get(AdminFlagValue.READ_ONLY.value)
flag = session.get(AdminFlag, AdminFlagValue.READ_ONLY.value)
if flag and flag.enabled:
request.tm.doom()

Expand Down
2 changes: 1 addition & 1 deletion warehouse/legacy/api/pypi.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def browse(request):
except ValueError:
raise HTTPNotFound

classifier = request.db.query(Classifier).get(classifier_id)
classifier = request.db.get(Classifier, classifier_id)

if not classifier:
raise HTTPNotFound
Expand Down
2 changes: 1 addition & 1 deletion warehouse/malware/checks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def prepare(cls, request, obj_id):
kwargs = {"obj_id": obj_id}

model = getattr(models, cls.hooked_object)
kwargs["obj"] = request.db.query(model).get(obj_id)
kwargs["obj"] = request.db.get(model, obj_id)

if cls.hooked_object == "File":
kwargs["file_url"] = request.route_url(
Expand Down
2 changes: 1 addition & 1 deletion warehouse/manage/views/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1353,7 +1353,7 @@ def delete_oidc_publisher(self):
form = DeletePublisherForm(self.request.POST)

if form.validate():
publisher = self.request.db.query(OIDCPublisher).get(form.publisher_id.data)
publisher = self.request.db.get(OIDCPublisher, form.publisher_id.data)

# publisher will be `None` here if someone manually futzes with the form.
if publisher is None or publisher not in self.project.oidc_publishers:
Expand Down
6 changes: 3 additions & 3 deletions warehouse/migrations/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# limitations under the License.

from alembic import context
from sqlalchemy import create_engine, pool
from sqlalchemy import create_engine, pool, text

from warehouse import db

Expand Down Expand Up @@ -50,8 +50,8 @@ def run_migrations_online():
connectable = create_engine(url, poolclass=pool.NullPool)

with connectable.connect() as connection:
connection.execute("SET statement_timeout = 5000")
connection.execute("SET lock_timeout = 4000")
connection.execute(text("SET statement_timeout = 5000"))
connection.execute(text("SET lock_timeout = 4000"))

context.configure(
connection=connection,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
"""


import sqlalchemy as sa

from alembic import op

revision = "203f1f8dcf92"
Expand All @@ -27,7 +29,7 @@
def upgrade():
# We've seen this migration fail due to statement timeouts in production.
conn = op.get_bind()
conn.execute("SET statement_timeout = 60000")
conn.execute(sa.text("SET statement_timeout = 60000"))

op.drop_constraint("file_events_source_id_fkey", "file_events", type_="foreignkey")
op.create_foreign_key(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _get_num_rows(conn):

def upgrade():
conn = op.get_bind()
conn.execute("SET statement_timeout = 120000")
conn.execute(sa.text("SET statement_timeout = 120000"))
total_rows = _get_num_rows(conn)
max_loops = total_rows / 100000 * 2
loops = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
Create Date: 2023-04-14 18:21:38.683694
"""

import sqlalchemy as sa

from alembic import op
from sqlalchemy.dialects import postgresql

Expand All @@ -26,8 +28,8 @@

def upgrade():
conn = op.get_bind()
conn.execute("SET statement_timeout = 120000")
conn.execute("SET lock_timeout = 120000")
conn.execute(sa.text("SET statement_timeout = 120000"))
conn.execute(sa.text("SET lock_timeout = 120000"))
op.alter_column(
"pending_oidc_publishers",
"added_by_id",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

def upgrade():
conn = op.get_bind()
conn.execute("SET statement_timeout = 120000")
conn.execute(sa.text("SET statement_timeout = 120000"))
op.add_column(
"release_files",
sa.Column(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@

def upgrade():
conn = op.get_bind()
conn.execute("SET statement_timeout = 120000")
conn.execute("SET lock_timeout = 120000")
conn.execute(sa.text("SET statement_timeout = 120000"))
conn.execute(sa.text("SET lock_timeout = 120000"))
op.drop_column("users", "has_oidc_beta_access")


Expand Down
Loading