diff --git a/tests/conftest.py b/tests/conftest.py index ec5e1cc2c3f9..d33ed9f6086c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -442,9 +442,7 @@ def db_request(pyramid_request, db_session): @pytest.fixture def enable_organizations(db_request): - flag = db_request.db.query(AdminFlag).get( - AdminFlagValue.DISABLE_ORGANIZATIONS.value - ) + flag = db_request.db.get(AdminFlag, AdminFlagValue.DISABLE_ORGANIZATIONS.value) flag.enabled = False yield flag.enabled = True diff --git a/tests/unit/accounts/test_views.py b/tests/unit/accounts/test_views.py index eff981a33d9f..238f0b3cfeda 100644 --- a/tests/unit/accounts/test_views.py +++ b/tests/unit/accounts/test_views.py @@ -1420,8 +1420,8 @@ def _find_service(service=None, name=None, context=None): def test_register_fails_with_admin_flag_set(self, db_request): # This flag was already set via migration, just need to enable it - flag = db_request.db.query(AdminFlag).get( - AdminFlagValue.DISALLOW_NEW_USER_REGISTRATION.value + flag = db_request.db.get( + AdminFlag, AdminFlagValue.DISALLOW_NEW_USER_REGISTRATION.value ) flag.enabled = True diff --git a/tests/unit/admin/views/test_users.py b/tests/unit/admin/views/test_users.py index 9d6d6c116a8d..ddff6ee1947b 100644 --- a/tests/unit/admin/views/test_users.py +++ b/tests/unit/admin/views/test_users.py @@ -331,7 +331,7 @@ def test_deletes_user(self, db_request, monkeypatch): db_request.db.flush() - assert not db_request.db.query(User).get(user.id) + assert not db_request.db.get(User, user.id) assert db_request.db.query(Project).all() == [another_project] assert db_request.route_path.calls == [pretend.call("admin.user.list")] assert result.status_code == 303 @@ -371,7 +371,7 @@ def test_deletes_user_bad_confirm(self, db_request, monkeypatch): db_request.db.flush() - assert db_request.db.query(User).get(user.id) + assert db_request.db.get(User, user.id) assert db_request.db.query(Project).all() == [project] assert db_request.route_path.calls == [ pretend.call("admin.user.detail", username=user.username) diff --git a/tests/unit/manage/test_views.py b/tests/unit/manage/test_views.py index 81b44a114a18..893de2efa3b6 100644 --- a/tests/unit/manage/test_views.py +++ b/tests/unit/manage/test_views.py @@ -6708,7 +6708,7 @@ def test_delete_oidc_publisher_not_found(self, monkeypatch, other_publisher): session=pretend.stub(flash=pretend.call_recorder(lambda *a, **kw: None)), POST=pretend.stub(), db=pretend.stub( - query=lambda *a: pretend.stub(get=lambda id: other_publisher), + get=pretend.call_recorder(lambda *a, **kw: other_publisher), ), remote_addr="0.0.0.0", ) diff --git a/tests/unit/subscriptions/test_services.py b/tests/unit/subscriptions/test_services.py index 35eab6d9b8cf..b5dde3ad0798 100644 --- a/tests/unit/subscriptions/test_services.py +++ b/tests/unit/subscriptions/test_services.py @@ -749,11 +749,9 @@ def test_delete_subscription_price(self, subscription_service, db_request): """ subscription_price = StripeSubscriptionPriceFactory.create() - assert db_request.db.query(StripeSubscriptionPrice).get(subscription_price.id) + assert db_request.db.get(StripeSubscriptionPrice, subscription_price.id) subscription_service.delete_subscription_price(subscription_price.id) subscription_service.db.flush() - assert not ( - db_request.db.query(StripeSubscriptionPrice).get(subscription_price.id) - ) + assert not (db_request.db.get(StripeSubscriptionPrice, subscription_price.id)) diff --git a/tests/unit/test_db.py b/tests/unit/test_db.py index 63837634fafc..24fc3ac28350 100644 --- a/tests/unit/test_db.py +++ b/tests/unit/test_db.py @@ -24,7 +24,7 @@ from sqlalchemy.exc import OperationalError from warehouse import db -from warehouse.admin.flags import AdminFlagValue +from warehouse.admin.flags import AdminFlag, AdminFlagValue from warehouse.db import ( DEFAULT_ISOLATION, DatabaseNotAvailableError, @@ -134,7 +134,7 @@ def raiser(): def test_create_session(monkeypatch, pyramid_services): session_obj = pretend.stub( close=pretend.call_recorder(lambda: None), - query=lambda *a: pretend.stub(get=lambda *a: None), + get=pretend.call_recorder(lambda *a: None), ) session_cls = pretend.call_recorder(lambda bind: session_obj) monkeypatch.setattr(db, "Session", session_cls) @@ -190,9 +190,7 @@ def test_create_session_read_only_mode( admin_flag, is_superuser, doom_calls, monkeypatch, pyramid_services ): get = pretend.call_recorder(lambda *a: admin_flag) - session_obj = pretend.stub( - close=lambda: None, query=lambda *a: pretend.stub(get=get) - ) + session_obj = pretend.stub(close=lambda: None, get=get) session_cls = pretend.call_recorder(lambda bind: session_obj) monkeypatch.setattr(db, "Session", session_cls) @@ -218,7 +216,7 @@ def test_create_session_read_only_mode( ) assert _create_session(request) is session_obj - assert get.calls == [pretend.call(AdminFlagValue.READ_ONLY.value)] + assert get.calls == [pretend.call(AdminFlag, AdminFlagValue.READ_ONLY.value)] assert request.tm.doom.calls == doom_calls diff --git a/warehouse/accounts/services.py b/warehouse/accounts/services.py index 89f6e2e3f8d2..da5f6bebbb81 100644 --- a/warehouse/accounts/services.py +++ b/warehouse/accounts/services.py @@ -23,9 +23,9 @@ import requests from passlib.context import CryptContext +from sqlalchemy import exists, select from sqlalchemy.exc import NoResultFound from sqlalchemy.orm import joinedload -from sqlalchemy.sql import exists from webauthn.helpers import bytes_to_base64url from zope.interface import implementer @@ -99,9 +99,11 @@ def _get_user(self, userid): # object here. # TODO: We need some sort of Anonymous User. return ( - self.db.query(User).options(joinedload(User.webauthn)).get(userid) - if userid - else None + self.db.scalars( + select(User).options(joinedload(User.webauthn)).where(User.id == userid) + ) + .unique() + .one_or_none() ) def get_user(self, userid): diff --git a/warehouse/accounts/views.py b/warehouse/accounts/views.py index df2169284ed0..ec03d5e26df3 100644 --- a/warehouse/accounts/views.py +++ b/warehouse/accounts/views.py @@ -1550,8 +1550,8 @@ def delete_pending_oidc_publisher(self): ) return self.default_response - pending_publisher = self.request.db.query(PendingOIDCPublisher).get( - form.publisher_id.data + pending_publisher = self.request.db.get( + PendingOIDCPublisher, form.publisher_id.data ) # pending_publisher will be `None` here if someone manually diff --git a/warehouse/admin/flags.py b/warehouse/admin/flags.py index ebb4d6a0affa..dee394f45166 100644 --- a/warehouse/admin/flags.py +++ b/warehouse/admin/flags.py @@ -48,7 +48,7 @@ def notifications(self): ) def enabled(self, flag_member): - flag = self.request.db.query(AdminFlag).get(flag_member.value) + flag = self.request.db.get(AdminFlag, flag_member.value) return flag.enabled if flag else False diff --git a/warehouse/admin/views/emails.py b/warehouse/admin/views/emails.py index b29171cd09dd..6e4be8bb2e74 100644 --- a/warehouse/admin/views/emails.py +++ b/warehouse/admin/views/emails.py @@ -90,7 +90,7 @@ def email_mass(request): rows = list(csv.DictReader(wrapper)) if rows: for row in rows: - user = request.db.query(User).get(row["user_id"]) + user = request.db.get(User, row["user_id"]) email = user.primary_email if email: diff --git a/warehouse/admin/views/flags.py b/warehouse/admin/views/flags.py index 0e17d9b19f82..a25f2df19a20 100644 --- a/warehouse/admin/views/flags.py +++ b/warehouse/admin/views/flags.py @@ -36,7 +36,7 @@ def get_flags(request): require_csrf=True, ) def edit_flag(request): - flag = request.db.query(AdminFlag).get(request.POST["id"]) + flag = request.db.get(AdminFlag, request.POST["id"]) flag.description = request.POST["description"] flag.enabled = bool(request.POST.get("enabled")) diff --git a/warehouse/admin/views/projects.py b/warehouse/admin/views/projects.py index 8d76a35cf7ba..11169ef209ca 100644 --- a/warehouse/admin/views/projects.py +++ b/warehouse/admin/views/projects.py @@ -429,7 +429,7 @@ def delete_role(project, request): confirm = request.POST.get("username") role_id = request.matchdict.get("role_id") - role = request.db.query(Role).get(role_id) + role = request.db.get(Role, role_id) if not role: request.session.flash("This role no longer exists", queue="error") raise HTTPSeeOther( diff --git a/warehouse/admin/views/verdicts.py b/warehouse/admin/views/verdicts.py index 4e2620ffa3d7..a8048d0052c5 100644 --- a/warehouse/admin/views/verdicts.py +++ b/warehouse/admin/views/verdicts.py @@ -56,7 +56,7 @@ def get_verdicts(request): uses_session=True, ) def get_verdict(request): - verdict = request.db.query(MalwareVerdict).get(request.matchdict["verdict_id"]) + verdict = request.db.get(MalwareVerdict, request.matchdict["verdict_id"]) if verdict: return { @@ -76,7 +76,7 @@ def get_verdict(request): require_csrf=True, ) def review_verdict(request): - verdict = request.db.query(MalwareVerdict).get(request.matchdict["verdict_id"]) + verdict = request.db.get(MalwareVerdict, request.matchdict["verdict_id"]) try: classification = getattr(VerdictClassification, request.POST["classification"]) diff --git a/warehouse/db.py b/warehouse/db.py index 6d46c14980d2..4df7be5b8a82 100644 --- a/warehouse/db.py +++ b/warehouse/db.py @@ -22,8 +22,7 @@ from sqlalchemy import event, inspect from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.exc import IntegrityError, OperationalError -from sqlalchemy.ext.declarative import declarative_base # type: ignore -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import declarative_base, sessionmaker from warehouse.metrics import IMetricsService from warehouse.utils.attrs import make_repr @@ -155,7 +154,7 @@ def cleanup(request): # Check if we're in read-only mode from warehouse.admin.flags import AdminFlag, AdminFlagValue - flag = session.query(AdminFlag).get(AdminFlagValue.READ_ONLY.value) + flag = session.get(AdminFlag, AdminFlagValue.READ_ONLY.value) if flag and flag.enabled: request.tm.doom() diff --git a/warehouse/legacy/api/pypi.py b/warehouse/legacy/api/pypi.py index d37c016c26c7..02968e3909f3 100644 --- a/warehouse/legacy/api/pypi.py +++ b/warehouse/legacy/api/pypi.py @@ -100,7 +100,7 @@ def browse(request): except ValueError: raise HTTPNotFound - classifier = request.db.query(Classifier).get(classifier_id) + classifier = request.db.get(Classifier, classifier_id) if not classifier: raise HTTPNotFound diff --git a/warehouse/malware/checks/base.py b/warehouse/malware/checks/base.py index c6e3f4591e8d..c10e97d25986 100644 --- a/warehouse/malware/checks/base.py +++ b/warehouse/malware/checks/base.py @@ -29,7 +29,7 @@ def prepare(cls, request, obj_id): kwargs = {"obj_id": obj_id} model = getattr(models, cls.hooked_object) - kwargs["obj"] = request.db.query(model).get(obj_id) + kwargs["obj"] = request.db.get(model, obj_id) if cls.hooked_object == "File": kwargs["file_url"] = request.route_url( diff --git a/warehouse/manage/views/__init__.py b/warehouse/manage/views/__init__.py index d11068f91158..d389d1375967 100644 --- a/warehouse/manage/views/__init__.py +++ b/warehouse/manage/views/__init__.py @@ -1353,7 +1353,7 @@ def delete_oidc_publisher(self): form = DeletePublisherForm(self.request.POST) if form.validate(): - publisher = self.request.db.query(OIDCPublisher).get(form.publisher_id.data) + publisher = self.request.db.get(OIDCPublisher, form.publisher_id.data) # publisher will be `None` here if someone manually futzes with the form. if publisher is None or publisher not in self.project.oidc_publishers: diff --git a/warehouse/migrations/env.py b/warehouse/migrations/env.py index bb170a2919a7..23a05d64ba40 100644 --- a/warehouse/migrations/env.py +++ b/warehouse/migrations/env.py @@ -11,7 +11,7 @@ # limitations under the License. from alembic import context -from sqlalchemy import create_engine, pool +from sqlalchemy import create_engine, pool, text from warehouse import db @@ -50,8 +50,8 @@ def run_migrations_online(): connectable = create_engine(url, poolclass=pool.NullPool) with connectable.connect() as connection: - connection.execute("SET statement_timeout = 5000") - connection.execute("SET lock_timeout = 4000") + connection.execute(text("SET statement_timeout = 5000")) + connection.execute(text("SET lock_timeout = 4000")) context.configure( connection=connection, diff --git a/warehouse/migrations/versions/203f1f8dcf92_event_source_id_cascades_on_delete.py b/warehouse/migrations/versions/203f1f8dcf92_event_source_id_cascades_on_delete.py index 9877c8cb1a9b..8dcca01321ce 100644 --- a/warehouse/migrations/versions/203f1f8dcf92_event_source_id_cascades_on_delete.py +++ b/warehouse/migrations/versions/203f1f8dcf92_event_source_id_cascades_on_delete.py @@ -18,6 +18,8 @@ """ +import sqlalchemy as sa + from alembic import op revision = "203f1f8dcf92" @@ -27,7 +29,7 @@ def upgrade(): # We've seen this migration fail due to statement timeouts in production. conn = op.get_bind() - conn.execute("SET statement_timeout = 60000") + conn.execute(sa.text("SET statement_timeout = 60000")) op.drop_constraint("file_events_source_id_fkey", "file_events", type_="foreignkey") op.create_foreign_key( diff --git a/warehouse/migrations/versions/4490777c984f_migrate_existing_data_for_release_is_.py b/warehouse/migrations/versions/4490777c984f_migrate_existing_data_for_release_is_.py index 01d65f1b50c7..7a94448429e2 100644 --- a/warehouse/migrations/versions/4490777c984f_migrate_existing_data_for_release_is_.py +++ b/warehouse/migrations/versions/4490777c984f_migrate_existing_data_for_release_is_.py @@ -35,7 +35,7 @@ def _get_num_rows(conn): def upgrade(): conn = op.get_bind() - conn.execute("SET statement_timeout = 120000") + conn.execute(sa.text("SET statement_timeout = 120000")) total_rows = _get_num_rows(conn) max_loops = total_rows / 100000 * 2 loops = 0 diff --git a/warehouse/migrations/versions/75ba94852cd1_make_pendingoidcpublisher_added_by_id_.py b/warehouse/migrations/versions/75ba94852cd1_make_pendingoidcpublisher_added_by_id_.py index 9835911cfe6b..e9a15e2a316e 100644 --- a/warehouse/migrations/versions/75ba94852cd1_make_pendingoidcpublisher_added_by_id_.py +++ b/warehouse/migrations/versions/75ba94852cd1_make_pendingoidcpublisher_added_by_id_.py @@ -17,6 +17,8 @@ Create Date: 2023-04-14 18:21:38.683694 """ +import sqlalchemy as sa + from alembic import op from sqlalchemy.dialects import postgresql @@ -26,8 +28,8 @@ def upgrade(): conn = op.get_bind() - conn.execute("SET statement_timeout = 120000") - conn.execute("SET lock_timeout = 120000") + conn.execute(sa.text("SET statement_timeout = 120000")) + conn.execute(sa.text("SET lock_timeout = 120000")) op.alter_column( "pending_oidc_publishers", "added_by_id", diff --git a/warehouse/migrations/versions/d142f435bb39_add_archived_column_to_files.py b/warehouse/migrations/versions/d142f435bb39_add_archived_column_to_files.py index ba515378d3e2..d413f50ead64 100644 --- a/warehouse/migrations/versions/d142f435bb39_add_archived_column_to_files.py +++ b/warehouse/migrations/versions/d142f435bb39_add_archived_column_to_files.py @@ -27,7 +27,7 @@ def upgrade(): conn = op.get_bind() - conn.execute("SET statement_timeout = 120000") + conn.execute(sa.text("SET statement_timeout = 120000")) op.add_column( "release_files", sa.Column( diff --git a/warehouse/migrations/versions/d1771b942eb6_remove_user_has_oidc_beta_access_column.py b/warehouse/migrations/versions/d1771b942eb6_remove_user_has_oidc_beta_access_column.py index 2bf14bdaaaa0..ce2d68a67e2b 100644 --- a/warehouse/migrations/versions/d1771b942eb6_remove_user_has_oidc_beta_access_column.py +++ b/warehouse/migrations/versions/d1771b942eb6_remove_user_has_oidc_beta_access_column.py @@ -27,8 +27,8 @@ def upgrade(): conn = op.get_bind() - conn.execute("SET statement_timeout = 120000") - conn.execute("SET lock_timeout = 120000") + conn.execute(sa.text("SET statement_timeout = 120000")) + conn.execute(sa.text("SET lock_timeout = 120000")) op.drop_column("users", "has_oidc_beta_access") diff --git a/warehouse/organizations/services.py b/warehouse/organizations/services.py index 7701a98fdc50..76994a8d1c58 100644 --- a/warehouse/organizations/services.py +++ b/warehouse/organizations/services.py @@ -12,7 +12,7 @@ import datetime -from sqlalchemy import func +from sqlalchemy import delete, func, select from sqlalchemy.exc import NoResultFound from zope.interface import implementer @@ -47,7 +47,7 @@ def get_organization(self, organization_id): Return the organization object that represents the given organizationid, or None if there is no organization for that ID. """ - return self.db.query(Organization).get(organization_id) + return self.db.get(Organization, organization_id) def get_organization_by_name(self, name): """ @@ -80,7 +80,7 @@ def get_organizations(self): """ Return a list of all organization objects, or None if there are none. """ - return self.db.query(Organization).order_by(Organization.name).all() + return self.db.scalars(select(Organization).order_by(Organization.name)).all() def get_organizations_needing_approval(self): """ @@ -155,7 +155,7 @@ def get_organization_role(self, organization_role_id): Return the org role object that represents the given org role id, or None if there is no organization role for that ID. """ - return self.db.query(OrganizationRole).get(organization_role_id) + return self.db.get(OrganizationRole, organization_role_id) def get_organization_role_by_user(self, organization_id, user_id): """ @@ -213,7 +213,7 @@ def get_organization_invite(self, organization_invite_id): Return the org invite object that represents the given org invite id, or None if there is no organization invite for that ID. """ - return self.db.query(OrganizationInvitation).get(organization_invite_id) + return self.db.get(OrganizationInvitation, organization_invite_id) def get_organization_invite_by_user(self, organization_id, user_id): """ @@ -477,13 +477,17 @@ def get_teams_by_organization(self, organization_id): Return a list of all team objects for the specified organization, or None if there are none. """ - return self.db.query(Team).filter(Team.organization_id == organization_id).all() + return ( + self.db.execute(select(Team).where(Team.organization_id == organization_id)) + .scalars() + .all() + ) def get_team(self, team_id): """ Return a team object for the specified identifier, """ - return self.db.query(Team).get(team_id) + return self.db.get(Team, team_id) def find_teamid(self, organization_id, team_name): """ @@ -545,11 +549,11 @@ def delete_team(self, team_id): """ team = self.get_team(team_id) # Delete team members - self.db.query(TeamRole).filter_by(team=team).delete() + self.db.execute(delete(TeamRole).filter_by(team=team)) # Delete projects - self.db.query(TeamProjectRole).filter_by(team=team).delete() + self.db.execute(delete(TeamProjectRole).filter_by(team=team)) # Delete team - self.db.delete(team) + self.db.execute(delete(Team).where(Team.id == team_id)) def delete_teams_by_organization(self, organization_id): """ @@ -563,7 +567,7 @@ def get_team_role(self, team_role_id): """ Return the team role object that represents the given team role id, """ - return self.db.query(TeamRole).get(team_role_id) + return self.db.get(TeamRole, team_role_id) def get_team_roles(self, team_id): """ @@ -600,7 +604,7 @@ def get_team_project_role(self, team_project_role_id): Return the team project role object that represents the given team project role id, """ - return self.db.query(TeamProjectRole).get(team_project_role_id) + return self.db.get(TeamProjectRole, team_project_role_id) def add_team_project_role(self, team_id, project_id, role_name): """ diff --git a/warehouse/packaging/tasks.py b/warehouse/packaging/tasks.py index c66d63203303..6951e367b9b9 100644 --- a/warehouse/packaging/tasks.py +++ b/warehouse/packaging/tasks.py @@ -31,7 +31,7 @@ @tasks.task(ignore_result=True, acks_late=True) def sync_file_to_archive(request, file_id): - file = request.db.query(File).get(file_id) + file = request.db.get(File, file_id) if not file.archived: primary_storage = request.find_service(IFileStorage, name="primary") archive_storage = request.find_service(IFileStorage, name="archive") diff --git a/warehouse/search/tasks.py b/warehouse/search/tasks.py index 87394b2eef73..518f4179d1c3 100644 --- a/warehouse/search/tasks.py +++ b/warehouse/search/tasks.py @@ -21,7 +21,7 @@ from elasticsearch.helpers import parallel_bulk from elasticsearch_dsl import serializer -from sqlalchemy import func +from sqlalchemy import func, text from sqlalchemy.orm import aliased from warehouse import tasks @@ -183,7 +183,7 @@ def reindex(self, request): # From this point on, if any error occurs, we want to be able to delete our # in progress index. try: - request.db.execute("SET statement_timeout = '600s'") + request.db.execute(text("SET statement_timeout = '600s'")) for _ in parallel_bulk( client, _project_docs(request.db), index=new_index_name diff --git a/warehouse/subscriptions/services.py b/warehouse/subscriptions/services.py index 5253b1e9619c..310a9624f096 100644 --- a/warehouse/subscriptions/services.py +++ b/warehouse/subscriptions/services.py @@ -379,7 +379,7 @@ def get_subscription(self, id): """ Get a subscription by id """ - return self.db.query(StripeSubscription).get(id) + return self.db.get(StripeSubscription, id) def find_subscriptionid(self, subscription_id): """ @@ -495,7 +495,7 @@ def get_stripe_customer(self, stripe_customer_id): """ Get a stripe customer by id """ - return self.db.query(StripeCustomer).get(stripe_customer_id) + return self.db.get(StripeCustomer, stripe_customer_id) def find_stripe_customer_id(self, customer_id): """ @@ -560,7 +560,7 @@ def get_subscription_product(self, subscription_product_id): """ Get a product by subscription product id """ - return self.db.query(StripeSubscriptionProduct).get(subscription_product_id) + return self.db.get(StripeSubscriptionProduct, subscription_product_id) def get_subscription_products(self): """ @@ -661,7 +661,7 @@ def get_subscription_price(self, subscription_price_id): """ Get a subscription price by id """ - return self.db.query(StripeSubscriptionPrice).get(subscription_price_id) + return self.db.get(StripeSubscriptionPrice, subscription_price_id) def get_subscription_prices(self): """