Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify ChaCha20 Poly1305 implementation #2338

Merged
merged 6 commits into from
Jun 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions pyatv/support/chacha20.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
"""Transparent encryption layer using Chacha20_Poly1305."""
from functools import partial
from struct import Struct
from typing import Optional

from chacha20poly1305_reuseable import ChaCha20Poly1305Reusable as ChaCha20Poly1305

NONCE_LENGTH = 12

# The first 4 bytes are always 0, followed by 8 bytes of counter
# for a total of 12 bytes.
PACK_NONCE = partial(Struct("<LQ").pack, 0)


class Chacha20Cipher:
"""CHACHA20 encryption/decryption layer."""
Expand All @@ -24,7 +30,7 @@ def out_nonce(self) -> bytes:
This is the nonce that will be used by encrypt in the _next_ call if no custom
nonce is specified.
"""
return self._out_counter.to_bytes(length=self._nonce_length, byteorder="little")
return PACK_NONCE(self._out_counter)

@property
def in_nonce(self) -> bytes:
Expand All @@ -33,7 +39,7 @@ def in_nonce(self) -> bytes:
This is the nonce that will be used by decrypt in the _next_ call if no custom
nonce is specified.
"""
return self._in_counter.to_bytes(length=self._nonce_length, byteorder="little")
return PACK_NONCE(self._in_counter)

def encrypt(
self, data: bytes, nonce: Optional[bytes] = None, aad: Optional[bytes] = None
Expand All @@ -42,10 +48,8 @@ def encrypt(
if nonce is None:
nonce = self.out_nonce
self._out_counter += 1

if len(nonce) < NONCE_LENGTH:
elif len(nonce) < NONCE_LENGTH:
nonce = b"\x00" * (NONCE_LENGTH - len(nonce)) + nonce

return self._enc_out.encrypt(nonce, data, aad)

def decrypt(
Expand All @@ -55,10 +59,6 @@ def decrypt(
if nonce is None:
nonce = self.in_nonce
self._in_counter += 1

if len(nonce) < NONCE_LENGTH:
elif len(nonce) < NONCE_LENGTH:
nonce = b"\x00" * (NONCE_LENGTH - len(nonce)) + nonce

decrypted = self._enc_in.decrypt(nonce, data, aad)

return bytes(decrypted)
return self._enc_in.decrypt(nonce, data, aad)
bdraco marked this conversation as resolved.
Show resolved Hide resolved
Loading