Skip to content

Commit

Permalink
Expose get_algorithm_by_name as new method
Browse files Browse the repository at this point in the history
Looking up an algorithm by name is used internally for signature
generation. This encapsulates that functionality in a dedicated method
and adds it to the public API. No new tests are needed to exercise the
functionality.

Rationale:

1. Inside of PyJWS, this improves the code. The KeyError handler is
   better scoped and the signing code reads more directly.

2. This is part of the path to supporting OIDC at_hash validation as a
   use-case (see: jpadilla#295, jpadilla#296, jpadilla#314).

This is arguably sufficient to consider that use-case supported and
close it. However, it is an improvement and step in the right
direction in either case.

A minor change was needed to satisfy mypy, as a union-typed variable
does not narrow its type based on assignments. The easiest resolution
is to use a new name, in this case, simply `algorithm -> algorithm_`.
  • Loading branch information
sirosen committed Jun 29, 2022
1 parent a863a73 commit 4058ae0
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 15 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Expand Up @@ -19,6 +19,8 @@ Fixed
Added
~~~~~
- Add to_jwk static method to ECAlgorithm by @leonsmith in https://github.com/jpadilla/pyjwt/pull/732
- Add ``get_algorithm_by_name`` as a method of ``PyJWS`` objects, and expose
the global PyJWS method as part of the public API

`v2.4.0 <https://github.com/jpadilla/pyjwt/compare/2.3.0...2.4.0>`__
-----------------------------------------------------------------------
Expand Down
2 changes: 2 additions & 0 deletions jwt/__init__.py
@@ -1,6 +1,7 @@
from .api_jwk import PyJWK, PyJWKSet
from .api_jws import (
PyJWS,
get_algorithm_by_name,
get_unverified_header,
register_algorithm,
unregister_algorithm,
Expand Down Expand Up @@ -51,6 +52,7 @@
"get_unverified_header",
"register_algorithm",
"unregister_algorithm",
"get_algorithm_by_name",
# Exceptions
"DecodeError",
"ExpiredSignatureError",
Expand Down
40 changes: 25 additions & 15 deletions jwt/api_jws.py
Expand Up @@ -73,6 +73,23 @@ def get_algorithms(self):
"""
return list(self._valid_algs)

def get_algorithm_by_name(self, alg_name: str) -> Algorithm:
"""
For a given string name, return the matching Algorithm object.
Example usage:
>>> jws_obj.get_algorithm_by_name("RS256")
"""
try:
return self._algorithms[alg_name]
except KeyError as e:
if not has_crypto and alg_name in requires_cryptography:
raise NotImplementedError(
f"Algorithm '{alg_name}' could not be found. Do you have cryptography installed?"
) from e
raise NotImplementedError("Algorithm not supported") from e

def encode(
self,
payload: bytes,
Expand All @@ -84,21 +101,21 @@ def encode(
) -> str:
segments = []

if algorithm is None:
algorithm = "none"
# declare a new var to narrow the type for type checkers
algorithm_: str = algorithm if algorithm is not None else "none"

# Prefer headers values if present to function parameters.
if headers:
headers_alg = headers.get("alg")
if headers_alg:
algorithm = headers["alg"]
algorithm_ = headers["alg"]

headers_b64 = headers.get("b64")
if headers_b64 is False:
is_payload_detached = True

# Header
header = {"typ": self.header_typ, "alg": algorithm} # type: Dict[str, Any]
header = {"typ": self.header_typ, "alg": algorithm_} # type: Dict[str, Any]

if headers:
self._validate_headers(headers)
Expand Down Expand Up @@ -128,17 +145,9 @@ def encode(
# Segments
signing_input = b".".join(segments)

try:
alg_obj = self._algorithms[algorithm]
key = alg_obj.prepare_key(key)
signature = alg_obj.sign(signing_input, key)

except KeyError as e:
if not has_crypto and algorithm in requires_cryptography:
raise NotImplementedError(
f"Algorithm '{algorithm}' could not be found. Do you have cryptography installed?"
) from e
raise NotImplementedError("Algorithm not supported") from e
alg_obj = self.get_algorithm_by_name(algorithm_)
key = alg_obj.prepare_key(key)
signature = alg_obj.sign(signing_input, key)

segments.append(base64url_encode(signature))

Expand Down Expand Up @@ -286,4 +295,5 @@ def _validate_kid(self, kid):
decode = _jws_global_obj.decode
register_algorithm = _jws_global_obj.register_algorithm
unregister_algorithm = _jws_global_obj.unregister_algorithm
get_algorithm_by_name = _jws_global_obj.get_algorithm_by_name
get_unverified_header = _jws_global_obj.get_unverified_header

0 comments on commit 4058ae0

Please sign in to comment.