diff --git a/tests/unit/oidc/test_services.py b/tests/unit/oidc/test_services.py index 8b6c88079e16..bbc80b07e34b 100644 --- a/tests/unit/oidc/test_services.py +++ b/tests/unit/oidc/test_services.py @@ -59,10 +59,11 @@ def test_interface_matches(self): ) def test_verify_jwt_signature(self, monkeypatch): + issuer_url = "https://example.com" service = services.OIDCPublisherService( session=pretend.stub(), publisher=pretend.stub(), - issuer_url=pretend.stub(), + issuer_url=issuer_url, audience="fakeaudience", cache_url=pretend.stub(), metrics=pretend.stub(), @@ -72,12 +73,10 @@ def test_verify_jwt_signature(self, monkeypatch): decoded = pretend.stub() jwt = pretend.stub(decode=pretend.call_recorder(lambda t, **kwargs: decoded)) key = pretend.stub(key="fake-key") - monkeypatch.setattr( - service, "_get_key_for_token", pretend.call_recorder(lambda t: key) - ) + monkeypatch.setattr(service, "_get_key_for_token", lambda t, i=issuer_url: key) monkeypatch.setattr(services, "jwt", jwt) - assert service.verify_jwt_signature(token) == decoded + assert service.verify_jwt_signature(token, issuer_url) == decoded assert jwt.decode.calls == [ pretend.call( token, @@ -93,7 +92,7 @@ def test_verify_jwt_signature(self, monkeypatch): verify_nbf=True, strict_aud=True, ), - issuer=service.issuer_url, + issuer=issuer_url, audience="fakeaudience", leeway=30, ) @@ -119,7 +118,7 @@ def test_verify_jwt_signature_get_key_for_token_fails(self, metrics, monkeypatch pretend.call_recorder(lambda s: None), ) - assert service.verify_jwt_signature(token) is None + assert service.verify_jwt_signature(token, "https://none") is None assert service.metrics.increment.calls == [ pretend.call( "warehouse.oidc.verify_jwt_signature.malformed_jwt", @@ -130,10 +129,11 @@ def test_verify_jwt_signature_get_key_for_token_fails(self, metrics, monkeypatch @pytest.mark.parametrize("exc", [PyJWTError, TypeError("foo")]) def test_verify_jwt_signature_fails(self, metrics, monkeypatch, exc): + issuer_url = "https://none" service = services.OIDCPublisherService( session=pretend.stub(), publisher="fakepublisher", - issuer_url="https://none", + issuer_url=issuer_url, audience="fakeaudience", cache_url=pretend.stub(), metrics=metrics, @@ -143,7 +143,9 @@ def test_verify_jwt_signature_fails(self, metrics, monkeypatch, exc): jwt = pretend.stub(decode=pretend.raiser(exc), PyJWTError=PyJWTError) key = pretend.stub(key="fake-key") monkeypatch.setattr( - service, "_get_key_for_token", pretend.call_recorder(lambda t: key) + service, + "_get_key_for_token", + pretend.call_recorder(lambda t, i=issuer_url: key), ) monkeypatch.setattr(services, "jwt", jwt) monkeypatch.setattr( @@ -152,7 +154,7 @@ def test_verify_jwt_signature_fails(self, metrics, monkeypatch, exc): pretend.call_recorder(lambda s: None), ) - assert service.verify_jwt_signature(token) is None + assert service.verify_jwt_signature(token, issuer_url) is None assert service.metrics.increment.calls == [ pretend.call( "warehouse.oidc.verify_jwt_signature.invalid_signature", @@ -168,16 +170,17 @@ def test_verify_jwt_signature_fails(self, metrics, monkeypatch, exc): assert services.sentry_sdk.capture_message.calls == [] def test_find_publisher(self, metrics, monkeypatch): + issuer_url = "https://none" service = services.OIDCPublisherService( session=pretend.stub(), publisher="fakepublisher", - issuer_url="https://none", + issuer_url=issuer_url, audience="fakeaudience", cache_url=pretend.stub(), metrics=metrics, ) - token = SignedClaims({}) + token = SignedClaims({"iss": issuer_url}) publisher = pretend.stub(verify_claims=pretend.call_recorder(lambda c, s: True)) find_publisher_by_issuer = pretend.call_recorder(lambda *a, **kw: publisher) @@ -198,10 +201,11 @@ def test_find_publisher(self, metrics, monkeypatch): ] def test_find_publisher_issuer_lookup_fails(self, metrics, monkeypatch): + issuer_url = "https://none" service = services.OIDCPublisherService( session=pretend.stub(), publisher="fakepublisher", - issuer_url="https://none", + issuer_url=issuer_url, audience="fakeaudience", cache_url=pretend.stub(), metrics=metrics, @@ -212,7 +216,7 @@ def test_find_publisher_issuer_lookup_fails(self, metrics, monkeypatch): services, "find_publisher_by_issuer", find_publisher_by_issuer ) - claims = pretend.stub() + claims = SignedClaims({"iss": issuer_url}) with pytest.raises(errors.InvalidPublisherError): service.find_publisher(claims) assert service.metrics.increment.calls == [ @@ -227,10 +231,11 @@ def test_find_publisher_issuer_lookup_fails(self, metrics, monkeypatch): ] def test_find_publisher_verify_claims_fails(self, metrics, monkeypatch): + issuer_url = "https://none" service = services.OIDCPublisherService( session=pretend.stub(), publisher="fakepublisher", - issuer_url="https://none", + issuer_url=issuer_url, audience="fakeaudience", cache_url=pretend.stub(), metrics=metrics, @@ -246,7 +251,7 @@ def test_find_publisher_verify_claims_fails(self, metrics, monkeypatch): services, "find_publisher_by_issuer", find_publisher_by_issuer ) - claims = SignedClaims({}) + claims = SignedClaims({"iss": issuer_url}) with pytest.raises(errors.InvalidPublisherError): service.find_publisher(claims) assert service.metrics.increment.calls == [ @@ -311,7 +316,7 @@ def test_get_keyset_not_cached(self, monkeypatch, mockredis): monkeypatch.setattr(services.redis, "StrictRedis", mockredis) - keys, timeout = service._get_keyset() + keys, timeout = service._get_keyset("https://example.com") assert not keys assert timeout is False @@ -329,17 +334,18 @@ def test_get_keyset_cached(self, monkeypatch, mockredis): monkeypatch.setattr(services.redis, "StrictRedis", mockredis) keyset = {"fake-key-id": {"foo": "bar"}} - service._store_keyset(keyset) - keys, timeout = service._get_keyset() + service._store_keyset("https://example.com", keyset) + keys, timeout = service._get_keyset("https://example.com") assert keys == keyset assert timeout is True def test_refresh_keyset_timeout(self, metrics, monkeypatch, mockredis): + issuer_url = "https://example.com" service = services.OIDCPublisherService( session=pretend.stub(), publisher="example", - issuer_url="https://example.com", + issuer_url=issuer_url, audience="fakeaudience", cache_url="rediss://fake.example.com", metrics=metrics, @@ -348,9 +354,9 @@ def test_refresh_keyset_timeout(self, metrics, monkeypatch, mockredis): monkeypatch.setattr(services.redis, "StrictRedis", mockredis) keyset = {"fake-key-id": {"foo": "bar"}} - service._store_keyset(keyset) + service._store_keyset(issuer_url, keyset) - keys = service._refresh_keyset() + keys = service._refresh_keyset(issuer_url) assert keys == keyset assert metrics.increment.calls == [ pretend.call( @@ -380,7 +386,7 @@ def test_refresh_keyset_oidc_config_fails(self, metrics, monkeypatch, mockredis) monkeypatch.setattr(services, "requests", requests) monkeypatch.setattr(services, "sentry_sdk", sentry_sdk) - keys = service._refresh_keyset() + keys = service._refresh_keyset("https://example.com") assert keys == {} assert metrics.increment.calls == [] @@ -421,7 +427,7 @@ def test_refresh_keyset_oidc_config_no_jwks_uri( monkeypatch.setattr(services, "requests", requests) monkeypatch.setattr(services, "sentry_sdk", sentry_sdk) - keys = service._refresh_keyset() + keys = service._refresh_keyset("https://example.com") assert keys == {} assert metrics.increment.calls == [] @@ -472,7 +478,7 @@ def get(url, timeout=5): monkeypatch.setattr(services, "requests", requests) monkeypatch.setattr(services, "sentry_sdk", sentry_sdk) - keys = service._refresh_keyset() + keys = service._refresh_keyset("https://example.com") assert keys == {} assert metrics.increment.calls == [] @@ -524,7 +530,7 @@ def get(url, timeout=5): monkeypatch.setattr(services, "requests", requests) monkeypatch.setattr(services, "sentry_sdk", sentry_sdk) - keys = service._refresh_keyset() + keys = service._refresh_keyset("https://example.com") assert keys == {} assert metrics.increment.calls == [] @@ -573,7 +579,7 @@ def get(url, timeout=5): monkeypatch.setattr(services, "requests", requests) monkeypatch.setattr(services, "sentry_sdk", sentry_sdk) - keys = service._refresh_keyset() + keys = service._refresh_keyset("https://example.com") assert keys == {"fake-key-id": {"kid": "fake-key-id", "foo": "bar"}} assert metrics.increment.calls == [] @@ -586,7 +592,7 @@ def get(url, timeout=5): assert sentry_sdk.capture_message.calls == [] # Ensure that we also cached the updated keyset as part of refreshing. - keys, timeout = service._get_keyset() + keys, timeout = service._get_keyset("https://example.com") assert keys == {"fake-key-id": {"kid": "fake-key-id", "foo": "bar"}} assert timeout is True @@ -612,9 +618,11 @@ def test_get_key_cached(self, metrics, monkeypatch): "x5t": "dummy", } } - monkeypatch.setattr(service, "_get_keyset", lambda: (keyset, True)) + monkeypatch.setattr( + service, "_get_keyset", lambda issuer_url=None: (keyset, True) + ) - key = service._get_key("fake-key-id") + key = service._get_key("fake-key-id", "https://example.com") assert isinstance(key, PyJWK) assert key.key_id == "fake-key-id" @@ -642,10 +650,10 @@ def test_get_key_uncached(self, metrics, monkeypatch): "x5t": "dummy", } } - monkeypatch.setattr(service, "_get_keyset", lambda: ({}, False)) - monkeypatch.setattr(service, "_refresh_keyset", lambda: keyset) + monkeypatch.setattr(service, "_get_keyset", lambda issuer_url=None: ({}, False)) + monkeypatch.setattr(service, "_refresh_keyset", lambda issuer_url=None: keyset) - key = service._get_key("fake-key-id") + key = service._get_key("fake-key-id", "https://example.com") assert isinstance(key, PyJWK) assert key.key_id == "fake-key-id" @@ -661,11 +669,14 @@ def test_get_key_refresh_fails(self, metrics, monkeypatch): metrics=metrics, ) - monkeypatch.setattr(service, "_get_keyset", lambda: ({}, False)) - monkeypatch.setattr(service, "_refresh_keyset", lambda: {}) + monkeypatch.setattr(service, "_get_keyset", lambda issuer_url=None: ({}, False)) + monkeypatch.setattr(service, "_refresh_keyset", lambda issuer_url=None: {}) - key = service._get_key("fake-key-id") - assert key is None + with pytest.raises( + jwt.PyJWTError, + match=r"Key ID 'fake-key-id' not found for issuer 'https://example.com'", + ): + service._get_key("fake-key-id", "https://example.com") assert metrics.increment.calls == [ pretend.call( @@ -690,7 +701,9 @@ def test_get_key_for_token(self, monkeypatch): cache_url="rediss://fake.example.com", metrics=pretend.stub(), ) - monkeypatch.setattr(service, "_get_key", pretend.call_recorder(lambda kid: key)) + monkeypatch.setattr( + service, "_get_key", pretend.call_recorder(lambda kid, i: key) + ) monkeypatch.setattr( services.jwt, @@ -698,8 +711,10 @@ def test_get_key_for_token(self, monkeypatch): pretend.call_recorder(lambda token: {"kid": "fake-key-id"}), ) - assert service._get_key_for_token(token) == key - assert service._get_key.calls == [pretend.call("fake-key-id")] + assert service._get_key_for_token(token, "https://example.com") == key + assert service._get_key.calls == [ + pretend.call("fake-key-id", "https://example.com") + ] assert services.jwt.get_unverified_header.calls == [pretend.call(token)] def test_reify_publisher(self, monkeypatch): @@ -802,7 +817,9 @@ def test_verify_jwt_signature_malformed_jwt(self): metrics=pretend.stub(), ) - assert service.verify_jwt_signature("malformed-jwt") is None + assert ( + service.verify_jwt_signature("malformed-jwt", "https://example.com") is None + ) def test_verify_jwt_signature_missing_aud(self): # { @@ -830,7 +847,7 @@ def test_verify_jwt_signature_missing_aud(self): metrics=pretend.stub(), ) - assert service.verify_jwt_signature(jwt) is None + assert service.verify_jwt_signature(jwt, "https://example.com") is None def test_verify_jwt_signature_wrong_aud(self): # { @@ -860,7 +877,7 @@ def test_verify_jwt_signature_wrong_aud(self): metrics=pretend.stub(), ) - assert service.verify_jwt_signature(jwt) is None + assert service.verify_jwt_signature(jwt, "https://example.com") is None def test_verify_jwt_signature_strict_aud(self): # { @@ -885,7 +902,7 @@ def test_verify_jwt_signature_strict_aud(self): metrics=pretend.stub(), ) - assert service.verify_jwt_signature(jwt) is None + assert service.verify_jwt_signature(jwt, "https://example.com") is None def test_find_publisher(self, monkeypatch): claims = SignedClaims( diff --git a/tests/unit/oidc/test_views.py b/tests/unit/oidc/test_views.py index 9e959820356a..20ecea5ed367 100644 --- a/tests/unit/oidc/test_views.py +++ b/tests/unit/oidc/test_views.py @@ -247,7 +247,7 @@ def body(self): @pytest.mark.parametrize( - ("token", "service_name"), + ("token", "service_name", "unverified_issuer"), [ ( ( @@ -256,6 +256,7 @@ def body(self): "ZWPGfHu-0EEQMlVyO5UVdQ" ), "github", + "https://token.actions.githubusercontent.com", ), ( ( @@ -263,6 +264,7 @@ def body(self): "nRzLmdvb2dsZS5jb20ifQ.2RJ6Y52Rap0LEj61yBGDokUg8r92SYQq6l3cflSWBVI" ), "google", + "https://accounts.google.com", ), ( ( @@ -271,11 +273,12 @@ def body(self): "sHtF7obbcnu4w_ZSU" ), "gitlab", + "https://gitlab.com", ), ], ) def test_mint_token_from_oidc_creates_expected_service( - monkeypatch, token, service_name + monkeypatch, token, service_name, unverified_issuer ): mint_token = pretend.call_recorder(lambda *a: pretend.stub()) monkeypatch.setattr(views, "mint_token", mint_token) @@ -293,12 +296,15 @@ def test_mint_token_from_oidc_creates_expected_service( assert request.find_service.calls == [ pretend.call(IOIDCPublisherService, name=service_name) ] - assert mint_token.calls == [pretend.call(oidc_service, token, request)] + assert mint_token.calls == [ + pretend.call(oidc_service, token, unverified_issuer, request) + ] def test_mint_token_from_trusted_publisher_verify_jwt_signature_fails(): + claims = {"iss": "https://none"} oidc_service = pretend.stub( - verify_jwt_signature=pretend.call_recorder(lambda token: None), + verify_jwt_signature=pretend.call_recorder(lambda token, issuer_url=None: None), ) request = pretend.stub( response=pretend.stub(status=None), @@ -306,7 +312,9 @@ def test_mint_token_from_trusted_publisher_verify_jwt_signature_fails(): flags=pretend.stub(disallow_oidc=lambda *a: False), ) - response = views.mint_token(oidc_service, DUMMY_GITHUB_OIDC_JWT, request) + response = views.mint_token( + oidc_service, DUMMY_GITHUB_OIDC_JWT, claims["iss"], request + ) assert request.response.status == 422 assert response == { "message": "Token request failed", @@ -319,15 +327,17 @@ def test_mint_token_from_trusted_publisher_verify_jwt_signature_fails(): } assert oidc_service.verify_jwt_signature.calls == [ - pretend.call(DUMMY_GITHUB_OIDC_JWT) + pretend.call(DUMMY_GITHUB_OIDC_JWT, claims["iss"]) ] def test_mint_token_trusted_publisher_lookup_fails(): - claims = pretend.stub() + claims = {"iss": "https://none"} message = "some message" oidc_service = pretend.stub( - verify_jwt_signature=pretend.call_recorder(lambda token: claims), + verify_jwt_signature=pretend.call_recorder( + lambda token, issuer_url=None: claims + ), find_publisher=pretend.call_recorder( pretend.raiser(errors.InvalidPublisherError(message)) ), @@ -338,7 +348,9 @@ def test_mint_token_trusted_publisher_lookup_fails(): flags=pretend.stub(disallow_oidc=lambda *a: False), ) - response = views.mint_token(oidc_service, DUMMY_GITHUB_OIDC_JWT, request) + response = views.mint_token( + oidc_service, DUMMY_GITHUB_OIDC_JWT, claims["iss"], request + ) assert request.response.status == 422 assert response == { "message": "Token request failed", @@ -353,7 +365,7 @@ def test_mint_token_trusted_publisher_lookup_fails(): } assert oidc_service.verify_jwt_signature.calls == [ - pretend.call(DUMMY_GITHUB_OIDC_JWT) + pretend.call(DUMMY_GITHUB_OIDC_JWT, claims["iss"]) ] assert oidc_service.find_publisher.calls == [ pretend.call(claims, pending=True), @@ -368,9 +380,11 @@ def find_publishers_mockup(_, pending: bool = False): else: raise errors.InvalidPublisherError("some message") - claims = pretend.stub() + claims = {"iss": "https://none"} oidc_service = pretend.stub( - verify_jwt_signature=pretend.call_recorder(lambda token: claims), + verify_jwt_signature=pretend.call_recorder( + lambda token, issuer_url=None: claims + ), find_publisher=find_publishers_mockup, ) request = pretend.stub( @@ -379,7 +393,9 @@ def find_publishers_mockup(_, pending: bool = False): flags=pretend.stub(disallow_oidc=lambda *a: False), ) - response = views.mint_token(oidc_service, DUMMY_GITHUB_OIDC_JWT, request) + response = views.mint_token( + oidc_service, DUMMY_GITHUB_OIDC_JWT, claims["iss"], request + ) assert request.response.status == 422 assert response == { "message": "Token request failed", @@ -400,16 +416,20 @@ def test_mint_token_pending_publisher_project_already_exists(db_request): db_request.flags.disallow_oidc = lambda f=None: False - claims = pretend.stub() + claims = {"iss": "https://none"} oidc_service = pretend.stub( - verify_jwt_signature=pretend.call_recorder(lambda token: claims), + verify_jwt_signature=pretend.call_recorder( + lambda token, issuer_url=None: claims + ), find_publisher=pretend.call_recorder( lambda claims, pending=False: pending_publisher ), ) db_request.find_service = pretend.call_recorder(lambda *a, **kw: oidc_service) - resp = views.mint_token(oidc_service, DUMMY_GITHUB_OIDC_JWT, db_request) + resp = views.mint_token( + oidc_service, DUMMY_GITHUB_OIDC_JWT, claims["iss"], db_request + ) assert db_request.response.status_code == 422 assert resp == { "message": "Token request failed", @@ -422,7 +442,7 @@ def test_mint_token_pending_publisher_project_already_exists(db_request): } assert oidc_service.verify_jwt_signature.calls == [ - pretend.call(DUMMY_GITHUB_OIDC_JWT) + pretend.call(DUMMY_GITHUB_OIDC_JWT, "https://none") ] assert oidc_service.find_publisher.calls == [pretend.call(claims, pending=True)] @@ -577,6 +597,10 @@ def test_mint_token_from_pending_trusted_publisher_invalidates_others( def test_mint_token_no_pending_publisher_ok( monkeypatch, db_request, claims_in_token, claims_input ): + # Ensure the `iss` claim is set to match the GitHub OIDC issuer, as that's + # what the GitHubPublisherFactory implies. + claims_in_token.update({"iss": "https://token.actions.githubusercontent.com"}) + time = pretend.stub(time=pretend.call_recorder(lambda: 0)) monkeypatch.setattr(views, "time", time) @@ -598,7 +622,9 @@ def _find_publisher(claims, pending=False): return publisher oidc_service = pretend.stub( - verify_jwt_signature=pretend.call_recorder(lambda token: claims_in_token), + verify_jwt_signature=pretend.call_recorder( + lambda token, issuer_url=None: claims_in_token + ), find_publisher=pretend.call_recorder(_find_publisher), ) @@ -618,7 +644,12 @@ def find_service(iface, **kw): monkeypatch.setattr(db_request, "find_service", find_service) monkeypatch.setattr(db_request, "domain", "fakedomain") - response = views.mint_token(oidc_service, DUMMY_GITHUB_OIDC_JWT, db_request) + response = views.mint_token( + oidc_service, + DUMMY_GITHUB_OIDC_JWT, + "https://token.actions.githubusercontent.com", + db_request, + ) assert response == { "success": True, "token": "raw-macaroon", @@ -626,7 +657,9 @@ def find_service(iface, **kw): } assert oidc_service.verify_jwt_signature.calls == [ - pretend.call(DUMMY_GITHUB_OIDC_JWT) + pretend.call( + DUMMY_GITHUB_OIDC_JWT, "https://token.actions.githubusercontent.com" + ) ] assert oidc_service.find_publisher.calls == [ pretend.call(claims_in_token, pending=True), @@ -663,7 +696,12 @@ def find_service(iface, **kw): def test_mint_token_warn_constrain_environment(monkeypatch, db_request): - claims_in_token = {"ref": "someref", "sha": "somesha", "environment": "fakeenv"} + claims_in_token = { + "ref": "someref", + "sha": "somesha", + "environment": "fakeenv", + "iss": "https://token.actions.githubusercontent.com", + } claims_input = {"ref": "someref", "sha": "somesha"} time = pretend.stub(time=pretend.call_recorder(lambda: 0)) monkeypatch.setattr(views, "time", time) @@ -698,7 +736,9 @@ def _find_publisher(claims, pending=False): return publisher oidc_service = pretend.stub( - verify_jwt_signature=pretend.call_recorder(lambda token: claims_in_token), + verify_jwt_signature=pretend.call_recorder( + lambda token, issuer_url=None: claims_in_token + ), find_publisher=pretend.call_recorder(_find_publisher), ) @@ -718,7 +758,9 @@ def find_service(iface, **kw): monkeypatch.setattr(db_request, "find_service", find_service) monkeypatch.setattr(db_request, "domain", "fakedomain") - response = views.mint_token(oidc_service, DUMMY_GITHUB_OIDC_JWT, db_request) + response = views.mint_token( + oidc_service, DUMMY_GITHUB_OIDC_JWT, claims_in_token["iss"], db_request + ) assert response == { "success": True, "token": "raw-macaroon", @@ -726,7 +768,9 @@ def find_service(iface, **kw): } assert oidc_service.verify_jwt_signature.calls == [ - pretend.call(DUMMY_GITHUB_OIDC_JWT) + pretend.call( + DUMMY_GITHUB_OIDC_JWT, "https://token.actions.githubusercontent.com" + ) ] assert oidc_service.find_publisher.calls == [ pretend.call(claims_in_token, pending=True), @@ -851,6 +895,7 @@ def test_mint_token_with_invalid_name_fails(monkeypatch, db_request): [ ( { + "iss": "https://token.actions.githubusercontent.com", "ref": "someref", "sha": "somesha", "workflow_ref": "org/repo/.github/workflows/parent.yml@someref", @@ -861,6 +906,7 @@ def test_mint_token_with_invalid_name_fails(monkeypatch, db_request): ), ( { + "iss": "https://token.actions.githubusercontent.com", "ref": "someref", "sha": "somesha", "workflow_ref": "org/repo/.github/workflows/workflow.yml@someref", @@ -871,6 +917,7 @@ def test_mint_token_with_invalid_name_fails(monkeypatch, db_request): ), ( { + "iss": "https://gitlab.com", "ref": "someref", "sha": "somesha", }, @@ -907,7 +954,9 @@ def _find_publisher(claims, pending=False): return publisher oidc_service = pretend.stub( - verify_jwt_signature=pretend.call_recorder(lambda token: claims_in_token), + verify_jwt_signature=pretend.call_recorder( + lambda token, issuer_url=None: claims_in_token + ), find_publisher=pretend.call_recorder(_find_publisher), ) @@ -929,7 +978,7 @@ def find_service(iface, **kw): monkeypatch.setattr(db_request, "find_service", find_service) monkeypatch.setattr(db_request, "domain", "fakedomain") - views.mint_token(oidc_service, DUMMY_GITHUB_OIDC_JWT, db_request) + views.mint_token(oidc_service, DUMMY_GITHUB_OIDC_JWT, claims_in_token, db_request) if is_reusable: assert metrics.increment.calls == [ diff --git a/warehouse/oidc/interfaces.py b/warehouse/oidc/interfaces.py index 260e32ab7b31..f92972ad167a 100644 --- a/warehouse/oidc/interfaces.py +++ b/warehouse/oidc/interfaces.py @@ -16,7 +16,9 @@ class IOIDCPublisherService(Interface): - def verify_jwt_signature(unverified_token: str) -> SignedClaims | None: + def verify_jwt_signature( + unverified_token: str, issuer_url: str + ) -> SignedClaims | None: """ Verify the given JWT's signature, returning its signed claims if valid. If the signature is invalid, `None` is returned. diff --git a/warehouse/oidc/services.py b/warehouse/oidc/services.py index f1b39a12c974..861d8fd99efd 100644 --- a/warehouse/oidc/services.py +++ b/warehouse/oidc/services.py @@ -46,9 +46,12 @@ def __init__( ) self.db = session + self.publisher = publisher self.issuer_url = issuer_url - def verify_jwt_signature(self, unverified_token: str) -> SignedClaims | None: + def verify_jwt_signature( + self, unverified_token: str, issuer_url: str + ) -> SignedClaims | None: try: return SignedClaims( jwt.decode( @@ -121,36 +124,37 @@ def __init__( self.cache_url = cache_url self.metrics = metrics - self._publisher_jwk_key = f"/warehouse/oidc/jwks/{self.issuer_url}" - self._publisher_timeout_key = f"{self._publisher_jwk_key}/timeout" - - def _store_keyset(self, keys: dict) -> None: + def _store_keyset(self, issuer_url: str, keys: dict) -> None: """ - Store the given keyset for the given publisher, setting the timeout key + Store the given keyset for the given **issuer**, setting the timeout key in the process. """ + _publisher_jwk_key = f"/warehouse/oidc/jwks/{issuer_url}" + _publisher_timeout_key = f"{_publisher_jwk_key}/timeout" with redis.StrictRedis.from_url(self.cache_url) as r: - r.set(self._publisher_jwk_key, json.dumps(keys)) - r.setex(self._publisher_timeout_key, 60, "placeholder") + r.set(_publisher_jwk_key, json.dumps(keys)) + r.setex(_publisher_timeout_key, 60, "placeholder") - def _get_keyset(self) -> tuple[dict[str, dict], bool]: + def _get_keyset(self, issuer_url: str) -> tuple[dict[str, dict], bool]: """ - Return the cached keyset for the given publisher, or an empty + Return the cached keyset for the given issuer, or an empty keyset if no keys are currently cached. """ + _publisher_jwk_key = f"/warehouse/oidc/jwks/{issuer_url}" + _publisher_timeout_key = f"{_publisher_jwk_key}/timeout" with redis.StrictRedis.from_url(self.cache_url) as r: - keys = r.get(self._publisher_jwk_key) - timeout = bool(r.exists(self._publisher_timeout_key)) + keys = r.get(_publisher_jwk_key) + timeout = bool(r.exists(_publisher_timeout_key)) if keys is not None: - return (json.loads(keys), timeout) + return json.loads(keys), timeout else: - return ({}, timeout) + return {}, timeout - def _refresh_keyset(self) -> dict[str, dict]: + def _refresh_keyset(self, issuer_url: str) -> dict[str, dict]: """ - Attempt to refresh the keyset from the OIDC publisher, assuming no + Attempt to refresh the keyset from the OIDC issuer, assuming no timeout is in effect. Returns the refreshed keyset, or the cached keyset if a timeout is @@ -160,15 +164,15 @@ def _refresh_keyset(self) -> dict[str, dict]: """ # Fast path: we're in a cooldown from a previous refresh. - keys, timeout = self._get_keyset() + keys, timeout = self._get_keyset(issuer_url=issuer_url) if timeout: self.metrics.increment( "warehouse.oidc.refresh_keyset.timeout", - tags=[f"publisher:{self.publisher}", f"issuer_url:{self.issuer_url}"], + tags=[f"publisher:{self.publisher}", f"issuer_url:{issuer_url}"], ) return keys - oidc_url = f"{self.issuer_url}/.well-known/openid-configuration" + oidc_url = f"{issuer_url}/.well-known/openid-configuration" resp = requests.get(oidc_url, timeout=5) @@ -219,19 +223,19 @@ def _refresh_keyset(self) -> dict[str, dict]: return keys keys = {key["kid"]: key for key in new_keys} - self._store_keyset(keys) + self._store_keyset(issuer_url, keys) return keys - def _get_key(self, key_id: str) -> jwt.PyJWK | None: + def _get_key(self, key_id: str, issuer_url: str) -> jwt.PyJWK: """ Return a JWK for the given key ID, or None if the key can't be found in this publisher's keyset. """ - keyset, _ = self._get_keyset() + keyset, _ = self._get_keyset(issuer_url) if key_id not in keyset: - keyset = self._refresh_keyset() + keyset = self._refresh_keyset(issuer_url) if key_id not in keyset: self.metrics.increment( "warehouse.oidc.get_key.error", @@ -241,10 +245,12 @@ def _get_key(self, key_id: str) -> jwt.PyJWK | None: f"issuer_url:{self.issuer_url}", ], ) - return None + raise jwt.PyJWTError( + f"Key ID {key_id!r} not found for issuer {issuer_url!r}" + ) return jwt.PyJWK(keyset[key_id]) - def _get_key_for_token(self, token): + def _get_key_for_token(self, token, issuer_url: str) -> jwt.PyJWK: """ Return a JWK suitable for verifying the given JWT. @@ -252,7 +258,7 @@ def _get_key_for_token(self, token): prior to any verification. """ unverified_header = jwt.get_unverified_header(token) - return self._get_key(unverified_header["kid"]) + return self._get_key(unverified_header["kid"], issuer_url) def jwt_identifier_exists(self, jti: str) -> bool: """ @@ -276,15 +282,20 @@ def store_jwt_identifier(self, jti: str, expiration: int) -> None: nx=True, ) - def verify_jwt_signature(self, unverified_token: str) -> SignedClaims | None: + def verify_jwt_signature( + self, unverified_token: str, issuer_url: str + ) -> SignedClaims | None: + """ + Verify the signature of the given JWT, returning the signed claims. + """ try: - key = self._get_key_for_token(unverified_token) + key = self._get_key_for_token(unverified_token, issuer_url) except jwt.PyJWTError: # The user might feed us an entirely nonsense JWT, e.g. one # with missing components. self.metrics.increment( "warehouse.oidc.verify_jwt_signature.malformed_jwt", - tags=[f"publisher:{self.publisher}", f"issuer_url:{self.issuer_url}"], + tags=[f"publisher:{self.publisher}", f"issuer_url:{issuer_url}"], ) return None @@ -311,7 +322,7 @@ def verify_jwt_signature(self, unverified_token: str) -> SignedClaims | None: # want to be the ONLY audience listed. strict_aud=True, ), - issuer=self.issuer_url, + issuer=issuer_url, audience=self.audience, leeway=30, ) @@ -319,7 +330,7 @@ def verify_jwt_signature(self, unverified_token: str) -> SignedClaims | None: except Exception as e: self.metrics.increment( "warehouse.oidc.verify_jwt_signature.invalid_signature", - tags=[f"publisher:{self.publisher}", f"issuer_url:{self.issuer_url}"], + tags=[f"publisher:{self.publisher}", f"issuer_url:{issuer_url}"], ) if not isinstance(e, jwt.PyJWTError): with sentry_sdk.new_scope() as scope: @@ -334,7 +345,10 @@ def find_publisher( self, signed_claims: SignedClaims, *, pending: bool = False ) -> OIDCPublisher | PendingOIDCPublisher: """Returns a publisher for the given claims, or raises an error.""" - metrics_tags = [f"publisher:{self.publisher}", f"issuer_url:{self.issuer_url}"] + metrics_tags = [ + f"publisher:{self.publisher}", + f"issuer_url:{signed_claims['iss']}", + ] self.metrics.increment( "warehouse.oidc.find_publisher.attempt", tags=metrics_tags, diff --git a/warehouse/oidc/views.py b/warehouse/oidc/views.py index 2f157e69085c..d6897db5a03a 100644 --- a/warehouse/oidc/views.py +++ b/warehouse/oidc/views.py @@ -162,13 +162,16 @@ def mint_token_from_oidc(request: Request): IOIDCPublisherService, name=service_name ) - return mint_token(oidc_service, unverified_jwt, request) + return mint_token(oidc_service, unverified_jwt, unverified_issuer, request) def mint_token( - oidc_service: OIDCPublisherService, unverified_jwt: str, request: Request + oidc_service: OIDCPublisherService, + unverified_jwt: str, + unverified_issuer: str, + request: Request, ) -> JsonResponse: - claims = oidc_service.verify_jwt_signature(unverified_jwt) + claims = oidc_service.verify_jwt_signature(unverified_jwt, unverified_issuer) if not claims: return _invalid( errors=[