Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 56 additions & 60 deletions python/py_vapid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

class VapidException(Exception):
"""An exception wrapper for Vapid."""

pass


Expand All @@ -34,6 +35,7 @@ class Vapid01(object):
https://tools.ietf.org/html/draft-ietf-webpush-vapid-01

"""

_private_key = None
_public_key = None
_schema = "WebPush"
Expand Down Expand Up @@ -65,14 +67,14 @@ def from_raw(cls, private_raw):
key = ec.derive_private_key(
int(binascii.hexlify(b64urldecode(private_raw)), 16),
curve=ec.SECP256R1(),
backend=default_backend())
backend=default_backend(),
)
return cls(key)

@classmethod
def from_raw_public(cls, public_raw):
key = ec.EllipticCurvePublicKey.from_encoded_point(
curve=ec.SECP256R1(),
data=b64urldecode(public_raw)
curve=ec.SECP256R1(), data=b64urldecode(public_raw)
)
ss = cls()
ss._public_key = key
Expand All @@ -87,8 +89,7 @@ def from_pem(cls, private_key):

"""
# not sure why, but load_pem_private_key fails to deserialize
return cls.from_der(
b''.join(private_key.splitlines()[1:-1]))
return cls.from_der(b"".join(private_key.splitlines()[1:-1]))

@classmethod
def from_der(cls, private_key):
Expand All @@ -98,9 +99,9 @@ def from_der(cls, private_key):
:type private_key: bytes

"""
key = serialization.load_der_private_key(b64urldecode(private_key),
password=None,
backend=default_backend())
key = serialization.load_der_private_key(
b64urldecode(private_key), password=None, backend=default_backend()
)
return cls(key)

@classmethod
Expand All @@ -118,13 +119,13 @@ def from_file(cls, private_key_file=None):
vapid.generate_keys()
vapid.save_key(private_key_file)
return vapid
with open(private_key_file, 'r') as file:
with open(private_key_file, "r") as file:
private_key = file.read()
try:
if "-----BEGIN" in private_key:
vapid = cls.from_pem(private_key.encode('utf8'))
vapid = cls.from_pem(private_key.encode("utf8"))
else:
vapid = cls.from_der(private_key.encode('utf8'))
vapid = cls.from_der(private_key.encode("utf8"))
return vapid
except Exception as exc:
logging.error("Could not open private key file: %s", repr(exc))
Expand Down Expand Up @@ -156,11 +157,10 @@ def verify(cls, key, auth):
type key: str

"""
tokens = auth.rsplit(' ', 1)[1].rsplit('.', 1)
tokens = auth.rsplit(" ", 1)[1].rsplit(".", 1)
kp = cls().from_raw_public(key.encode())
return kp.verify_token(
validation_token=tokens[0].encode(),
verification_token=tokens[1]
validation_token=tokens[0].encode(), verification_token=tokens[1]
)

@property
Expand Down Expand Up @@ -197,20 +197,19 @@ def public_key(self):

def generate_keys(self):
"""Generate a valid ECDSA Key Pair."""
self.private_key = ec.generate_private_key(ec.SECP256R1,
default_backend())
self.private_key = ec.generate_private_key(curve=ec.SECP256R1(), backend=default_backend())

def private_pem(self):
return self.private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
encryption_algorithm=serialization.NoEncryption(),
)

def public_pem(self):
return self.public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)

def save_key(self, key_file):
Expand Down Expand Up @@ -245,38 +244,40 @@ def verify_token(self, validation_token, verification_token):
:rtype: boolean

"""
hsig = b64urldecode(verification_token.encode('utf8'))
hsig = b64urldecode(verification_token.encode("utf8"))
r = int(binascii.hexlify(hsig[:32]), 16)
s = int(binascii.hexlify(hsig[32:]), 16)
try:
self.public_key.verify(
ecutils.encode_dss_signature(r, s),
validation_token,
signature_algorithm=ec.ECDSA(hashes.SHA256())
signature_algorithm=ec.ECDSA(hashes.SHA256()),
)
return True
except InvalidSignature:
return False

def _base_sign(self, claims):
cclaims = copy.deepcopy(claims)
if not cclaims.get('exp'):
cclaims['exp'] = int(time.time()) + 86400
if not self.conf.get('no-strict', False):
valid = _check_sub(cclaims.get('sub', ''))
if not cclaims.get("exp"):
cclaims["exp"] = int(time.time()) + 86400
if not self.conf.get("no-strict", False):
valid = _check_sub(cclaims.get("sub", ""))
else:
valid = cclaims.get('sub') is not None
valid = cclaims.get("sub") is not None
if not valid:
raise VapidException(
"Missing 'sub' from claims. "
"'sub' is your admin email as a mailto: link.")
if not re.match(r"^https?://[^/:]+(:\d+)?$",
cclaims.get("aud", ""),
re.IGNORECASE):
"'sub' is your admin email as a mailto: link."
)
if not re.match(
r"^https?://[^/:]+(:\d+)?$", cclaims.get("aud", ""), re.IGNORECASE
):
raise VapidException(
"Missing 'aud' from claims. "
"'aud' is the scheme, host and optional port for this "
"transaction e.g. https://example.com:8080")
"transaction e.g. https://example.com:8080"
)
return cclaims

def sign(self, claims, crypto_key=None):
Expand All @@ -292,19 +293,22 @@ def sign(self, claims, crypto_key=None):

"""
sig = sign(self._base_sign(claims), self.private_key)
pkey = 'p256ecdsa='
pkey = "p256ecdsa="
pkey += b64urlencode(
self.public_key.public_bytes(
serialization.Encoding.X962,
serialization.PublicFormat.UncompressedPoint
))
serialization.PublicFormat.UncompressedPoint,
)
)
if crypto_key:
crypto_key = crypto_key + ';' + pkey
crypto_key = crypto_key + ";" + pkey
else:
crypto_key = pkey

return {"Authorization": "{} {}".format(self._schema, sig.strip('=')),
"Crypto-Key": crypto_key}
return {
"Authorization": "{} {}".format(self._schema, sig.strip("=")),
"Crypto-Key": crypto_key,
}


class Vapid02(Vapid01):
Expand All @@ -313,6 +317,7 @@ class Vapid02(Vapid01):
https://tools.ietf.org/html/rfc8292

"""

_schema = "vapid"

def sign(self, claims, crypto_key=None):
Expand All @@ -329,14 +334,11 @@ def sign(self, claims, crypto_key=None):
"""
sig = sign(self._base_sign(claims), self.private_key)
pkey = self.public_key.public_bytes(
serialization.Encoding.X962,
serialization.PublicFormat.UncompressedPoint
)
return{
serialization.Encoding.X962, serialization.PublicFormat.UncompressedPoint
)
return {
"Authorization": "{schema} t={t},k={k}".format(
schema=self._schema,
t=sig,
k=b64urlencode(pkey)
schema=self._schema, t=sig, k=b64urlencode(pkey)
)
}

Expand All @@ -349,27 +351,23 @@ def verify(cls, auth):
:rtype: bool

"""
pref_tok = auth.rsplit(' ', 1)
assert pref_tok[0].lower() == cls._schema, (
"Incorrect schema specified")
pref_tok = auth.rsplit(" ", 1)
assert pref_tok[0].lower() == cls._schema, "Incorrect schema specified"
parts = {}
for tok in pref_tok[1].split(','):
kv = tok.split('=', 1)
for tok in pref_tok[1].split(","):
kv = tok.split("=", 1)
parts[kv[0]] = kv[1]
assert 'k' in parts.keys(), (
"Auth missing public key 'k' value")
assert 't' in parts.keys(), (
"Auth missing token set 't' value")
kp = cls().from_raw_public(parts['k'].encode())
tokens = parts['t'].rsplit('.', 1)
assert "k" in parts.keys(), "Auth missing public key 'k' value"
assert "t" in parts.keys(), "Auth missing token set 't' value"
kp = cls().from_raw_public(parts["k"].encode())
tokens = parts["t"].rsplit(".", 1)
return kp.verify_token(
validation_token=tokens[0].encode(),
verification_token=tokens[1]
validation_token=tokens[0].encode(), verification_token=tokens[1]
)


def _check_sub(sub):
""" Check to see if the `sub` is a properly formatted `mailto:`
"""Check to see if the `sub` is a properly formatted `mailto:`

a `mailto:` should be a SMTP mail address. Mind you, since I run
YouFailAtEmail.com, you have every right to yell about how terrible
Expand All @@ -382,9 +380,7 @@ def _check_sub(sub):
:rtype: bool

"""
pattern = (
r"^(mailto:.+@((localhost|[%\w-]+(\.[%\w-]+)+|([0-9a-f]{1,4}):+([0-9a-f]{1,4})?)))|https:\/\/(localhost|[\w-]+\.[\w\.-]+|([0-9a-f]{1,4}:+)+([0-9a-f]{1,4})?)$" # noqa
)
pattern = r"^(mailto:.+@((localhost|[%\w-]+(\.[%\w-]+)+|([0-9a-f]{1,4}):+([0-9a-f]{1,4})?)))|https:\/\/(localhost|[\w-]+\.[\w\.-]+|([0-9a-f]{1,4}:+)+([0-9a-f]{1,4})?)$" # noqa
return re.match(pattern, sub, re.IGNORECASE) is not None


Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "py-vapid"
version = "1.9.2"
version = "1.9.3"
license = {text = "MPL-2.0"}
description = "Simple VAPID header generation library"
readme = "README.rst"
Expand Down