diff --git a/tests/unit/cli/test_hashing.py b/tests/unit/cli/test_hashing.py index 3b68aeaa5d2c..7316e5bf8a67 100644 --- a/tests/unit/cli/test_hashing.py +++ b/tests/unit/cli/test_hashing.py @@ -38,7 +38,8 @@ def test_no_records_to_backfill(self, cli, db_request, monkeypatch): assert db_request.db.query(User.Event).count() == 0 - result = cli.invoke(hashing.backfill_ipaddrs, obj=config) + args = ["--event-type", "user"] + result = cli.invoke(hashing.backfill_ipaddrs, args, obj=config) assert result.exit_code == 0 assert result.output.strip() == "No rows to backfill. Done!" @@ -67,7 +68,8 @@ def test_backfill_with_no_ipaddr_obj(self, cli, db_session, monkeypatch): assert db_session.query(User.Event).count() == 3 assert db_session.query(IpAddress).count() == 0 - result = cli.invoke(hashing.backfill_ipaddrs, obj=config) + args = ["--event-type", "user"] + result = cli.invoke(hashing.backfill_ipaddrs, args, obj=config) assert result.exit_code == 0 assert db_session.query(IpAddress).count() == 3 @@ -96,6 +98,8 @@ def tests_backfills_records(self, cli, db_request, remote_addr, monkeypatch): assert db_request.db.query(User.Event).count() == 3 args = [ + "--event-type", + "user", "--batch-size", "2", ] @@ -141,6 +145,8 @@ def test_continue_until_done(self, cli, db_request, remote_addr, monkeypatch): ) args = [ + "--event-type", + "user", "--batch-size", "1", "--sleep-time", diff --git a/warehouse/cli/hashing.py b/warehouse/cli/hashing.py index 25c5b029105a..aad57739e8a7 100644 --- a/warehouse/cli/hashing.py +++ b/warehouse/cli/hashing.py @@ -28,6 +28,13 @@ def hashing(): @hashing.command() +@click.option( + "-e", + "--event-type", + type=click.Choice(["user", "project", "file", "organization", "team"]), + required=True, + help="Type of event to backfill", +) @click.option( "-b", "--batch-size", @@ -51,6 +58,7 @@ def hashing(): @click.pass_obj def backfill_ipaddrs( config, + event_type: str, batch_size: int, sleep_time: int, continue_until_done: bool, @@ -68,12 +76,15 @@ def backfill_ipaddrs( salt = config.registry.settings["warehouse.ip_salt"] - _backfill_ips(session, salt, batch_size, sleep_time, continue_until_done) + _backfill_ips( + session, salt, event_type, batch_size, sleep_time, continue_until_done + ) def _backfill_ips( session, salt: str, + event_type: str, batch_size: int, sleep_time: int, continue_until_done: bool, @@ -82,18 +93,26 @@ def _backfill_ips( Create missing IPAddress objects for events that don't have them. Broken out from the CLI command so that it can be called recursively. - - TODO: Currently operates on only User events, but should be expanded to - include Project events and others. """ from warehouse.accounts.models import User from warehouse.ip_addresses.models import IpAddress + from warehouse.organizations.models import Organization, Team + from warehouse.packaging.models import File, Project + + has_events = { + "user": User, + "organization": Organization, + "team": Team, + "project": Project, + "file": File, + } + model = has_events[event_type] # Get rows a batch at a time, only if the row doesn't have an `ip_address_id no_ip_obj_rows = session.scalars( - select(User.Event) - .where(User.Event.ip_address_id.is_(None)) # type: ignore[attr-defined] - .order_by(User.Event.time) # type: ignore[attr-defined] + select(model.Event) # type: ignore[attr-defined] + .where(model.Event.ip_address_id.is_(None)) # type: ignore[attr-defined] + .order_by(model.Event.time) # type: ignore[attr-defined] .limit(batch_size) ).all() @@ -137,6 +156,7 @@ def _backfill_ips( _backfill_ips( session, salt, + event_type, batch_size, sleep_time, continue_until_done,