Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 61 additions & 44 deletions tests/unit/oidc/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ def test_interface_matches(self):
)

def test_verify_jwt_signature(self, monkeypatch):
issuer_url = "https://example.com"
service = services.OIDCPublisherService(
session=pretend.stub(),
publisher=pretend.stub(),
issuer_url=pretend.stub(),
issuer_url=issuer_url,
audience="fakeaudience",
cache_url=pretend.stub(),
metrics=pretend.stub(),
Expand All @@ -72,12 +73,10 @@ def test_verify_jwt_signature(self, monkeypatch):
decoded = pretend.stub()
jwt = pretend.stub(decode=pretend.call_recorder(lambda t, **kwargs: decoded))
key = pretend.stub(key="fake-key")
monkeypatch.setattr(
service, "_get_key_for_token", pretend.call_recorder(lambda t: key)
)
monkeypatch.setattr(service, "_get_key_for_token", lambda t, i=issuer_url: key)
monkeypatch.setattr(services, "jwt", jwt)

assert service.verify_jwt_signature(token) == decoded
assert service.verify_jwt_signature(token, issuer_url) == decoded
assert jwt.decode.calls == [
pretend.call(
token,
Expand All @@ -93,7 +92,7 @@ def test_verify_jwt_signature(self, monkeypatch):
verify_nbf=True,
strict_aud=True,
),
issuer=service.issuer_url,
issuer=issuer_url,
audience="fakeaudience",
leeway=30,
)
Expand All @@ -119,7 +118,7 @@ def test_verify_jwt_signature_get_key_for_token_fails(self, metrics, monkeypatch
pretend.call_recorder(lambda s: None),
)

assert service.verify_jwt_signature(token) is None
assert service.verify_jwt_signature(token, "https://none") is None
assert service.metrics.increment.calls == [
pretend.call(
"warehouse.oidc.verify_jwt_signature.malformed_jwt",
Expand All @@ -130,10 +129,11 @@ def test_verify_jwt_signature_get_key_for_token_fails(self, metrics, monkeypatch

@pytest.mark.parametrize("exc", [PyJWTError, TypeError("foo")])
def test_verify_jwt_signature_fails(self, metrics, monkeypatch, exc):
issuer_url = "https://none"
service = services.OIDCPublisherService(
session=pretend.stub(),
publisher="fakepublisher",
issuer_url="https://none",
issuer_url=issuer_url,
audience="fakeaudience",
cache_url=pretend.stub(),
metrics=metrics,
Expand All @@ -143,7 +143,9 @@ def test_verify_jwt_signature_fails(self, metrics, monkeypatch, exc):
jwt = pretend.stub(decode=pretend.raiser(exc), PyJWTError=PyJWTError)
key = pretend.stub(key="fake-key")
monkeypatch.setattr(
service, "_get_key_for_token", pretend.call_recorder(lambda t: key)
service,
"_get_key_for_token",
pretend.call_recorder(lambda t, i=issuer_url: key),
)
monkeypatch.setattr(services, "jwt", jwt)
monkeypatch.setattr(
Expand All @@ -152,7 +154,7 @@ def test_verify_jwt_signature_fails(self, metrics, monkeypatch, exc):
pretend.call_recorder(lambda s: None),
)

assert service.verify_jwt_signature(token) is None
assert service.verify_jwt_signature(token, issuer_url) is None
assert service.metrics.increment.calls == [
pretend.call(
"warehouse.oidc.verify_jwt_signature.invalid_signature",
Expand All @@ -168,16 +170,17 @@ def test_verify_jwt_signature_fails(self, metrics, monkeypatch, exc):
assert services.sentry_sdk.capture_message.calls == []

def test_find_publisher(self, metrics, monkeypatch):
issuer_url = "https://none"
service = services.OIDCPublisherService(
session=pretend.stub(),
publisher="fakepublisher",
issuer_url="https://none",
issuer_url=issuer_url,
audience="fakeaudience",
cache_url=pretend.stub(),
metrics=metrics,
)

token = SignedClaims({})
token = SignedClaims({"iss": issuer_url})

publisher = pretend.stub(verify_claims=pretend.call_recorder(lambda c, s: True))
find_publisher_by_issuer = pretend.call_recorder(lambda *a, **kw: publisher)
Expand All @@ -198,10 +201,11 @@ def test_find_publisher(self, metrics, monkeypatch):
]

def test_find_publisher_issuer_lookup_fails(self, metrics, monkeypatch):
issuer_url = "https://none"
service = services.OIDCPublisherService(
session=pretend.stub(),
publisher="fakepublisher",
issuer_url="https://none",
issuer_url=issuer_url,
audience="fakeaudience",
cache_url=pretend.stub(),
metrics=metrics,
Expand All @@ -212,7 +216,7 @@ def test_find_publisher_issuer_lookup_fails(self, metrics, monkeypatch):
services, "find_publisher_by_issuer", find_publisher_by_issuer
)

claims = pretend.stub()
claims = SignedClaims({"iss": issuer_url})
with pytest.raises(errors.InvalidPublisherError):
service.find_publisher(claims)
assert service.metrics.increment.calls == [
Expand All @@ -227,10 +231,11 @@ def test_find_publisher_issuer_lookup_fails(self, metrics, monkeypatch):
]

def test_find_publisher_verify_claims_fails(self, metrics, monkeypatch):
issuer_url = "https://none"
service = services.OIDCPublisherService(
session=pretend.stub(),
publisher="fakepublisher",
issuer_url="https://none",
issuer_url=issuer_url,
audience="fakeaudience",
cache_url=pretend.stub(),
metrics=metrics,
Expand All @@ -246,7 +251,7 @@ def test_find_publisher_verify_claims_fails(self, metrics, monkeypatch):
services, "find_publisher_by_issuer", find_publisher_by_issuer
)

claims = SignedClaims({})
claims = SignedClaims({"iss": issuer_url})
with pytest.raises(errors.InvalidPublisherError):
service.find_publisher(claims)
assert service.metrics.increment.calls == [
Expand Down Expand Up @@ -311,7 +316,7 @@ def test_get_keyset_not_cached(self, monkeypatch, mockredis):

monkeypatch.setattr(services.redis, "StrictRedis", mockredis)

keys, timeout = service._get_keyset()
keys, timeout = service._get_keyset("https://example.com")

assert not keys
assert timeout is False
Expand All @@ -329,17 +334,18 @@ def test_get_keyset_cached(self, monkeypatch, mockredis):
monkeypatch.setattr(services.redis, "StrictRedis", mockredis)

keyset = {"fake-key-id": {"foo": "bar"}}
service._store_keyset(keyset)
keys, timeout = service._get_keyset()
service._store_keyset("https://example.com", keyset)
keys, timeout = service._get_keyset("https://example.com")

assert keys == keyset
assert timeout is True

def test_refresh_keyset_timeout(self, metrics, monkeypatch, mockredis):
issuer_url = "https://example.com"
service = services.OIDCPublisherService(
session=pretend.stub(),
publisher="example",
issuer_url="https://example.com",
issuer_url=issuer_url,
audience="fakeaudience",
cache_url="rediss://fake.example.com",
metrics=metrics,
Expand All @@ -348,9 +354,9 @@ def test_refresh_keyset_timeout(self, metrics, monkeypatch, mockredis):
monkeypatch.setattr(services.redis, "StrictRedis", mockredis)

keyset = {"fake-key-id": {"foo": "bar"}}
service._store_keyset(keyset)
service._store_keyset(issuer_url, keyset)

keys = service._refresh_keyset()
keys = service._refresh_keyset(issuer_url)
assert keys == keyset
assert metrics.increment.calls == [
pretend.call(
Expand Down Expand Up @@ -380,7 +386,7 @@ def test_refresh_keyset_oidc_config_fails(self, metrics, monkeypatch, mockredis)
monkeypatch.setattr(services, "requests", requests)
monkeypatch.setattr(services, "sentry_sdk", sentry_sdk)

keys = service._refresh_keyset()
keys = service._refresh_keyset("https://example.com")

assert keys == {}
assert metrics.increment.calls == []
Expand Down Expand Up @@ -421,7 +427,7 @@ def test_refresh_keyset_oidc_config_no_jwks_uri(
monkeypatch.setattr(services, "requests", requests)
monkeypatch.setattr(services, "sentry_sdk", sentry_sdk)

keys = service._refresh_keyset()
keys = service._refresh_keyset("https://example.com")

assert keys == {}
assert metrics.increment.calls == []
Expand Down Expand Up @@ -472,7 +478,7 @@ def get(url, timeout=5):
monkeypatch.setattr(services, "requests", requests)
monkeypatch.setattr(services, "sentry_sdk", sentry_sdk)

keys = service._refresh_keyset()
keys = service._refresh_keyset("https://example.com")

assert keys == {}
assert metrics.increment.calls == []
Expand Down Expand Up @@ -524,7 +530,7 @@ def get(url, timeout=5):
monkeypatch.setattr(services, "requests", requests)
monkeypatch.setattr(services, "sentry_sdk", sentry_sdk)

keys = service._refresh_keyset()
keys = service._refresh_keyset("https://example.com")

assert keys == {}
assert metrics.increment.calls == []
Expand Down Expand Up @@ -573,7 +579,7 @@ def get(url, timeout=5):
monkeypatch.setattr(services, "requests", requests)
monkeypatch.setattr(services, "sentry_sdk", sentry_sdk)

keys = service._refresh_keyset()
keys = service._refresh_keyset("https://example.com")

assert keys == {"fake-key-id": {"kid": "fake-key-id", "foo": "bar"}}
assert metrics.increment.calls == []
Expand All @@ -586,7 +592,7 @@ def get(url, timeout=5):
assert sentry_sdk.capture_message.calls == []

# Ensure that we also cached the updated keyset as part of refreshing.
keys, timeout = service._get_keyset()
keys, timeout = service._get_keyset("https://example.com")
assert keys == {"fake-key-id": {"kid": "fake-key-id", "foo": "bar"}}
assert timeout is True

Expand All @@ -612,9 +618,11 @@ def test_get_key_cached(self, metrics, monkeypatch):
"x5t": "dummy",
}
}
monkeypatch.setattr(service, "_get_keyset", lambda: (keyset, True))
monkeypatch.setattr(
service, "_get_keyset", lambda issuer_url=None: (keyset, True)
)

key = service._get_key("fake-key-id")
key = service._get_key("fake-key-id", "https://example.com")
assert isinstance(key, PyJWK)
assert key.key_id == "fake-key-id"

Expand Down Expand Up @@ -642,10 +650,10 @@ def test_get_key_uncached(self, metrics, monkeypatch):
"x5t": "dummy",
}
}
monkeypatch.setattr(service, "_get_keyset", lambda: ({}, False))
monkeypatch.setattr(service, "_refresh_keyset", lambda: keyset)
monkeypatch.setattr(service, "_get_keyset", lambda issuer_url=None: ({}, False))
monkeypatch.setattr(service, "_refresh_keyset", lambda issuer_url=None: keyset)

key = service._get_key("fake-key-id")
key = service._get_key("fake-key-id", "https://example.com")
assert isinstance(key, PyJWK)
assert key.key_id == "fake-key-id"

Expand All @@ -661,11 +669,14 @@ def test_get_key_refresh_fails(self, metrics, monkeypatch):
metrics=metrics,
)

monkeypatch.setattr(service, "_get_keyset", lambda: ({}, False))
monkeypatch.setattr(service, "_refresh_keyset", lambda: {})
monkeypatch.setattr(service, "_get_keyset", lambda issuer_url=None: ({}, False))
monkeypatch.setattr(service, "_refresh_keyset", lambda issuer_url=None: {})

key = service._get_key("fake-key-id")
assert key is None
with pytest.raises(
jwt.PyJWTError,
match=r"Key ID 'fake-key-id' not found for issuer 'https://example.com'",
):
service._get_key("fake-key-id", "https://example.com")

assert metrics.increment.calls == [
pretend.call(
Expand All @@ -690,16 +701,20 @@ def test_get_key_for_token(self, monkeypatch):
cache_url="rediss://fake.example.com",
metrics=pretend.stub(),
)
monkeypatch.setattr(service, "_get_key", pretend.call_recorder(lambda kid: key))
monkeypatch.setattr(
service, "_get_key", pretend.call_recorder(lambda kid, i: key)
)

monkeypatch.setattr(
services.jwt,
"get_unverified_header",
pretend.call_recorder(lambda token: {"kid": "fake-key-id"}),
)

assert service._get_key_for_token(token) == key
assert service._get_key.calls == [pretend.call("fake-key-id")]
assert service._get_key_for_token(token, "https://example.com") == key
assert service._get_key.calls == [
pretend.call("fake-key-id", "https://example.com")
]
assert services.jwt.get_unverified_header.calls == [pretend.call(token)]

def test_reify_publisher(self, monkeypatch):
Expand Down Expand Up @@ -802,7 +817,9 @@ def test_verify_jwt_signature_malformed_jwt(self):
metrics=pretend.stub(),
)

assert service.verify_jwt_signature("malformed-jwt") is None
assert (
service.verify_jwt_signature("malformed-jwt", "https://example.com") is None
)

def test_verify_jwt_signature_missing_aud(self):
# {
Expand Down Expand Up @@ -830,7 +847,7 @@ def test_verify_jwt_signature_missing_aud(self):
metrics=pretend.stub(),
)

assert service.verify_jwt_signature(jwt) is None
assert service.verify_jwt_signature(jwt, "https://example.com") is None

def test_verify_jwt_signature_wrong_aud(self):
# {
Expand Down Expand Up @@ -860,7 +877,7 @@ def test_verify_jwt_signature_wrong_aud(self):
metrics=pretend.stub(),
)

assert service.verify_jwt_signature(jwt) is None
assert service.verify_jwt_signature(jwt, "https://example.com") is None

def test_verify_jwt_signature_strict_aud(self):
# {
Expand All @@ -885,7 +902,7 @@ def test_verify_jwt_signature_strict_aud(self):
metrics=pretend.stub(),
)

assert service.verify_jwt_signature(jwt) is None
assert service.verify_jwt_signature(jwt, "https://example.com") is None

def test_find_publisher(self, monkeypatch):
claims = SignedClaims(
Expand Down
Loading