Skip to content

Commit

Permalink
Annotations in tests, part 1 (#725)
Browse files Browse the repository at this point in the history
* Annotations for tests.utils

* Annotations for tests.test_aead

* Annotations for tests.test_signing

* Annotations for tests.test_bindings

* Annotations for tests.test_box

* Move check_type_error to tests.utils

* Start running mypy on tests
  • Loading branch information
DMRobertson committed Dec 23, 2021
1 parent b9a14e1 commit 6f732e6
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 96 deletions.
32 changes: 30 additions & 2 deletions pyproject.toml
Expand Up @@ -39,7 +39,15 @@ warn_unreachable = true
no_implicit_reexport = true
strict_equality = true

files = ["src/nacl"]
# Include test files manually for now. Can tidy up later.
files = [
"src/nacl",
"tests/test_aead.py",
"tests/test_bindings.py",
"tests/test_box.py",
"tests/test_signing.py",
"tests/utils.py",
]

[[tool.mypy.overrides]]
module = [
Expand All @@ -51,7 +59,7 @@ ignore_missing_imports = true
# nacl._sodium return `Any` as far as mypy is concerned. It's not worth it to
# stub the C functions or cast() their uses. But this means there are more
# `Any`s floating around. So the more restrictive any checks we'd like to use
# should only be turned out outside of `bindings`.
# should only be turned on outside of `bindings`.

[[tool.mypy.overrides]]
module = [
Expand All @@ -60,3 +68,23 @@ module = [
disallow_any_expr = false
warn_return_any = false

# Loosen some of the checks within the tests.
# For now this is an explicit list rather than a wildcard "test.*", to make
# it a little easier to run the strict checks on modules first. We can clean
# this up later. Note that we _do_ run the strict checks on `test.utils`.

[[tool.mypy.overrides]]
module = [
"tests.test_aead",
"tests.test_bindings",
"tests.test_box",
"tests.test_signing",
]
# Some library helpers types' involve `Any`, in particular `pytest.mark.parameterize`
# and `hypothesis.strategies.sampledfrom`.
disallow_any_expr = false
disallow_any_decorated = false

# It's not useful to annotate each test function as `-> None`.
disallow_untyped_defs = false
disallow_incomplete_defs = false
4 changes: 2 additions & 2 deletions src/nacl/bindings/crypto_scalarmult.py
Expand Up @@ -20,8 +20,8 @@

has_crypto_scalarmult_ed25519 = bool(lib.PYNACL_HAS_CRYPTO_SCALARMULT_ED25519)

crypto_scalarmult_BYTES = lib.crypto_scalarmult_bytes()
crypto_scalarmult_SCALARBYTES = lib.crypto_scalarmult_scalarbytes()
crypto_scalarmult_BYTES: int = lib.crypto_scalarmult_bytes()
crypto_scalarmult_SCALARBYTES: int = lib.crypto_scalarmult_scalarbytes()

crypto_scalarmult_ed25519_BYTES = 0
crypto_scalarmult_ed25519_SCALARBYTES = 0
Expand Down
37 changes: 22 additions & 15 deletions tests/test_aead.py
Expand Up @@ -11,10 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import binascii
from collections import namedtuple
from typing import Callable, Dict, List, NamedTuple, Optional

from hypothesis import given, settings
from hypothesis.strategies import binary, sampled_from
Expand All @@ -27,28 +25,32 @@
from .utils import read_kv_test_vectors


def chacha20poly1305_agl_vectors():
def chacha20poly1305_agl_vectors() -> List[Dict[str, bytes]]:
# NIST vectors derived format
DATA = "chacha20-poly1305-agl_ref.txt"
return read_kv_test_vectors(DATA, delimiter=b":", newrecord=b"AEAD")


def chacha20poly1305_ietf_vectors():
def chacha20poly1305_ietf_vectors() -> List[Dict[str, bytes]]:
# NIST vectors derived format
DATA = "chacha20-poly1305-ietf_ref.txt"
return read_kv_test_vectors(DATA, delimiter=b":", newrecord=b"AEAD")


def xchacha20poly1305_ietf_vectors():
def xchacha20poly1305_ietf_vectors() -> List[Dict[str, bytes]]:
# NIST vectors derived format
DATA = "xchacha20-poly1305-ietf_ref.txt"
return read_kv_test_vectors(DATA, delimiter=b":", newrecord=b"AEAD")


Construction = namedtuple("Construction", "encrypt, decrypt, NPUB, KEYBYTES")
class Construction(NamedTuple):
encrypt: Callable[[bytes, Optional[bytes], bytes, bytes], bytes]
decrypt: Callable[[bytes, Optional[bytes], bytes, bytes], bytes]
NPUB: int
KEYBYTES: int


def _getconstruction(construction):
def _getconstruction(construction: bytes) -> Construction:
if construction == b"chacha20-poly1305-old":
encrypt = b.crypto_aead_chacha20poly1305_encrypt
decrypt = b.crypto_aead_chacha20poly1305_decrypt
Expand All @@ -74,7 +76,7 @@ def _getconstruction(construction):
+ chacha20poly1305_ietf_vectors()
+ xchacha20poly1305_ietf_vectors(),
)
def test_chacha20poly1305_variants_kat(kv):
def test_chacha20poly1305_variants_kat(kv: Dict[str, bytes]):
msg = binascii.unhexlify(kv["IN"])
ad = binascii.unhexlify(kv["AD"])
nonce = binascii.unhexlify(kv["NONCE"])
Expand Down Expand Up @@ -103,7 +105,7 @@ def test_chacha20poly1305_variants_kat(kv):
)
@settings(deadline=None, max_examples=20)
def test_chacha20poly1305_variants_roundtrip(
construction, message, aad, nonce, key
construction: bytes, message: bytes, aad: bytes, nonce: bytes, key: bytes
):

c = _getconstruction(construction)
Expand All @@ -123,30 +125,35 @@ def test_chacha20poly1305_variants_roundtrip(
"construction",
[b"chacha20-poly1305-old", b"chacha20-poly1305", b"xchacha20-poly1305"],
)
def test_chacha20poly1305_variants_wrong_params(construction):
def test_chacha20poly1305_variants_wrong_params(construction: bytes):
c = _getconstruction(construction)
nonce = b"\x00" * c.NPUB
key = b"\x00" * c.KEYBYTES
aad = None
c.encrypt(b"", aad, nonce, key)
# The first two checks call encrypt with a nonce/key that's too short. Otherwise,
# the types are fine. (TODO: Should this raise ValueError rather than TypeError?
# Doing so would be a breaking change.)
with pytest.raises(exc.TypeError):
c.encrypt(b"", aad, nonce[:-1], key)
with pytest.raises(exc.TypeError):
c.encrypt(b"", aad, nonce, key[:-1])
# Type safety: mypy spots these next two errors, but we want to check that they're
# spotted at runtime too.
with pytest.raises(exc.TypeError):
c.encrypt(b"", aad, nonce.decode("utf-8"), key)
c.encrypt(b"", aad, nonce.decode("utf-8"), key) # type: ignore[arg-type]
with pytest.raises(exc.TypeError):
c.encrypt(b"", aad, nonce, key.decode("utf-8"))
c.encrypt(b"", aad, nonce, key.decode("utf-8")) # type: ignore[arg-type]


@pytest.mark.parametrize(
"construction",
[b"chacha20-poly1305-old", b"chacha20-poly1305", b"xchacha20-poly1305"],
)
def test_chacha20poly1305_variants_str_msg(construction):
def test_chacha20poly1305_variants_str_msg(construction: bytes):
c = _getconstruction(construction)
nonce = b"\x00" * c.NPUB
key = b"\x00" * c.KEYBYTES
aad = None
with pytest.raises(exc.TypeError):
c.encrypt("", aad, nonce, key)
c.encrypt("", aad, nonce, key) # type: ignore[arg-type]
36 changes: 21 additions & 15 deletions tests/test_bindings.py
Expand Up @@ -15,6 +15,7 @@

import hashlib
from binascii import hexlify, unhexlify
from typing import List, Tuple

from hypothesis import given, settings
from hypothesis.strategies import binary, integers
Expand All @@ -28,7 +29,7 @@
from .utils import flip_byte, read_crypto_test_vectors


def tohex(b):
def tohex(b: bytes) -> str:
return hexlify(b).decode("ascii")


Expand Down Expand Up @@ -119,20 +120,22 @@ def test_box():
def test_box_wrong_lengths():
A_pubkey, A_secretkey = c.crypto_box_keypair()
with pytest.raises(ValueError):
c.crypto_box(b"abc", "\x00", A_pubkey, A_secretkey)
c.crypto_box(b"abc", b"\x00", A_pubkey, A_secretkey)
with pytest.raises(ValueError):
c.crypto_box(
b"abc", "\x00" * c.crypto_box_NONCEBYTES, b"", A_secretkey
b"abc", b"\x00" * c.crypto_box_NONCEBYTES, b"", A_secretkey
)
with pytest.raises(ValueError):
c.crypto_box(b"abc", "\x00" * c.crypto_box_NONCEBYTES, A_pubkey, b"")
c.crypto_box(b"abc", b"\x00" * c.crypto_box_NONCEBYTES, A_pubkey, b"")

with pytest.raises(ValueError):
c.crypto_box_open(b"", b"", b"", b"")
with pytest.raises(ValueError):
c.crypto_box_open(b"", "\x00" * c.crypto_box_NONCEBYTES, b"", b"")
c.crypto_box_open(b"", b"\x00" * c.crypto_box_NONCEBYTES, b"", b"")
with pytest.raises(ValueError):
c.crypto_box_open(b"", "\x00" * c.crypto_box_NONCEBYTES, A_pubkey, b"")
c.crypto_box_open(
b"", b"\x00" * c.crypto_box_NONCEBYTES, A_pubkey, b""
)

with pytest.raises(ValueError):
c.crypto_box_beforenm(b"", b"")
Expand Down Expand Up @@ -173,7 +176,7 @@ def test_sign_wrong_lengths():
c.crypto_sign_seed_keypair(b"")


def secret_scalar():
def secret_scalar() -> Tuple[bytes, bytes]:
pubkey, secretkey = c.crypto_box_keypair()
assert len(secretkey) == c.crypto_box_SECRETKEYBYTES
assert c.crypto_box_SECRETKEYBYTES == c.crypto_scalarmult_BYTES
Expand Down Expand Up @@ -272,17 +275,18 @@ def test_box_seal_wrong_lengths():

def test_box_seal_wrong_types():
A_pubkey, A_secretkey = c.crypto_box_keypair()
# type safety: mypy can spot these errors, but we want to spot them at runtime too.
with pytest.raises(TypeError):
c.crypto_box_seal(b"abc", dict())
c.crypto_box_seal(b"abc", dict()) # type: ignore[arg-type]
with pytest.raises(TypeError):
c.crypto_box_seal_open(b"abc", None, A_secretkey)
c.crypto_box_seal_open(b"abc", None, A_secretkey) # type: ignore[arg-type]
with pytest.raises(TypeError):
c.crypto_box_seal_open(b"abc", A_pubkey, None)
c.crypto_box_seal_open(b"abc", A_pubkey, None) # type: ignore[arg-type]
with pytest.raises(TypeError):
c.crypto_box_seal_open(None, A_pubkey, A_secretkey)
c.crypto_box_seal_open(None, A_pubkey, A_secretkey) # type: ignore[arg-type]


def _box_from_seed_vectors():
def _box_from_seed_vectors() -> List[Tuple[bytes, bytes, bytes]]:
# Fmt: <seed> <tab> <public_key> || <secret_key>
DATA = "box_from_seed.txt"
lines = read_crypto_test_vectors(DATA, maxels=2, delimiter=b"\t")
Expand All @@ -299,7 +303,9 @@ def _box_from_seed_vectors():
@pytest.mark.parametrize(
("seed", "public_key", "secret_key"), _box_from_seed_vectors()
)
def test_box_seed_keypair_reference(seed, public_key, secret_key):
def test_box_seed_keypair_reference(
seed: bytes, public_key: bytes, secret_key: bytes
):
seed = unhexlify(seed)
pk, sk = c.crypto_box_seed_keypair(seed)
assert pk == unhexlify(public_key)
Expand Down Expand Up @@ -336,7 +342,7 @@ def test_unpad_not_padded():
binary(min_size=0, max_size=2049), integers(min_value=16, max_value=256)
)
@settings(max_examples=20)
def test_pad_sizes(msg, bl_sz):
def test_pad_sizes(msg: bytes, bl_sz: int):
padded = c.sodium_pad(msg, bl_sz)
assert len(padded) > len(msg)
assert len(padded) >= bl_sz
Expand All @@ -347,7 +353,7 @@ def test_pad_sizes(msg, bl_sz):
binary(min_size=0, max_size=2049), integers(min_value=16, max_value=256)
)
@settings(max_examples=20)
def test_pad_roundtrip(msg, bl_sz):
def test_pad_roundtrip(msg: bytes, bl_sz: int):
padded = c.sodium_pad(msg, bl_sz)
assert len(padded) > len(msg)
assert len(padded) >= bl_sz
Expand Down

0 comments on commit 6f732e6

Please sign in to comment.